finished documentation of rules.py and base action

This commit is contained in:
Joel Friedrich
2023-12-22 11:01:44 +01:00
parent 944887aa2e
commit 51e2f71d15
3 changed files with 129 additions and 42 deletions

View File

@ -6,9 +6,9 @@ from marl_factory_grid.environment import rewards as r, constants as c
from marl_factory_grid.utils.helpers import MOVEMAP from marl_factory_grid.utils.helpers import MOVEMAP
from marl_factory_grid.utils.results import ActionResult from marl_factory_grid.utils.results import ActionResult
TYPE_COLLISION = 'collision' TYPE_COLLISION = 'collision'
class Action(abc.ABC): class Action(abc.ABC):
@property @property
def name(self): def name(self):
@ -18,13 +18,18 @@ class Action(abc.ABC):
def __init__(self, identifier: str, default_valid_reward: float, default_fail_reward: float, def __init__(self, identifier: str, default_valid_reward: float, default_fail_reward: float,
valid_reward: float | None = None, fail_reward: float | None = None): valid_reward: float | None = None, fail_reward: float | None = None):
""" """
Todo Abstract base class representing an action that can be performed in the environment.
:param identifier: :param identifier: A unique identifier for the action.
:param default_valid_reward: :type identifier: str
:param default_fail_reward: :param default_valid_reward: Default reward for a valid action.
:param valid_reward: :type default_valid_reward: float
:param fail_reward: :param default_fail_reward: Default reward for a failed action.
:type default_fail_reward: float
:param valid_reward: Custom reward for a valid action (optional).
:type valid_reward: Union[float, optional]
:param fail_reward: Custom reward for a failed action (optional).
:type fail_reward: Union[float, optional]
""" """
self.fail_reward = fail_reward if fail_reward is not None else default_fail_reward self.fail_reward = fail_reward if fail_reward is not None else default_fail_reward
self.valid_reward = valid_reward if valid_reward is not None else default_valid_reward self.valid_reward = valid_reward if valid_reward is not None else default_valid_reward
@ -46,6 +51,9 @@ class Action(abc.ABC):
return f'Action[{self._identifier}]' return f'Action[{self._identifier}]'
def get_result(self, validity, entity, action_introduced_collision=False): def get_result(self, validity, entity, action_introduced_collision=False):
"""
Generate an ActionResult for the action based on its validity.
"""
reward = self.valid_reward if validity else self.fail_reward reward = self.valid_reward if validity else self.fail_reward
return ActionResult(self.__class__.__name__, validity, reward=reward, entity=entity, return ActionResult(self.__class__.__name__, validity, reward=reward, entity=entity,
action_introduced_collision=action_introduced_collision) action_introduced_collision=action_introduced_collision)

View File

@ -16,85 +16,128 @@ class Rule(abc.ABC):
@property @property
def name(self): def name(self):
""" """
TODO Get the name of the rule.
:return: The name of the rule.
:return: :rtype: str
""" """
return self.__class__.__name__ return self.__class__.__name__
def __init__(self): def __init__(self):
""" """
TODO Abstract base class representing a rule in the environment.
This class provides a framework for defining rules that govern the behavior of the environment. Rules can be
implemented by inheriting from this class and overriding specific methods.
:return:
""" """
pass pass
def __repr__(self): def __repr__(self) -> str:
"""
Return a string representation of the rule.
:return: A string representation of the rule.
:rtype: str
"""
return f'{self.name}' return f'{self.name}'
def on_init(self, state, lvl_map): def on_init(self, state, lvl_map):
""" """
TODO Initialize the rule when the environment is created.
This method is called during the initialization of the environment. It allows the rule to perform any setup or
initialization required.
:return: :param state: The current game state.
:type state: marl_factory_grid.utils.states.GameState
:param lvl_map: The map of the level.
:type lvl_map: marl_factory_grid.environment.level.LevelMap
:return: List of TickResults generated during initialization.
:rtype: List[TickResult]
""" """
return [] return []
def on_reset_post_spawn(self, state) -> List[TickResult]: def on_reset_post_spawn(self, state) -> List[TickResult]:
""" """
TODO Execute actions after entities are spawned during a reset.
This method is called after entities are spawned during a reset. It allows the rule to perform any actions
required at this stage.
:return: :param state: The current game state.
:type state: marl_factory_grid.utils.states.GameState
:return: List of TickResults generated after entity spawning.
:rtype: List[TickResult]
""" """
return [] return []
def on_reset(self, state) -> List[TickResult]: def on_reset(self, state) -> List[TickResult]:
""" """
TODO Execute actions during a reset.
This method is called during a reset. It allows the rule to perform any actions required at this stage.
:return: :param state: The current game state.
:type state: marl_factory_grid.utils.states.GameState
:return: List of TickResults generated during a reset.
:rtype: List[TickResult]
""" """
return [] return []
def tick_pre_step(self, state) -> List[TickResult]: def tick_pre_step(self, state) -> List[TickResult]:
""" """
TODO Execute actions before the main step of the environment.
This method is called before the main step of the environment. It allows the rule to perform any actions
required before the main step.
:return: :param state: The current game state.
:type state: marl_factory_grid.utils.states.GameState
:return: List of TickResults generated before the main step.
:rtype: List[TickResult]
""" """
return [] return []
def tick_step(self, state) -> List[TickResult]: def tick_step(self, state) -> List[TickResult]:
""" """
TODO Execute actions during the main step of the environment.
This method is called during the main step of the environment. It allows the rule to perform any actions
required during the main step.
:return: :param state: The current game state.
:type state: marl_factory_grid.utils.states.GameState
:return: List of TickResults generated during the main step.
:rtype: List[TickResult]
""" """
return [] return []
def tick_post_step(self, state) -> List[TickResult]: def tick_post_step(self, state) -> List[TickResult]:
""" """
TODO Execute actions after the main step of the environment.
This method is called after the main step of the environment. It allows the rule to perform any actions
required after the main step.
:return: :param state: The current game state.
:type state: marl_factory_grid.utils.states.GameState
:return: List of TickResults generated after the main step.
:rtype: List[TickResult]
""" """
return [] return []
def on_check_done(self, state) -> List[DoneResult]: def on_check_done(self, state) -> List[DoneResult]:
""" """
TODO Check conditions for the termination of the environment.
This method is called to check conditions for the termination of the environment. It allows the rule to
specify conditions under which the environment should be considered done.
:return: :param state: The current game state.
:type state: marl_factory_grid.utils.states.GameState
:return: List of DoneResults indicating whether the environment is done.
:rtype: List[DoneResult]
""" """
return [] return []
@ -160,15 +203,23 @@ class DoneAtMaxStepsReached(Rule):
def __init__(self, max_steps: int = 500): def __init__(self, max_steps: int = 500):
""" """
TODO A rule that terminates the environment when a specified maximum number of steps is reached.
:param max_steps: The maximum number of steps before the environment is considered done.
:return: :type max_steps: int
""" """
super().__init__() super().__init__()
self.max_steps = max_steps self.max_steps = max_steps
def on_check_done(self, state): def on_check_done(self, state):
"""
Check if the maximum number of steps is reached, and if so, mark the environment as done.
:param state: The current game state.
:type state: marl_factory_grid.utils.states.GameState
:return: List of DoneResults indicating whether the environment is done.
:rtype: List[DoneResult]
"""
if self.max_steps <= state.curr_step: if self.max_steps <= state.curr_step:
return [DoneResult(validity=c.VALID, identifier=self.name)] return [DoneResult(validity=c.VALID, identifier=self.name)]
return [] return []
@ -178,14 +229,23 @@ class AssignGlobalPositions(Rule):
def __init__(self): def __init__(self):
""" """
TODO A rule that assigns global positions to agents when the environment is reset.
:return: None
:return:
""" """
super().__init__() super().__init__()
def on_reset(self, state, lvl_map): def on_reset(self, state, lvl_map):
"""
Assign global positions to agents when the environment is reset.
:param state: The current game state.
:type state: marl_factory_grid.utils.states.GameState
:param lvl_map: The map of the current level.
:type lvl_map: marl_factory_grid.levels.level.LevelMap
:return: An empty list, as no additional results are generated by this rule during the reset.
:rtype: List[TickResult]
"""
from marl_factory_grid.environment.entity.util import GlobalPosition from marl_factory_grid.environment.entity.util import GlobalPosition
for agent in state[c.AGENT]: for agent in state[c.AGENT]:
gp = GlobalPosition(agent, lvl_map.level_shape) gp = GlobalPosition(agent, lvl_map.level_shape)
@ -197,10 +257,15 @@ class WatchCollisions(Rule):
def __init__(self, reward=r.COLLISION, done_at_collisions: bool = False, reward_at_done=r.COLLISION_DONE): def __init__(self, reward=r.COLLISION, done_at_collisions: bool = False, reward_at_done=r.COLLISION_DONE):
""" """
TODO A rule that monitors collisions between entities in the environment.
:param reward: The reward assigned for each collision.
:return: :type reward: float
:param done_at_collisions: If True, marks the environment as done when collisions occur.
:type done_at_collisions: bool
:param reward_at_done: The reward assigned when the environment is marked as done due to collisions.
:type reward_at_done: float
:return: None
""" """
super().__init__() super().__init__()
self.reward_at_done = reward_at_done self.reward_at_done = reward_at_done
@ -209,6 +274,14 @@ class WatchCollisions(Rule):
self.curr_done = False self.curr_done = False
def tick_post_step(self, state) -> List[TickResult]: def tick_post_step(self, state) -> List[TickResult]:
"""
Monitors collisions between entities after each step in the environment.
:param state: The current game state.
:type state: marl_factory_grid.utils.states.GameState
:return: A list of TickResult objects representing collisions and their associated rewards.
:rtype: List[TickResult]
"""
self.curr_done = False self.curr_done = False
results = list() results = list()
for agent in state[c.AGENT]: for agent in state[c.AGENT]:
@ -234,6 +307,14 @@ class WatchCollisions(Rule):
return results return results
def on_check_done(self, state) -> List[DoneResult]: def on_check_done(self, state) -> List[DoneResult]:
"""
Checks if the environment should be marked as done based on collision conditions.
:param state: The current game state.
:type state: marl_factory_grid.utils.states.GameState
:return: A list of DoneResult objects representing the conditions for marking the environment as done.
:rtype: List[DoneResult]
"""
if self.done_at_collisions: if self.done_at_collisions:
inter_entity_collision_detected = self.curr_done inter_entity_collision_detected = self.curr_done
collision_in_step = any(h.is_move(x.state.identifier) and x.state.action_introduced_collision collision_in_step = any(h.is_move(x.state.identifier) and x.state.action_introduced_collision

View File

@ -23,7 +23,6 @@ class FactoryConfigParser(object):
""" """
This class parses the factory env config file. This class parses the factory env config file.
:param config_path: Path to where the 'config.yml' is. :param config_path: Path to where the 'config.yml' is.
:param custom_modules_path: Additional search path for custom modules, levels, entities, etc.. :param custom_modules_path: Additional search path for custom modules, levels, entities, etc..
""" """
@ -45,7 +44,6 @@ class FactoryConfigParser(object):
self._n_abbr_dict = defaultdict(lambda: 'th', {1: 'st', 2: 'nd', 3: 'rd'}) self._n_abbr_dict = defaultdict(lambda: 'th', {1: 'st', 2: 'nd', 3: 'rd'})
return self._n_abbr_dict[n] return self._n_abbr_dict[n]
@property @property
def agent_actions(self): def agent_actions(self):
return self._get_sub_list('Agents', "Actions") return self._get_sub_list('Agents', "Actions")