recorder fixed
This commit is contained in:
@@ -71,6 +71,12 @@ class BaseFactory(gym.Env):
|
||||
d['class_name'] = self.__class__.__name__
|
||||
return d
|
||||
|
||||
@property
|
||||
def summarize_header(self):
|
||||
summary_dict = self._summarize_state(stateless_entities=True)
|
||||
summary_dict.update(actions=self._actions.summarize())
|
||||
return summary_dict
|
||||
|
||||
def __enter__(self):
|
||||
return self if self.obs_prop.frames_to_stack == 0 else \
|
||||
MarlFrameStack(FrameStack(self, self.obs_prop.frames_to_stack))
|
||||
@@ -665,12 +671,12 @@ class BaseFactory(gym.Env):
|
||||
else:
|
||||
return []
|
||||
|
||||
def _summarize_state(self):
|
||||
def _summarize_state(self, stateless_entities=False):
|
||||
summary = {f'{REC_TAC}step': self._steps}
|
||||
|
||||
for entity_group in self._entities:
|
||||
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states(n_steps=self._steps)})
|
||||
|
||||
if entity_group.is_stateless == stateless_entities:
|
||||
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states()})
|
||||
return summary
|
||||
|
||||
def print(self, string):
|
||||
|
||||
@@ -86,7 +86,7 @@ class EnvObject(Object):
|
||||
|
||||
# TODO: Missing Documentation
|
||||
class Entity(EnvObject):
|
||||
"""Full Env Entity that lives on the env Grid. Doors, Items, Dirt etc..."""
|
||||
"""Full Env Entity that lives on the env Grid. Doors, Items, DirtPile etc..."""
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
@@ -113,7 +113,7 @@ class Entity(EnvObject):
|
||||
self._tile = tile
|
||||
tile.enter(self)
|
||||
|
||||
def summarize_state(self, **_) -> dict:
|
||||
def summarize_state(self) -> dict:
|
||||
return dict(name=str(self.name), x=int(self.x), y=int(self.y),
|
||||
tile=str(self.tile.name), can_collide=bool(self.can_collide))
|
||||
|
||||
@@ -338,8 +338,8 @@ class Door(Entity):
|
||||
if not closed_on_init:
|
||||
self._open()
|
||||
|
||||
def summarize_state(self, **kwargs):
|
||||
state_dict = super().summarize_state(**kwargs)
|
||||
def summarize_state(self):
|
||||
state_dict = super().summarize_state()
|
||||
state_dict.update(state=str(self.str_state), time_to_close=int(self.time_to_close))
|
||||
return state_dict
|
||||
|
||||
@@ -402,7 +402,7 @@ class Agent(MoveableEntity):
|
||||
# if attr.startswith('temp'):
|
||||
self.step_result = None
|
||||
|
||||
def summarize_state(self, **kwargs):
|
||||
state_dict = super().summarize_state(**kwargs)
|
||||
def summarize_state(self):
|
||||
state_dict = super().summarize_state()
|
||||
state_dict.update(valid=bool(self.step_result['action_valid']), action=str(self.step_result['action_name']))
|
||||
return state_dict
|
||||
|
||||
@@ -19,6 +19,11 @@ from environments.helpers import Constants as c
|
||||
|
||||
class ObjectCollection:
|
||||
_accepted_objects = Object
|
||||
_stateless_entities = False
|
||||
|
||||
@property
|
||||
def is_stateless(self):
|
||||
return self._stateless_entities
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
@@ -116,8 +121,8 @@ class EnvObjectCollection(ObjectCollection):
|
||||
self._lazy_eval_transforms = []
|
||||
return self._array
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
return [entity.summarize_state(n_steps=n_steps) for entity in self.values()]
|
||||
def summarize_states(self):
|
||||
return [entity.summarize_state() for entity in self.values()]
|
||||
|
||||
def notify_change_to_free(self, env_object: EnvObject):
|
||||
self._array_change_notifyer(env_object, value=c.FREE_CELL)
|
||||
@@ -290,9 +295,6 @@ class GlobalPositions(EnvObjectCollection):
|
||||
# noinspection PyTypeChecker
|
||||
self.add_additional_items(global_positions)
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
return {}
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
try:
|
||||
return next((idx for idx, inv in enumerate(self) if inv.belongs_to_entity(entity)))
|
||||
@@ -376,6 +378,7 @@ class Entities(ObjectCollection):
|
||||
|
||||
class Walls(EntityCollection):
|
||||
_accepted_objects = Wall
|
||||
_stateless_entities = True
|
||||
|
||||
def as_array(self):
|
||||
if not np.any(self._array):
|
||||
@@ -406,15 +409,10 @@ class Walls(EntityCollection):
|
||||
def from_tiles(cls, tiles, *args, **kwargs):
|
||||
raise RuntimeError()
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
if n_steps == h.STEPS_START:
|
||||
return super(Walls, self).summarize_states(n_steps=n_steps)
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
class Floors(Walls):
|
||||
_accepted_objects = Floor
|
||||
_stateless_entities = True
|
||||
|
||||
def __init__(self, *args, is_blocking_light=False, **kwargs):
|
||||
super(Floors, self).__init__(*args, is_blocking_light=is_blocking_light, **kwargs)
|
||||
@@ -436,10 +434,6 @@ class Floors(Walls):
|
||||
def from_tiles(cls, tiles, *args, **kwargs):
|
||||
raise RuntimeError()
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
# Do not summarize
|
||||
return {}
|
||||
|
||||
|
||||
class Agents(MovingEntityObjectCollection):
|
||||
_accepted_objects = Agent
|
||||
@@ -521,6 +515,9 @@ class Actions(ObjectCollection):
|
||||
def is_moving_action(self, action: Union[int]):
|
||||
return action in self.movement_actions.values()
|
||||
|
||||
def summarize(self):
|
||||
return [dict(name=action.identifier) for action in self]
|
||||
|
||||
|
||||
class Zones(ObjectCollection):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user