mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-05 09:01:36 +02:00
finished documentation of rules.py and base action
This commit is contained in:
@ -6,25 +6,30 @@ from marl_factory_grid.environment import rewards as r, constants as c
|
|||||||
from marl_factory_grid.utils.helpers import MOVEMAP
|
from marl_factory_grid.utils.helpers import MOVEMAP
|
||||||
from marl_factory_grid.utils.results import ActionResult
|
from marl_factory_grid.utils.results import ActionResult
|
||||||
|
|
||||||
|
|
||||||
TYPE_COLLISION = 'collision'
|
TYPE_COLLISION = 'collision'
|
||||||
|
|
||||||
|
|
||||||
class Action(abc.ABC):
|
class Action(abc.ABC):
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return self._identifier
|
return self._identifier
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def __init__(self, identifier: str, default_valid_reward: float, default_fail_reward: float,
|
def __init__(self, identifier: str, default_valid_reward: float, default_fail_reward: float,
|
||||||
valid_reward: float | None = None, fail_reward: float | None = None):
|
valid_reward: float | None = None, fail_reward: float | None = None):
|
||||||
"""
|
"""
|
||||||
Todo
|
Abstract base class representing an action that can be performed in the environment.
|
||||||
|
|
||||||
:param identifier:
|
:param identifier: A unique identifier for the action.
|
||||||
:param default_valid_reward:
|
:type identifier: str
|
||||||
:param default_fail_reward:
|
:param default_valid_reward: Default reward for a valid action.
|
||||||
:param valid_reward:
|
:type default_valid_reward: float
|
||||||
:param fail_reward:
|
:param default_fail_reward: Default reward for a failed action.
|
||||||
|
:type default_fail_reward: float
|
||||||
|
:param valid_reward: Custom reward for a valid action (optional).
|
||||||
|
:type valid_reward: Union[float, optional]
|
||||||
|
:param fail_reward: Custom reward for a failed action (optional).
|
||||||
|
:type fail_reward: Union[float, optional]
|
||||||
"""
|
"""
|
||||||
self.fail_reward = fail_reward if fail_reward is not None else default_fail_reward
|
self.fail_reward = fail_reward if fail_reward is not None else default_fail_reward
|
||||||
self.valid_reward = valid_reward if valid_reward is not None else default_valid_reward
|
self.valid_reward = valid_reward if valid_reward is not None else default_valid_reward
|
||||||
@ -46,6 +51,9 @@ class Action(abc.ABC):
|
|||||||
return f'Action[{self._identifier}]'
|
return f'Action[{self._identifier}]'
|
||||||
|
|
||||||
def get_result(self, validity, entity, action_introduced_collision=False):
|
def get_result(self, validity, entity, action_introduced_collision=False):
|
||||||
|
"""
|
||||||
|
Generate an ActionResult for the action based on its validity.
|
||||||
|
"""
|
||||||
reward = self.valid_reward if validity else self.fail_reward
|
reward = self.valid_reward if validity else self.fail_reward
|
||||||
return ActionResult(self.__class__.__name__, validity, reward=reward, entity=entity,
|
return ActionResult(self.__class__.__name__, validity, reward=reward, entity=entity,
|
||||||
action_introduced_collision=action_introduced_collision)
|
action_introduced_collision=action_introduced_collision)
|
||||||
|
@ -16,85 +16,128 @@ class Rule(abc.ABC):
|
|||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
"""
|
"""
|
||||||
TODO
|
Get the name of the rule.
|
||||||
|
|
||||||
|
:return: The name of the rule.
|
||||||
:return:
|
:rtype: str
|
||||||
"""
|
"""
|
||||||
return self.__class__.__name__
|
return self.__class__.__name__
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""
|
"""
|
||||||
TODO
|
Abstract base class representing a rule in the environment.
|
||||||
|
|
||||||
|
This class provides a framework for defining rules that govern the behavior of the environment. Rules can be
|
||||||
|
implemented by inheriting from this class and overriding specific methods.
|
||||||
|
|
||||||
:return:
|
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
|
"""
|
||||||
|
Return a string representation of the rule.
|
||||||
|
|
||||||
|
:return: A string representation of the rule.
|
||||||
|
:rtype: str
|
||||||
|
"""
|
||||||
return f'{self.name}'
|
return f'{self.name}'
|
||||||
|
|
||||||
def on_init(self, state, lvl_map):
|
def on_init(self, state, lvl_map):
|
||||||
"""
|
"""
|
||||||
TODO
|
Initialize the rule when the environment is created.
|
||||||
|
|
||||||
|
This method is called during the initialization of the environment. It allows the rule to perform any setup or
|
||||||
|
initialization required.
|
||||||
|
|
||||||
:return:
|
:param state: The current game state.
|
||||||
|
:type state: marl_factory_grid.utils.states.GameState
|
||||||
|
:param lvl_map: The map of the level.
|
||||||
|
:type lvl_map: marl_factory_grid.environment.level.LevelMap
|
||||||
|
:return: List of TickResults generated during initialization.
|
||||||
|
:rtype: List[TickResult]
|
||||||
"""
|
"""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def on_reset_post_spawn(self, state) -> List[TickResult]:
|
def on_reset_post_spawn(self, state) -> List[TickResult]:
|
||||||
"""
|
"""
|
||||||
TODO
|
Execute actions after entities are spawned during a reset.
|
||||||
|
|
||||||
|
This method is called after entities are spawned during a reset. It allows the rule to perform any actions
|
||||||
|
required at this stage.
|
||||||
|
|
||||||
:return:
|
:param state: The current game state.
|
||||||
|
:type state: marl_factory_grid.utils.states.GameState
|
||||||
|
:return: List of TickResults generated after entity spawning.
|
||||||
|
:rtype: List[TickResult]
|
||||||
"""
|
"""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def on_reset(self, state) -> List[TickResult]:
|
def on_reset(self, state) -> List[TickResult]:
|
||||||
"""
|
"""
|
||||||
TODO
|
Execute actions during a reset.
|
||||||
|
|
||||||
|
This method is called during a reset. It allows the rule to perform any actions required at this stage.
|
||||||
|
|
||||||
:return:
|
:param state: The current game state.
|
||||||
|
:type state: marl_factory_grid.utils.states.GameState
|
||||||
|
:return: List of TickResults generated during a reset.
|
||||||
|
:rtype: List[TickResult]
|
||||||
"""
|
"""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def tick_pre_step(self, state) -> List[TickResult]:
|
def tick_pre_step(self, state) -> List[TickResult]:
|
||||||
"""
|
"""
|
||||||
TODO
|
Execute actions before the main step of the environment.
|
||||||
|
|
||||||
|
This method is called before the main step of the environment. It allows the rule to perform any actions
|
||||||
|
required before the main step.
|
||||||
|
|
||||||
:return:
|
:param state: The current game state.
|
||||||
|
:type state: marl_factory_grid.utils.states.GameState
|
||||||
|
:return: List of TickResults generated before the main step.
|
||||||
|
:rtype: List[TickResult]
|
||||||
"""
|
"""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def tick_step(self, state) -> List[TickResult]:
|
def tick_step(self, state) -> List[TickResult]:
|
||||||
"""
|
"""
|
||||||
TODO
|
Execute actions during the main step of the environment.
|
||||||
|
|
||||||
|
This method is called during the main step of the environment. It allows the rule to perform any actions
|
||||||
|
required during the main step.
|
||||||
|
|
||||||
:return:
|
:param state: The current game state.
|
||||||
|
:type state: marl_factory_grid.utils.states.GameState
|
||||||
|
:return: List of TickResults generated during the main step.
|
||||||
|
:rtype: List[TickResult]
|
||||||
"""
|
"""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def tick_post_step(self, state) -> List[TickResult]:
|
def tick_post_step(self, state) -> List[TickResult]:
|
||||||
"""
|
"""
|
||||||
TODO
|
Execute actions after the main step of the environment.
|
||||||
|
|
||||||
|
This method is called after the main step of the environment. It allows the rule to perform any actions
|
||||||
|
required after the main step.
|
||||||
|
|
||||||
:return:
|
:param state: The current game state.
|
||||||
|
:type state: marl_factory_grid.utils.states.GameState
|
||||||
|
:return: List of TickResults generated after the main step.
|
||||||
|
:rtype: List[TickResult]
|
||||||
"""
|
"""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def on_check_done(self, state) -> List[DoneResult]:
|
def on_check_done(self, state) -> List[DoneResult]:
|
||||||
"""
|
"""
|
||||||
TODO
|
Check conditions for the termination of the environment.
|
||||||
|
|
||||||
|
This method is called to check conditions for the termination of the environment. It allows the rule to
|
||||||
|
specify conditions under which the environment should be considered done.
|
||||||
|
|
||||||
:return:
|
:param state: The current game state.
|
||||||
|
:type state: marl_factory_grid.utils.states.GameState
|
||||||
|
:return: List of DoneResults indicating whether the environment is done.
|
||||||
|
:rtype: List[DoneResult]
|
||||||
"""
|
"""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@ -160,15 +203,23 @@ class DoneAtMaxStepsReached(Rule):
|
|||||||
|
|
||||||
def __init__(self, max_steps: int = 500):
|
def __init__(self, max_steps: int = 500):
|
||||||
"""
|
"""
|
||||||
TODO
|
A rule that terminates the environment when a specified maximum number of steps is reached.
|
||||||
|
|
||||||
|
:param max_steps: The maximum number of steps before the environment is considered done.
|
||||||
:return:
|
:type max_steps: int
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_steps = max_steps
|
self.max_steps = max_steps
|
||||||
|
|
||||||
def on_check_done(self, state):
|
def on_check_done(self, state):
|
||||||
|
"""
|
||||||
|
Check if the maximum number of steps is reached, and if so, mark the environment as done.
|
||||||
|
|
||||||
|
:param state: The current game state.
|
||||||
|
:type state: marl_factory_grid.utils.states.GameState
|
||||||
|
:return: List of DoneResults indicating whether the environment is done.
|
||||||
|
:rtype: List[DoneResult]
|
||||||
|
"""
|
||||||
if self.max_steps <= state.curr_step:
|
if self.max_steps <= state.curr_step:
|
||||||
return [DoneResult(validity=c.VALID, identifier=self.name)]
|
return [DoneResult(validity=c.VALID, identifier=self.name)]
|
||||||
return []
|
return []
|
||||||
@ -178,14 +229,23 @@ class AssignGlobalPositions(Rule):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""
|
"""
|
||||||
TODO
|
A rule that assigns global positions to agents when the environment is reset.
|
||||||
|
|
||||||
|
:return: None
|
||||||
:return:
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def on_reset(self, state, lvl_map):
|
def on_reset(self, state, lvl_map):
|
||||||
|
"""
|
||||||
|
Assign global positions to agents when the environment is reset.
|
||||||
|
|
||||||
|
:param state: The current game state.
|
||||||
|
:type state: marl_factory_grid.utils.states.GameState
|
||||||
|
:param lvl_map: The map of the current level.
|
||||||
|
:type lvl_map: marl_factory_grid.levels.level.LevelMap
|
||||||
|
:return: An empty list, as no additional results are generated by this rule during the reset.
|
||||||
|
:rtype: List[TickResult]
|
||||||
|
"""
|
||||||
from marl_factory_grid.environment.entity.util import GlobalPosition
|
from marl_factory_grid.environment.entity.util import GlobalPosition
|
||||||
for agent in state[c.AGENT]:
|
for agent in state[c.AGENT]:
|
||||||
gp = GlobalPosition(agent, lvl_map.level_shape)
|
gp = GlobalPosition(agent, lvl_map.level_shape)
|
||||||
@ -197,10 +257,15 @@ class WatchCollisions(Rule):
|
|||||||
|
|
||||||
def __init__(self, reward=r.COLLISION, done_at_collisions: bool = False, reward_at_done=r.COLLISION_DONE):
|
def __init__(self, reward=r.COLLISION, done_at_collisions: bool = False, reward_at_done=r.COLLISION_DONE):
|
||||||
"""
|
"""
|
||||||
TODO
|
A rule that monitors collisions between entities in the environment.
|
||||||
|
|
||||||
|
:param reward: The reward assigned for each collision.
|
||||||
:return:
|
:type reward: float
|
||||||
|
:param done_at_collisions: If True, marks the environment as done when collisions occur.
|
||||||
|
:type done_at_collisions: bool
|
||||||
|
:param reward_at_done: The reward assigned when the environment is marked as done due to collisions.
|
||||||
|
:type reward_at_done: float
|
||||||
|
:return: None
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.reward_at_done = reward_at_done
|
self.reward_at_done = reward_at_done
|
||||||
@ -209,6 +274,14 @@ class WatchCollisions(Rule):
|
|||||||
self.curr_done = False
|
self.curr_done = False
|
||||||
|
|
||||||
def tick_post_step(self, state) -> List[TickResult]:
|
def tick_post_step(self, state) -> List[TickResult]:
|
||||||
|
"""
|
||||||
|
Monitors collisions between entities after each step in the environment.
|
||||||
|
|
||||||
|
:param state: The current game state.
|
||||||
|
:type state: marl_factory_grid.utils.states.GameState
|
||||||
|
:return: A list of TickResult objects representing collisions and their associated rewards.
|
||||||
|
:rtype: List[TickResult]
|
||||||
|
"""
|
||||||
self.curr_done = False
|
self.curr_done = False
|
||||||
results = list()
|
results = list()
|
||||||
for agent in state[c.AGENT]:
|
for agent in state[c.AGENT]:
|
||||||
@ -234,6 +307,14 @@ class WatchCollisions(Rule):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
def on_check_done(self, state) -> List[DoneResult]:
|
def on_check_done(self, state) -> List[DoneResult]:
|
||||||
|
"""
|
||||||
|
Checks if the environment should be marked as done based on collision conditions.
|
||||||
|
|
||||||
|
:param state: The current game state.
|
||||||
|
:type state: marl_factory_grid.utils.states.GameState
|
||||||
|
:return: A list of DoneResult objects representing the conditions for marking the environment as done.
|
||||||
|
:rtype: List[DoneResult]
|
||||||
|
"""
|
||||||
if self.done_at_collisions:
|
if self.done_at_collisions:
|
||||||
inter_entity_collision_detected = self.curr_done
|
inter_entity_collision_detected = self.curr_done
|
||||||
collision_in_step = any(h.is_move(x.state.identifier) and x.state.action_introduced_collision
|
collision_in_step = any(h.is_move(x.state.identifier) and x.state.action_introduced_collision
|
||||||
|
@ -23,7 +23,6 @@ class FactoryConfigParser(object):
|
|||||||
"""
|
"""
|
||||||
This class parses the factory env config file.
|
This class parses the factory env config file.
|
||||||
|
|
||||||
|
|
||||||
:param config_path: Path to where the 'config.yml' is.
|
:param config_path: Path to where the 'config.yml' is.
|
||||||
:param custom_modules_path: Additional search path for custom modules, levels, entities, etc..
|
:param custom_modules_path: Additional search path for custom modules, levels, entities, etc..
|
||||||
"""
|
"""
|
||||||
@ -45,7 +44,6 @@ class FactoryConfigParser(object):
|
|||||||
self._n_abbr_dict = defaultdict(lambda: 'th', {1: 'st', 2: 'nd', 3: 'rd'})
|
self._n_abbr_dict = defaultdict(lambda: 'th', {1: 'st', 2: 'nd', 3: 'rd'})
|
||||||
return self._n_abbr_dict[n]
|
return self._n_abbr_dict[n]
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def agent_actions(self):
|
def agent_actions(self):
|
||||||
return self._get_sub_list('Agents', "Actions")
|
return self._get_sub_list('Agents', "Actions")
|
||||||
@ -176,7 +174,7 @@ class FactoryConfigParser(object):
|
|||||||
['Actions', 'Observations', 'Positions', 'Clones']}
|
['Actions', 'Observations', 'Positions', 'Clones']}
|
||||||
parsed_agents_conf[name] = dict(
|
parsed_agents_conf[name] = dict(
|
||||||
actions=parsed_actions, observations=observations, positions=positions, other=other_kwargs
|
actions=parsed_actions, observations=observations, positions=positions, other=other_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
clones = self.agents[name].get('Clones', 0)
|
clones = self.agents[name].get('Clones', 0)
|
||||||
if clones:
|
if clones:
|
||||||
|
Reference in New Issue
Block a user