finished documentation of rules.py and base action

This commit is contained in:
Joel Friedrich
2023-12-22 11:01:44 +01:00
parent 944887aa2e
commit 51e2f71d15
3 changed files with 129 additions and 42 deletions

View File

@ -6,25 +6,30 @@ from marl_factory_grid.environment import rewards as r, constants as c
from marl_factory_grid.utils.helpers import MOVEMAP
from marl_factory_grid.utils.results import ActionResult
TYPE_COLLISION = 'collision'
class Action(abc.ABC):
@property
def name(self):
return self._identifier
@abc.abstractmethod
def __init__(self, identifier: str, default_valid_reward: float, default_fail_reward: float,
def __init__(self, identifier: str, default_valid_reward: float, default_fail_reward: float,
valid_reward: float | None = None, fail_reward: float | None = None):
"""
Todo
Abstract base class representing an action that can be performed in the environment.
:param identifier:
:param default_valid_reward:
:param default_fail_reward:
:param valid_reward:
:param fail_reward:
:param identifier: A unique identifier for the action.
:type identifier: str
:param default_valid_reward: Default reward for a valid action.
:type default_valid_reward: float
:param default_fail_reward: Default reward for a failed action.
:type default_fail_reward: float
:param valid_reward: Custom reward for a valid action (optional).
:type valid_reward: Union[float, optional]
:param fail_reward: Custom reward for a failed action (optional).
:type fail_reward: Union[float, optional]
"""
self.fail_reward = fail_reward if fail_reward is not None else default_fail_reward
self.valid_reward = valid_reward if valid_reward is not None else default_valid_reward
@ -46,6 +51,9 @@ class Action(abc.ABC):
return f'Action[{self._identifier}]'
def get_result(self, validity, entity, action_introduced_collision=False):
"""
Generate an ActionResult for the action based on its validity.
"""
reward = self.valid_reward if validity else self.fail_reward
return ActionResult(self.__class__.__name__, validity, reward=reward, entity=entity,
action_introduced_collision=action_introduced_collision)