mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-05 17:11:35 +02:00
Info Updated
This commit is contained in:
@ -2,9 +2,11 @@ from typing import Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
from marl_factory_grid.environment.entity.object import Object
|
||||
|
||||
import marl_factory_grid.environment.constants as c
|
||||
TYPE_VALUE = 'value'
|
||||
TYPE_REWARD = 'reward'
|
||||
|
||||
|
||||
TYPES = [TYPE_VALUE, TYPE_REWARD]
|
||||
|
||||
|
||||
@ -32,8 +34,9 @@ class Result:
|
||||
"""
|
||||
identifier: str
|
||||
validity: bool
|
||||
reward: Union[float, None] = None
|
||||
value: Union[float, None] = None
|
||||
reward: float | None = None
|
||||
value: float | None = None
|
||||
collision: bool | None = None
|
||||
entity: Object = None
|
||||
|
||||
def get_infos(self):
|
||||
@ -68,8 +71,17 @@ class ActionResult(Result):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.action_introduced_collision = action_introduced_collision
|
||||
|
||||
pass
|
||||
def __repr__(self):
|
||||
sr = super().__repr__()
|
||||
return sr + f" | {c.COLLISION}" if self.action_introduced_collision is not None else ""
|
||||
|
||||
def get_infos(self):
|
||||
base_infos = super().get_infos()
|
||||
if self.action_introduced_collision:
|
||||
i = InfoObject(identifier=f'{self.entity.name}_{c.COLLISION}', val_type=TYPE_VALUE, value=1)
|
||||
return base_infos + [i]
|
||||
else:
|
||||
return base_infos
|
||||
|
||||
@dataclass
|
||||
class DoneResult(Result):
|
||||
|
Reference in New Issue
Block a user