mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-22 14:56:43 +02:00
140 lines
5.4 KiB
Python
140 lines
5.4 KiB
Python
from random import shuffle
|
|
|
|
import networkx as nx
|
|
import numpy as np
|
|
|
|
from ...environment import constants as c
|
|
from ...environment.actions import Action, ALL_BASEACTIONS
|
|
from ...environment.entity.entity import Entity
|
|
from ..doors import constants as do
|
|
from ..maintenance import constants as mi
|
|
from ...utils import helpers as h
|
|
from ...utils.utility_classes import RenderEntity, Floor
|
|
from ..doors import DoorUse
|
|
|
|
|
|
class Maintainer(Entity):
|
|
|
|
def __init__(self, objective, action, *args, **kwargs):
|
|
self.action_ = """
|
|
Represents the maintainer entity that aims to maintain machines. The maintainer calculates its route using nx
|
|
shortest path and restores the health of machines it visits to 100.
|
|
|
|
:param objective: The maintainer's objective, e.g., "Machines".
|
|
:type objective: str
|
|
:param action: The default action to be performed by the maintainer.
|
|
:type action: Action
|
|
"""
|
|
super().__init__(*args, **kwargs)
|
|
self.action = action
|
|
self.actions = [x() for x in ALL_BASEACTIONS] + [DoorUse()]
|
|
self.objective = objective
|
|
self._path = None
|
|
self._next = []
|
|
self._last = []
|
|
self._last_serviced = 'None'
|
|
|
|
def tick(self, state):
|
|
"""
|
|
If there is an objective at the current position, the maintainer performs its action on the objective.
|
|
If the objective has changed since the last servicing, the maintainer performs the action and updates
|
|
the last serviced objective. Otherwise, it calculates a move action and performs it.
|
|
|
|
:param state: The current game state.
|
|
:type state: GameState
|
|
:return: The result of the action performed by the maintainer.
|
|
:rtype: ActionResult
|
|
"""
|
|
if found_objective := h.get_first(state[self.objective].by_pos(self.pos)):
|
|
if found_objective.name != self._last_serviced:
|
|
result = self.action.do(self, state)
|
|
self._last_serviced = found_objective.name
|
|
else:
|
|
action = self.get_move_action(state)
|
|
result = action.do(self, state)
|
|
else:
|
|
action = self.get_move_action(state)
|
|
result = action.do(self, state)
|
|
self.set_state(result)
|
|
return result
|
|
|
|
def set_state(self, action_result):
|
|
"""
|
|
Updates the maintainers own status with an action result.
|
|
"""
|
|
self._status = action_result
|
|
|
|
def get_move_action(self, state) -> Action:
|
|
"""
|
|
Retrieves the next move action for the agent.
|
|
|
|
If a path is not already determined, the agent calculates the shortest path to its objective, considering doors
|
|
and obstacles. If a closed door is found in the calculated path, the agent attempts to open it.
|
|
|
|
:param state: The current state of the environment.
|
|
:type state: GameState
|
|
|
|
:return: The chosen move action for the agent.
|
|
:rtype: Action
|
|
"""
|
|
if self._path is None or not len(self._path):
|
|
if not self._next:
|
|
self._next = list(state[self.objective].values()) + [Floor(*state.random_free_position)]
|
|
shuffle(self._next)
|
|
self._last = []
|
|
self._last.append(self._next.pop())
|
|
state.print("Calculating shortest path....")
|
|
self._path = self.calculate_route(self._last[-1], state.floortile_graph)
|
|
if not self._path:
|
|
self._last.append(self._next.pop())
|
|
state.print("Calculating shortest path.... Again....")
|
|
self._path = self.calculate_route(self._last[-1], state.floortile_graph)
|
|
|
|
if door := self._closed_door_in_path(state):
|
|
state.print(f"{self} found {door} that is closed. Attempt to open.")
|
|
# Translate the action_object to an integer to have the same output as any other model
|
|
action = do.ACTION_DOOR_USE
|
|
else:
|
|
action = self._predict_move(state)
|
|
# Translate the action_object to an integer to have the same output as any other model
|
|
try:
|
|
action_obj = h.get_first(self.actions, lambda x: x.name == action)
|
|
except (StopIteration, UnboundLocalError):
|
|
print('Will not happen')
|
|
raise EnvironmentError
|
|
return action_obj
|
|
|
|
def calculate_route(self, entity, floortile_graph) -> list:
|
|
"""
|
|
:returns: path, include both the source and target position
|
|
:rtype: list
|
|
"""
|
|
route = nx.shortest_path(floortile_graph, self.pos, entity.pos)
|
|
return route[1:]
|
|
|
|
def _closed_door_in_path(self, state):
|
|
"""
|
|
Internal Use
|
|
"""
|
|
if self._path:
|
|
return h.get_first(state[do.DOORS].by_pos(self._path[0]), lambda x: x.is_closed)
|
|
else:
|
|
return None
|
|
|
|
def _predict_move(self, state) -> Action:
|
|
"""
|
|
Internal Use
|
|
"""
|
|
next_pos = self._path[0]
|
|
if any(x for x in state.entities.pos_dict[next_pos] if x.var_can_collide) > 0:
|
|
action = c.NOOP
|
|
else:
|
|
next_pos = self._path.pop(0)
|
|
diff = np.subtract(next_pos, self.pos)
|
|
# Retrieve action based on the pos dif (like in: What do I have to do to get there?)
|
|
action = next(action for action, pos_diff in h.MOVEMAP.items() if np.all(diff == pos_diff))
|
|
return action
|
|
|
|
def render(self):
|
|
return RenderEntity(mi.MAINTAINER, self.pos)
|