From 51e2f71d154cda103d080823b69a27a24d72de60 Mon Sep 17 00:00:00 2001 From: Joel Friedrich Date: Fri, 22 Dec 2023 11:01:44 +0100 Subject: [PATCH] finished documentation of rules.py and base action --- marl_factory_grid/environment/actions.py | 24 ++-- marl_factory_grid/environment/rules.py | 143 ++++++++++++++++++----- marl_factory_grid/utils/config_parser.py | 4 +- 3 files changed, 129 insertions(+), 42 deletions(-) diff --git a/marl_factory_grid/environment/actions.py b/marl_factory_grid/environment/actions.py index 231a74d..e802741 100644 --- a/marl_factory_grid/environment/actions.py +++ b/marl_factory_grid/environment/actions.py @@ -6,25 +6,30 @@ from marl_factory_grid.environment import rewards as r, constants as c from marl_factory_grid.utils.helpers import MOVEMAP from marl_factory_grid.utils.results import ActionResult - TYPE_COLLISION = 'collision' + class Action(abc.ABC): @property def name(self): return self._identifier @abc.abstractmethod - def __init__(self, identifier: str, default_valid_reward: float, default_fail_reward: float, + def __init__(self, identifier: str, default_valid_reward: float, default_fail_reward: float, valid_reward: float | None = None, fail_reward: float | None = None): """ - Todo + Abstract base class representing an action that can be performed in the environment. - :param identifier: - :param default_valid_reward: - :param default_fail_reward: - :param valid_reward: - :param fail_reward: + :param identifier: A unique identifier for the action. + :type identifier: str + :param default_valid_reward: Default reward for a valid action. + :type default_valid_reward: float + :param default_fail_reward: Default reward for a failed action. + :type default_fail_reward: float + :param valid_reward: Custom reward for a valid action (optional). + :type valid_reward: Union[float, optional] + :param fail_reward: Custom reward for a failed action (optional). + :type fail_reward: Union[float, optional] """ self.fail_reward = fail_reward if fail_reward is not None else default_fail_reward self.valid_reward = valid_reward if valid_reward is not None else default_valid_reward @@ -46,6 +51,9 @@ class Action(abc.ABC): return f'Action[{self._identifier}]' def get_result(self, validity, entity, action_introduced_collision=False): + """ + Generate an ActionResult for the action based on its validity. + """ reward = self.valid_reward if validity else self.fail_reward return ActionResult(self.__class__.__name__, validity, reward=reward, entity=entity, action_introduced_collision=action_introduced_collision) diff --git a/marl_factory_grid/environment/rules.py b/marl_factory_grid/environment/rules.py index ac2faf7..82f7696 100644 --- a/marl_factory_grid/environment/rules.py +++ b/marl_factory_grid/environment/rules.py @@ -16,85 +16,128 @@ class Rule(abc.ABC): @property def name(self): """ - TODO + Get the name of the rule. - - :return: - """ + :return: The name of the rule. + :rtype: str + """ return self.__class__.__name__ def __init__(self): """ - TODO + Abstract base class representing a rule in the environment. + This class provides a framework for defining rules that govern the behavior of the environment. Rules can be + implemented by inheriting from this class and overriding specific methods. - :return: """ pass - def __repr__(self): + def __repr__(self) -> str: + """ + Return a string representation of the rule. + + :return: A string representation of the rule. + :rtype: str + """ return f'{self.name}' def on_init(self, state, lvl_map): """ - TODO + Initialize the rule when the environment is created. + This method is called during the initialization of the environment. It allows the rule to perform any setup or + initialization required. - :return: + :param state: The current game state. + :type state: marl_factory_grid.utils.states.GameState + :param lvl_map: The map of the level. + :type lvl_map: marl_factory_grid.environment.level.LevelMap + :return: List of TickResults generated during initialization. + :rtype: List[TickResult] """ return [] def on_reset_post_spawn(self, state) -> List[TickResult]: """ - TODO + Execute actions after entities are spawned during a reset. + This method is called after entities are spawned during a reset. It allows the rule to perform any actions + required at this stage. - :return: + :param state: The current game state. + :type state: marl_factory_grid.utils.states.GameState + :return: List of TickResults generated after entity spawning. + :rtype: List[TickResult] """ return [] def on_reset(self, state) -> List[TickResult]: """ - TODO + Execute actions during a reset. + This method is called during a reset. It allows the rule to perform any actions required at this stage. - :return: + :param state: The current game state. + :type state: marl_factory_grid.utils.states.GameState + :return: List of TickResults generated during a reset. + :rtype: List[TickResult] """ return [] def tick_pre_step(self, state) -> List[TickResult]: """ - TODO + Execute actions before the main step of the environment. + This method is called before the main step of the environment. It allows the rule to perform any actions + required before the main step. - :return: + :param state: The current game state. + :type state: marl_factory_grid.utils.states.GameState + :return: List of TickResults generated before the main step. + :rtype: List[TickResult] """ return [] def tick_step(self, state) -> List[TickResult]: """ - TODO + Execute actions during the main step of the environment. + This method is called during the main step of the environment. It allows the rule to perform any actions + required during the main step. - :return: + :param state: The current game state. + :type state: marl_factory_grid.utils.states.GameState + :return: List of TickResults generated during the main step. + :rtype: List[TickResult] """ return [] def tick_post_step(self, state) -> List[TickResult]: """ - TODO + Execute actions after the main step of the environment. + This method is called after the main step of the environment. It allows the rule to perform any actions + required after the main step. - :return: + :param state: The current game state. + :type state: marl_factory_grid.utils.states.GameState + :return: List of TickResults generated after the main step. + :rtype: List[TickResult] """ return [] def on_check_done(self, state) -> List[DoneResult]: """ - TODO + Check conditions for the termination of the environment. + This method is called to check conditions for the termination of the environment. It allows the rule to + specify conditions under which the environment should be considered done. - :return: + :param state: The current game state. + :type state: marl_factory_grid.utils.states.GameState + :return: List of DoneResults indicating whether the environment is done. + :rtype: List[DoneResult] """ return [] @@ -160,15 +203,23 @@ class DoneAtMaxStepsReached(Rule): def __init__(self, max_steps: int = 500): """ - TODO + A rule that terminates the environment when a specified maximum number of steps is reached. - - :return: - """ + :param max_steps: The maximum number of steps before the environment is considered done. + :type max_steps: int + """ super().__init__() self.max_steps = max_steps def on_check_done(self, state): + """ + Check if the maximum number of steps is reached, and if so, mark the environment as done. + + :param state: The current game state. + :type state: marl_factory_grid.utils.states.GameState + :return: List of DoneResults indicating whether the environment is done. + :rtype: List[DoneResult] + """ if self.max_steps <= state.curr_step: return [DoneResult(validity=c.VALID, identifier=self.name)] return [] @@ -178,14 +229,23 @@ class AssignGlobalPositions(Rule): def __init__(self): """ - TODO + A rule that assigns global positions to agents when the environment is reset. - - :return: + :return: None """ super().__init__() def on_reset(self, state, lvl_map): + """ + Assign global positions to agents when the environment is reset. + + :param state: The current game state. + :type state: marl_factory_grid.utils.states.GameState + :param lvl_map: The map of the current level. + :type lvl_map: marl_factory_grid.levels.level.LevelMap + :return: An empty list, as no additional results are generated by this rule during the reset. + :rtype: List[TickResult] + """ from marl_factory_grid.environment.entity.util import GlobalPosition for agent in state[c.AGENT]: gp = GlobalPosition(agent, lvl_map.level_shape) @@ -197,10 +257,15 @@ class WatchCollisions(Rule): def __init__(self, reward=r.COLLISION, done_at_collisions: bool = False, reward_at_done=r.COLLISION_DONE): """ - TODO + A rule that monitors collisions between entities in the environment. - - :return: + :param reward: The reward assigned for each collision. + :type reward: float + :param done_at_collisions: If True, marks the environment as done when collisions occur. + :type done_at_collisions: bool + :param reward_at_done: The reward assigned when the environment is marked as done due to collisions. + :type reward_at_done: float + :return: None """ super().__init__() self.reward_at_done = reward_at_done @@ -209,6 +274,14 @@ class WatchCollisions(Rule): self.curr_done = False def tick_post_step(self, state) -> List[TickResult]: + """ + Monitors collisions between entities after each step in the environment. + + :param state: The current game state. + :type state: marl_factory_grid.utils.states.GameState + :return: A list of TickResult objects representing collisions and their associated rewards. + :rtype: List[TickResult] + """ self.curr_done = False results = list() for agent in state[c.AGENT]: @@ -234,6 +307,14 @@ class WatchCollisions(Rule): return results def on_check_done(self, state) -> List[DoneResult]: + """ + Checks if the environment should be marked as done based on collision conditions. + + :param state: The current game state. + :type state: marl_factory_grid.utils.states.GameState + :return: A list of DoneResult objects representing the conditions for marking the environment as done. + :rtype: List[DoneResult] + """ if self.done_at_collisions: inter_entity_collision_detected = self.curr_done collision_in_step = any(h.is_move(x.state.identifier) and x.state.action_introduced_collision diff --git a/marl_factory_grid/utils/config_parser.py b/marl_factory_grid/utils/config_parser.py index 26cf007..926d302 100644 --- a/marl_factory_grid/utils/config_parser.py +++ b/marl_factory_grid/utils/config_parser.py @@ -23,7 +23,6 @@ class FactoryConfigParser(object): """ This class parses the factory env config file. - :param config_path: Path to where the 'config.yml' is. :param custom_modules_path: Additional search path for custom modules, levels, entities, etc.. """ @@ -45,7 +44,6 @@ class FactoryConfigParser(object): self._n_abbr_dict = defaultdict(lambda: 'th', {1: 'st', 2: 'nd', 3: 'rd'}) return self._n_abbr_dict[n] - @property def agent_actions(self): return self._get_sub_list('Agents', "Actions") @@ -176,7 +174,7 @@ class FactoryConfigParser(object): ['Actions', 'Observations', 'Positions', 'Clones']} parsed_agents_conf[name] = dict( actions=parsed_actions, observations=observations, positions=positions, other=other_kwargs - ) + ) clones = self.agents[name].get('Clones', 0) if clones: