finished documentation of rules.py and base action

This commit is contained in:
Joel Friedrich
2023-12-22 11:01:44 +01:00
parent 944887aa2e
commit 51e2f71d15
3 changed files with 129 additions and 42 deletions

View File

@ -6,9 +6,9 @@ from marl_factory_grid.environment import rewards as r, constants as c
from marl_factory_grid.utils.helpers import MOVEMAP
from marl_factory_grid.utils.results import ActionResult
TYPE_COLLISION = 'collision'
class Action(abc.ABC):
@property
def name(self):
@ -18,13 +18,18 @@ class Action(abc.ABC):
def __init__(self, identifier: str, default_valid_reward: float, default_fail_reward: float,
valid_reward: float | None = None, fail_reward: float | None = None):
"""
Todo
Abstract base class representing an action that can be performed in the environment.
:param identifier:
:param default_valid_reward:
:param default_fail_reward:
:param valid_reward:
:param fail_reward:
:param identifier: A unique identifier for the action.
:type identifier: str
:param default_valid_reward: Default reward for a valid action.
:type default_valid_reward: float
:param default_fail_reward: Default reward for a failed action.
:type default_fail_reward: float
:param valid_reward: Custom reward for a valid action (optional).
:type valid_reward: Union[float, optional]
:param fail_reward: Custom reward for a failed action (optional).
:type fail_reward: Union[float, optional]
"""
self.fail_reward = fail_reward if fail_reward is not None else default_fail_reward
self.valid_reward = valid_reward if valid_reward is not None else default_valid_reward
@ -46,6 +51,9 @@ class Action(abc.ABC):
return f'Action[{self._identifier}]'
def get_result(self, validity, entity, action_introduced_collision=False):
"""
Generate an ActionResult for the action based on its validity.
"""
reward = self.valid_reward if validity else self.fail_reward
return ActionResult(self.__class__.__name__, validity, reward=reward, entity=entity,
action_introduced_collision=action_introduced_collision)

View File

@ -16,85 +16,128 @@ class Rule(abc.ABC):
@property
def name(self):
"""
TODO
Get the name of the rule.
:return:
:return: The name of the rule.
:rtype: str
"""
return self.__class__.__name__
def __init__(self):
"""
TODO
Abstract base class representing a rule in the environment.
This class provides a framework for defining rules that govern the behavior of the environment. Rules can be
implemented by inheriting from this class and overriding specific methods.
:return:
"""
pass
def __repr__(self):
def __repr__(self) -> str:
"""
Return a string representation of the rule.
:return: A string representation of the rule.
:rtype: str
"""
return f'{self.name}'
def on_init(self, state, lvl_map):
"""
TODO
Initialize the rule when the environment is created.
This method is called during the initialization of the environment. It allows the rule to perform any setup or
initialization required.
:return:
:param state: The current game state.
:type state: marl_factory_grid.utils.states.GameState
:param lvl_map: The map of the level.
:type lvl_map: marl_factory_grid.environment.level.LevelMap
:return: List of TickResults generated during initialization.
:rtype: List[TickResult]
"""
return []
def on_reset_post_spawn(self, state) -> List[TickResult]:
"""
TODO
Execute actions after entities are spawned during a reset.
This method is called after entities are spawned during a reset. It allows the rule to perform any actions
required at this stage.
:return:
:param state: The current game state.
:type state: marl_factory_grid.utils.states.GameState
:return: List of TickResults generated after entity spawning.
:rtype: List[TickResult]
"""
return []
def on_reset(self, state) -> List[TickResult]:
"""
TODO
Execute actions during a reset.
This method is called during a reset. It allows the rule to perform any actions required at this stage.
:return:
:param state: The current game state.
:type state: marl_factory_grid.utils.states.GameState
:return: List of TickResults generated during a reset.
:rtype: List[TickResult]
"""
return []
def tick_pre_step(self, state) -> List[TickResult]:
"""
TODO
Execute actions before the main step of the environment.
This method is called before the main step of the environment. It allows the rule to perform any actions
required before the main step.
:return:
:param state: The current game state.
:type state: marl_factory_grid.utils.states.GameState
:return: List of TickResults generated before the main step.
:rtype: List[TickResult]
"""
return []
def tick_step(self, state) -> List[TickResult]:
"""
TODO
Execute actions during the main step of the environment.
This method is called during the main step of the environment. It allows the rule to perform any actions
required during the main step.
:return:
:param state: The current game state.
:type state: marl_factory_grid.utils.states.GameState
:return: List of TickResults generated during the main step.
:rtype: List[TickResult]
"""
return []
def tick_post_step(self, state) -> List[TickResult]:
"""
TODO
Execute actions after the main step of the environment.
This method is called after the main step of the environment. It allows the rule to perform any actions
required after the main step.
:return:
:param state: The current game state.
:type state: marl_factory_grid.utils.states.GameState
:return: List of TickResults generated after the main step.
:rtype: List[TickResult]
"""
return []
def on_check_done(self, state) -> List[DoneResult]:
"""
TODO
Check conditions for the termination of the environment.
This method is called to check conditions for the termination of the environment. It allows the rule to
specify conditions under which the environment should be considered done.
:return:
:param state: The current game state.
:type state: marl_factory_grid.utils.states.GameState
:return: List of DoneResults indicating whether the environment is done.
:rtype: List[DoneResult]
"""
return []
@ -160,15 +203,23 @@ class DoneAtMaxStepsReached(Rule):
def __init__(self, max_steps: int = 500):
"""
TODO
A rule that terminates the environment when a specified maximum number of steps is reached.
:return:
:param max_steps: The maximum number of steps before the environment is considered done.
:type max_steps: int
"""
super().__init__()
self.max_steps = max_steps
def on_check_done(self, state):
"""
Check if the maximum number of steps is reached, and if so, mark the environment as done.
:param state: The current game state.
:type state: marl_factory_grid.utils.states.GameState
:return: List of DoneResults indicating whether the environment is done.
:rtype: List[DoneResult]
"""
if self.max_steps <= state.curr_step:
return [DoneResult(validity=c.VALID, identifier=self.name)]
return []
@ -178,14 +229,23 @@ class AssignGlobalPositions(Rule):
def __init__(self):
"""
TODO
A rule that assigns global positions to agents when the environment is reset.
:return:
:return: None
"""
super().__init__()
def on_reset(self, state, lvl_map):
"""
Assign global positions to agents when the environment is reset.
:param state: The current game state.
:type state: marl_factory_grid.utils.states.GameState
:param lvl_map: The map of the current level.
:type lvl_map: marl_factory_grid.levels.level.LevelMap
:return: An empty list, as no additional results are generated by this rule during the reset.
:rtype: List[TickResult]
"""
from marl_factory_grid.environment.entity.util import GlobalPosition
for agent in state[c.AGENT]:
gp = GlobalPosition(agent, lvl_map.level_shape)
@ -197,10 +257,15 @@ class WatchCollisions(Rule):
def __init__(self, reward=r.COLLISION, done_at_collisions: bool = False, reward_at_done=r.COLLISION_DONE):
"""
TODO
A rule that monitors collisions between entities in the environment.
:return:
:param reward: The reward assigned for each collision.
:type reward: float
:param done_at_collisions: If True, marks the environment as done when collisions occur.
:type done_at_collisions: bool
:param reward_at_done: The reward assigned when the environment is marked as done due to collisions.
:type reward_at_done: float
:return: None
"""
super().__init__()
self.reward_at_done = reward_at_done
@ -209,6 +274,14 @@ class WatchCollisions(Rule):
self.curr_done = False
def tick_post_step(self, state) -> List[TickResult]:
"""
Monitors collisions between entities after each step in the environment.
:param state: The current game state.
:type state: marl_factory_grid.utils.states.GameState
:return: A list of TickResult objects representing collisions and their associated rewards.
:rtype: List[TickResult]
"""
self.curr_done = False
results = list()
for agent in state[c.AGENT]:
@ -234,6 +307,14 @@ class WatchCollisions(Rule):
return results
def on_check_done(self, state) -> List[DoneResult]:
"""
Checks if the environment should be marked as done based on collision conditions.
:param state: The current game state.
:type state: marl_factory_grid.utils.states.GameState
:return: A list of DoneResult objects representing the conditions for marking the environment as done.
:rtype: List[DoneResult]
"""
if self.done_at_collisions:
inter_entity_collision_detected = self.curr_done
collision_in_step = any(h.is_move(x.state.identifier) and x.state.action_introduced_collision

View File

@ -23,7 +23,6 @@ class FactoryConfigParser(object):
"""
This class parses the factory env config file.
:param config_path: Path to where the 'config.yml' is.
:param custom_modules_path: Additional search path for custom modules, levels, entities, etc..
"""
@ -45,7 +44,6 @@ class FactoryConfigParser(object):
self._n_abbr_dict = defaultdict(lambda: 'th', {1: 'st', 2: 'nd', 3: 'rd'})
return self._n_abbr_dict[n]
@property
def agent_actions(self):
return self._get_sub_list('Agents', "Actions")