mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-06 09:31:35 +02:00
Info Updated
This commit is contained in:
@ -7,6 +7,8 @@ from marl_factory_grid.utils.helpers import MOVEMAP
|
|||||||
from marl_factory_grid.utils.results import ActionResult
|
from marl_factory_grid.utils.results import ActionResult
|
||||||
|
|
||||||
|
|
||||||
|
TYPE_COLLISION = 'collision'
|
||||||
|
|
||||||
class Action(abc.ABC):
|
class Action(abc.ABC):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -212,8 +212,7 @@ class Factory(gym.Env):
|
|||||||
# Combine Info dicts into a global one
|
# Combine Info dicts into a global one
|
||||||
combined_info_dict = defaultdict(lambda: 0.0)
|
combined_info_dict = defaultdict(lambda: 0.0)
|
||||||
for result in chain(tick_results, done_check_results):
|
for result in chain(tick_results, done_check_results):
|
||||||
if not result:
|
assert result, 'Something returned None...'
|
||||||
raise ValueError()
|
|
||||||
if result.reward is not None:
|
if result.reward is not None:
|
||||||
try:
|
try:
|
||||||
rewards[result.entity.name] += result.reward
|
rewards[result.entity.name] += result.reward
|
||||||
|
@ -2,9 +2,11 @@ from typing import Union
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from marl_factory_grid.environment.entity.object import Object
|
from marl_factory_grid.environment.entity.object import Object
|
||||||
|
import marl_factory_grid.environment.constants as c
|
||||||
TYPE_VALUE = 'value'
|
TYPE_VALUE = 'value'
|
||||||
TYPE_REWARD = 'reward'
|
TYPE_REWARD = 'reward'
|
||||||
|
|
||||||
|
|
||||||
TYPES = [TYPE_VALUE, TYPE_REWARD]
|
TYPES = [TYPE_VALUE, TYPE_REWARD]
|
||||||
|
|
||||||
|
|
||||||
@ -32,8 +34,9 @@ class Result:
|
|||||||
"""
|
"""
|
||||||
identifier: str
|
identifier: str
|
||||||
validity: bool
|
validity: bool
|
||||||
reward: Union[float, None] = None
|
reward: float | None = None
|
||||||
value: Union[float, None] = None
|
value: float | None = None
|
||||||
|
collision: bool | None = None
|
||||||
entity: Object = None
|
entity: Object = None
|
||||||
|
|
||||||
def get_infos(self):
|
def get_infos(self):
|
||||||
@ -68,8 +71,17 @@ class ActionResult(Result):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.action_introduced_collision = action_introduced_collision
|
self.action_introduced_collision = action_introduced_collision
|
||||||
|
|
||||||
pass
|
def __repr__(self):
|
||||||
|
sr = super().__repr__()
|
||||||
|
return sr + f" | {c.COLLISION}" if self.action_introduced_collision is not None else ""
|
||||||
|
|
||||||
|
def get_infos(self):
|
||||||
|
base_infos = super().get_infos()
|
||||||
|
if self.action_introduced_collision:
|
||||||
|
i = InfoObject(identifier=f'{self.entity.name}_{c.COLLISION}', val_type=TYPE_VALUE, value=1)
|
||||||
|
return base_infos + [i]
|
||||||
|
else:
|
||||||
|
return base_infos
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DoneResult(Result):
|
class DoneResult(Result):
|
||||||
|
Reference in New Issue
Block a user