diff --git a/marl_factory_grid/environment/actions.py b/marl_factory_grid/environment/actions.py index 53403a7..a2281e3 100644 --- a/marl_factory_grid/environment/actions.py +++ b/marl_factory_grid/environment/actions.py @@ -7,6 +7,8 @@ from marl_factory_grid.utils.helpers import MOVEMAP from marl_factory_grid.utils.results import ActionResult +TYPE_COLLISION = 'collision' + class Action(abc.ABC): @property diff --git a/marl_factory_grid/environment/factory.py b/marl_factory_grid/environment/factory.py index d64ede9..b793f80 100644 --- a/marl_factory_grid/environment/factory.py +++ b/marl_factory_grid/environment/factory.py @@ -212,8 +212,7 @@ class Factory(gym.Env): # Combine Info dicts into a global one combined_info_dict = defaultdict(lambda: 0.0) for result in chain(tick_results, done_check_results): - if not result: - raise ValueError() + assert result, 'Something returned None...' if result.reward is not None: try: rewards[result.entity.name] += result.reward diff --git a/marl_factory_grid/utils/results.py b/marl_factory_grid/utils/results.py index a40004b..0fb260e 100644 --- a/marl_factory_grid/utils/results.py +++ b/marl_factory_grid/utils/results.py @@ -2,9 +2,11 @@ from typing import Union from dataclasses import dataclass from marl_factory_grid.environment.entity.object import Object - +import marl_factory_grid.environment.constants as c TYPE_VALUE = 'value' TYPE_REWARD = 'reward' + + TYPES = [TYPE_VALUE, TYPE_REWARD] @@ -32,8 +34,9 @@ class Result: """ identifier: str validity: bool - reward: Union[float, None] = None - value: Union[float, None] = None + reward: float | None = None + value: float | None = None + collision: bool | None = None entity: Object = None def get_infos(self): @@ -68,8 +71,17 @@ class ActionResult(Result): super().__init__(*args, **kwargs) self.action_introduced_collision = action_introduced_collision - pass + def __repr__(self): + sr = super().__repr__() + return sr + f" | {c.COLLISION}" if self.action_introduced_collision is not None else "" + def get_infos(self): + base_infos = super().get_infos() + if self.action_introduced_collision: + i = InfoObject(identifier=f'{self.entity.name}_{c.COLLISION}', val_type=TYPE_VALUE, value=1) + return base_infos + [i] + else: + return base_infos @dataclass class DoneResult(Result):