Info Updated

This commit is contained in:
Steffen Illium
2023-11-24 15:41:44 +01:00
parent 08db1dfc6f
commit 17613c3ba9
3 changed files with 19 additions and 6 deletions

View File

@ -7,6 +7,8 @@ from marl_factory_grid.utils.helpers import MOVEMAP
from marl_factory_grid.utils.results import ActionResult from marl_factory_grid.utils.results import ActionResult
TYPE_COLLISION = 'collision'
class Action(abc.ABC): class Action(abc.ABC):
@property @property

View File

@ -212,8 +212,7 @@ class Factory(gym.Env):
# Combine Info dicts into a global one # Combine Info dicts into a global one
combined_info_dict = defaultdict(lambda: 0.0) combined_info_dict = defaultdict(lambda: 0.0)
for result in chain(tick_results, done_check_results): for result in chain(tick_results, done_check_results):
if not result: assert result, 'Something returned None...'
raise ValueError()
if result.reward is not None: if result.reward is not None:
try: try:
rewards[result.entity.name] += result.reward rewards[result.entity.name] += result.reward

View File

@ -2,9 +2,11 @@ from typing import Union
from dataclasses import dataclass from dataclasses import dataclass
from marl_factory_grid.environment.entity.object import Object from marl_factory_grid.environment.entity.object import Object
import marl_factory_grid.environment.constants as c
TYPE_VALUE = 'value' TYPE_VALUE = 'value'
TYPE_REWARD = 'reward' TYPE_REWARD = 'reward'
TYPES = [TYPE_VALUE, TYPE_REWARD] TYPES = [TYPE_VALUE, TYPE_REWARD]
@ -32,8 +34,9 @@ class Result:
""" """
identifier: str identifier: str
validity: bool validity: bool
reward: Union[float, None] = None reward: float | None = None
value: Union[float, None] = None value: float | None = None
collision: bool | None = None
entity: Object = None entity: Object = None
def get_infos(self): def get_infos(self):
@ -68,8 +71,17 @@ class ActionResult(Result):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.action_introduced_collision = action_introduced_collision self.action_introduced_collision = action_introduced_collision
pass def __repr__(self):
sr = super().__repr__()
return sr + f" | {c.COLLISION}" if self.action_introduced_collision is not None else ""
def get_infos(self):
base_infos = super().get_infos()
if self.action_introduced_collision:
i = InfoObject(identifier=f'{self.entity.name}_{c.COLLISION}', val_type=TYPE_VALUE, value=1)
return base_infos + [i]
else:
return base_infos
@dataclass @dataclass
class DoneResult(Result): class DoneResult(Result):