recorder fixed

This commit is contained in:
Steffen Illium
2022-08-18 16:15:17 +02:00
parent 6a24e7b518
commit 4f3924d3ab
19 changed files with 104 additions and 104 deletions
+9 -3
View File
@@ -71,6 +71,12 @@ class BaseFactory(gym.Env):
d['class_name'] = self.__class__.__name__
return d
@property
def summarize_header(self):
summary_dict = self._summarize_state(stateless_entities=True)
summary_dict.update(actions=self._actions.summarize())
return summary_dict
def __enter__(self):
return self if self.obs_prop.frames_to_stack == 0 else \
MarlFrameStack(FrameStack(self, self.obs_prop.frames_to_stack))
@@ -665,12 +671,12 @@ class BaseFactory(gym.Env):
else:
return []
def _summarize_state(self):
def _summarize_state(self, stateless_entities=False):
summary = {f'{REC_TAC}step': self._steps}
for entity_group in self._entities:
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states(n_steps=self._steps)})
if entity_group.is_stateless == stateless_entities:
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states()})
return summary
def print(self, string):
+6 -6
View File
@@ -86,7 +86,7 @@ class EnvObject(Object):
# TODO: Missing Documentation
class Entity(EnvObject):
"""Full Env Entity that lives on the env Grid. Doors, Items, Dirt etc..."""
"""Full Env Entity that lives on the env Grid. Doors, Items, DirtPile etc..."""
@property
def can_collide(self):
@@ -113,7 +113,7 @@ class Entity(EnvObject):
self._tile = tile
tile.enter(self)
def summarize_state(self, **_) -> dict:
def summarize_state(self) -> dict:
return dict(name=str(self.name), x=int(self.x), y=int(self.y),
tile=str(self.tile.name), can_collide=bool(self.can_collide))
@@ -338,8 +338,8 @@ class Door(Entity):
if not closed_on_init:
self._open()
def summarize_state(self, **kwargs):
state_dict = super().summarize_state(**kwargs)
def summarize_state(self):
state_dict = super().summarize_state()
state_dict.update(state=str(self.str_state), time_to_close=int(self.time_to_close))
return state_dict
@@ -402,7 +402,7 @@ class Agent(MoveableEntity):
# if attr.startswith('temp'):
self.step_result = None
def summarize_state(self, **kwargs):
state_dict = super().summarize_state(**kwargs)
def summarize_state(self):
state_dict = super().summarize_state()
state_dict.update(valid=bool(self.step_result['action_valid']), action=str(self.step_result['action_name']))
return state_dict
+12 -15
View File
@@ -19,6 +19,11 @@ from environments.helpers import Constants as c
class ObjectCollection:
_accepted_objects = Object
_stateless_entities = False
@property
def is_stateless(self):
return self._stateless_entities
@property
def name(self):
@@ -116,8 +121,8 @@ class EnvObjectCollection(ObjectCollection):
self._lazy_eval_transforms = []
return self._array
def summarize_states(self, n_steps=None):
return [entity.summarize_state(n_steps=n_steps) for entity in self.values()]
def summarize_states(self):
return [entity.summarize_state() for entity in self.values()]
def notify_change_to_free(self, env_object: EnvObject):
self._array_change_notifyer(env_object, value=c.FREE_CELL)
@@ -290,9 +295,6 @@ class GlobalPositions(EnvObjectCollection):
# noinspection PyTypeChecker
self.add_additional_items(global_positions)
def summarize_states(self, n_steps=None):
return {}
def idx_by_entity(self, entity):
try:
return next((idx for idx, inv in enumerate(self) if inv.belongs_to_entity(entity)))
@@ -376,6 +378,7 @@ class Entities(ObjectCollection):
class Walls(EntityCollection):
_accepted_objects = Wall
_stateless_entities = True
def as_array(self):
if not np.any(self._array):
@@ -406,15 +409,10 @@ class Walls(EntityCollection):
def from_tiles(cls, tiles, *args, **kwargs):
raise RuntimeError()
def summarize_states(self, n_steps=None):
if n_steps == h.STEPS_START:
return super(Walls, self).summarize_states(n_steps=n_steps)
else:
return {}
class Floors(Walls):
_accepted_objects = Floor
_stateless_entities = True
def __init__(self, *args, is_blocking_light=False, **kwargs):
super(Floors, self).__init__(*args, is_blocking_light=is_blocking_light, **kwargs)
@@ -436,10 +434,6 @@ class Floors(Walls):
def from_tiles(cls, tiles, *args, **kwargs):
raise RuntimeError()
def summarize_states(self, n_steps=None):
# Do not summarize
return {}
class Agents(MovingEntityObjectCollection):
_accepted_objects = Agent
@@ -521,6 +515,9 @@ class Actions(ObjectCollection):
def is_moving_action(self, action: Union[int]):
return action in self.movement_actions.values()
def summarize(self):
return [dict(name=action.identifier) for action in self]
class Zones(ObjectCollection):