Merge branch 'main' into unit_testing

# Conflicts:
#	marl_factory_grid/modules/doors/groups.py
#	marl_factory_grid/utils/states.py
This commit is contained in:
Chanumask
2023-11-23 12:58:12 +01:00
63 changed files with 1477 additions and 330 deletions

View File

@ -15,7 +15,15 @@ from ..doors import DoorUse
class Maintainer(Entity):
def __init__(self, objective: str, action: Action, *args, **kwargs):
def __init__(self, objective, action, *args, **kwargs):
"""
Represents the maintainer entity that aims to maintain machines.
:param objective: The maintainer's objective, e.g., "Machines".
:type objective: str
:param action: The default action to be performed by the maintainer.
:type action: Action
"""
super().__init__(*args, **kwargs)
self.action = action
self.actions = [x() for x in ALL_BASEACTIONS] + [DoorUse()]
@ -26,6 +34,16 @@ class Maintainer(Entity):
self._last_serviced = 'None'
def tick(self, state):
"""
If there is an objective at the current position, the maintainer performs its action on the objective.
If the objective has changed since the last servicing, the maintainer performs the action and updates
the last serviced objective. Otherwise, it calculates a move action and performs it.
:param state: The current game state.
:type state: GameState
:return: The result of the action performed by the maintainer.
:rtype: ActionResult
"""
if found_objective := h.get_first(state[self.objective].by_pos(self.pos)):
if found_objective.name != self._last_serviced:
result = self.action.do(self, state)
@ -40,9 +58,24 @@ class Maintainer(Entity):
return result
def set_state(self, action_result):
"""
Updates the maintainers own status with an action result.
"""
self._status = action_result
def get_move_action(self, state) -> Action:
"""
Retrieves the next move action for the agent.
If a path is not already determined, the agent calculates the shortest path to its objective, considering doors
and obstacles. If a closed door is found in the calculated path, the agent attempts to open it.
:param state: The current state of the environment.
:type state: GameState
:return: The chosen move action for the agent.
:rtype: Action
"""
if self._path is None or not len(self._path):
if not self._next:
self._next = list(state[self.objective].values()) + [Floor(*state.random_free_position)]
@ -70,17 +103,27 @@ class Maintainer(Entity):
raise EnvironmentError
return action_obj
def calculate_route(self, entity, floortile_graph):
def calculate_route(self, entity, floortile_graph) -> list:
"""
:returns: path, include both the source and target position
:rtype: list
"""
route = nx.shortest_path(floortile_graph, self.pos, entity.pos)
return route[1:]
def _closed_door_in_path(self, state):
"""
Internal Use
"""
if self._path:
return h.get_first(state[do.DOORS].by_pos(self._path[0]), lambda x: x.is_closed)
else:
return None
def _predict_move(self, state):
def _predict_move(self, state) -> Action:
"""
Internal Use
"""
next_pos = self._path[0]
if any(x for x in state.entities.pos_dict[next_pos] if x.var_can_collide) > 0:
action = c.NOOP

View File

@ -9,12 +9,26 @@ from ..machines.actions import MachineAction
class Maintainers(Collection):
_entity = Maintainer
var_can_collide = True
var_can_move = True
var_is_blocking_light = False
var_has_position = True
@property
def var_can_collide(self):
return True
@property
def var_can_move(self):
return True
@property
def var_is_blocking_light(self):
return False
@property
def var_has_position(self):
return True
def __init__(self, *args, **kwargs):
"""
A collection of maintainers
"""
super().__init__(*args, **kwargs)
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args):

View File

@ -9,6 +9,9 @@ from . import constants as M
class MoveMaintainers(Rule):
def __init__(self):
"""
This rule is responsible for moving the maintainers at every step of the environment.
"""
super().__init__()
def tick_step(self, state) -> List[TickResult]:
@ -22,6 +25,9 @@ class MoveMaintainers(Rule):
class DoneAtMaintainerCollision(Rule):
def __init__(self):
"""
When active, this rule stops the environment after a maintainer reports a collision with another entity.
"""
super().__init__()
def on_check_done(self, state) -> List[DoneResult]: