From 4f3924d3ababaee223a9ed7750a33ae86f016294 Mon Sep 17 00:00:00 2001 From: Steffen Illium Date: Thu, 18 Aug 2022 16:15:17 +0200 Subject: [PATCH] recorder fixed --- algorithms/TSP_dirt_agent.py | 2 +- .../additional/btry/btry_collections.py | 17 +++------- .../factory/additional/btry/btry_objects.py | 9 +----- .../additional/btry/factory_battery.py | 6 ++-- .../additional/dest/dest_collections.py | 6 ---- .../factory/additional/dest/dest_enitites.py | 7 ++-- .../additional/dirt/dirt_collections.py | 12 +++---- .../factory/additional/dirt/dirt_entity.py | 8 ++--- .../factory/additional/dirt/dirt_util.py | 2 +- .../factory/additional/dirt/factory_dirt.py | 12 +++---- .../factory/additional/item/factory_item.py | 4 +-- .../additional/item/item_collections.py | 12 ++++--- .../factory/additional/item/item_entities.py | 6 +--- environments/factory/base/base_factory.py | 12 +++++-- environments/factory/base/objects.py | 12 +++---- environments/factory/base/registers.py | 27 +++++++--------- .../factory_dirt_stationary_machines.py | 4 +-- environments/logging/recorder.py | 32 +++++++++++++++---- quickstart/combine_and_monitor_rerun.py | 18 ++++++----- 19 files changed, 104 insertions(+), 104 deletions(-) diff --git a/algorithms/TSP_dirt_agent.py b/algorithms/TSP_dirt_agent.py index c11c2b5..dcdd6dd 100644 --- a/algorithms/TSP_dirt_agent.py +++ b/algorithms/TSP_dirt_agent.py @@ -11,7 +11,7 @@ from environments.helpers import EnvActions as BaseActions class Constants(BaseConstants): - DIRT = 'Dirt' + DIRT = 'DirtPile' class Actions(BaseActions): diff --git a/environments/factory/additional/btry/btry_collections.py b/environments/factory/additional/btry/btry_collections.py index 350324e..eb7c2cc 100644 --- a/environments/factory/additional/btry/btry_collections.py +++ b/environments/factory/additional/btry/btry_collections.py @@ -2,24 +2,19 @@ from environments.factory.additional.btry.btry_objects import Battery, ChargePod from environments.factory.base.registers import EnvObjectCollection, EntityCollection -class BatteriesRegister(EnvObjectCollection): +class Batteries(EnvObjectCollection): _accepted_objects = Battery def __init__(self, *args, **kwargs): - super(BatteriesRegister, self).__init__(*args, individual_slices=True, - is_blocking_light=False, can_be_shadowed=False, **kwargs) + super(Batteries, self).__init__(*args, individual_slices=True, + is_blocking_light=False, can_be_shadowed=False, **kwargs) self.is_observable = True def spawn_batteries(self, agents, initial_charge_level): batteries = [self._accepted_objects(initial_charge_level, agent, self) for _, agent in enumerate(agents)] self.add_additional_items(batteries) - def summarize_states(self, n_steps=None): - # as dict with additional nesting - # return dict(items=super(Inventories, cls).summarize_states()) - return super(BatteriesRegister, self).summarize_states(n_steps=n_steps) - # Todo Move this to Mixin! def by_entity(self, entity): try: @@ -40,11 +35,7 @@ class BatteriesRegister(EnvObjectCollection): class ChargePods(EntityCollection): _accepted_objects = ChargePod + _stateless_entities = True def __repr__(self): super(ChargePods, self).__repr__() - - def summarize_states(self, n_steps=None): - # as dict with additional nesting - # return dict(items=super(Inventories, cls).summarize_states()) - return super(ChargePods, self).summarize_states(n_steps=n_steps) diff --git a/environments/factory/additional/btry/btry_objects.py b/environments/factory/additional/btry/btry_objects.py index 36cae74..6516a36 100644 --- a/environments/factory/additional/btry/btry_objects.py +++ b/environments/factory/additional/btry/btry_objects.py @@ -35,7 +35,7 @@ class Battery(BoundingMixin, EnvObject): def summarize_state(self, **_): attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'} - attr_dict.update(dict(name=self.name)) + attr_dict.update(dict(name=self.name, belongs_to=self._bound_entity.name)) return attr_dict @@ -58,10 +58,3 @@ class ChargePod(Entity): return c.NOT_VALID valid = battery.do_charge_action(self.charge_rate) return valid - - def summarize_state(self, n_steps=None) -> dict: - if n_steps == h.STEPS_START: - summary = super().summarize_state(n_steps=n_steps) - return summary - else: - {} diff --git a/environments/factory/additional/btry/factory_battery.py b/environments/factory/additional/btry/factory_battery.py index 3271c6e..6cead04 100644 --- a/environments/factory/additional/btry/factory_battery.py +++ b/environments/factory/additional/btry/factory_battery.py @@ -2,7 +2,7 @@ from typing import Dict, List import numpy as np -from environments.factory.additional.btry.btry_collections import BatteriesRegister, ChargePods +from environments.factory.additional.btry.btry_collections import Batteries, ChargePods from environments.factory.additional.btry.btry_util import Constants, Actions, RewardsBtry, BatteryProperties from environments.factory.base.base_factory import BaseFactory from environments.factory.base.objects import Agent, Action @@ -45,8 +45,8 @@ class BatteryFactory(BaseFactory): multi_charge=self.btry_prop.multi_charge) ) - batteries = BatteriesRegister(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2), - ) + batteries = Batteries(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2), + ) batteries.spawn_batteries(self[c.AGENT], self.btry_prop.initial_charge) super_entities.update({c.BATTERIES: batteries, c.CHARGE_PODS: charge_pods}) return super_entities diff --git a/environments/factory/additional/dest/dest_collections.py b/environments/factory/additional/dest/dest_collections.py index 6919898..9bf6ce2 100644 --- a/environments/factory/additional/dest/dest_collections.py +++ b/environments/factory/additional/dest/dest_collections.py @@ -25,9 +25,6 @@ class Destinations(EntityCollection): def __repr__(self): return super(Destinations, self).__repr__() - def summarize_states(self, n_steps=None): - return {} - class ReachedDestinations(Destinations): _accepted_objects = Destination @@ -37,8 +34,5 @@ class ReachedDestinations(Destinations): self.can_be_shadowed = False self.is_blocking_light = False - def summarize_states(self, n_steps=None): - return {} - def __repr__(self): return super(ReachedDestinations, self).__repr__() diff --git a/environments/factory/additional/dest/dest_enitites.py b/environments/factory/additional/dest/dest_enitites.py index 5050a6d..1c45dec 100644 --- a/environments/factory/additional/dest/dest_enitites.py +++ b/environments/factory/additional/dest/dest_enitites.py @@ -38,7 +38,8 @@ class Destination(Entity): def agent_is_dwelling(self, agent: Agent): return self._per_agent_times[agent.name] < self.dwell_time - def summarize_state(self, n_steps=None) -> dict: - state_summary = super().summarize_state(n_steps=n_steps) - state_summary.update(per_agent_times=self._per_agent_times) + def summarize_state(self) -> dict: + state_summary = super().summarize_state() + state_summary.update(per_agent_times=[ + dict(belongs_to=key, time=val) for key, val in self._per_agent_times.keys()], dwell_time=self.dwell_time) return state_summary diff --git a/environments/factory/additional/dirt/dirt_collections.py b/environments/factory/additional/dirt/dirt_collections.py index 79ae2d4..1b56159 100644 --- a/environments/factory/additional/dirt/dirt_collections.py +++ b/environments/factory/additional/dirt/dirt_collections.py @@ -1,13 +1,13 @@ -from environments.factory.additional.dirt.dirt_entity import Dirt +from environments.factory.additional.dirt.dirt_entity import DirtPile from environments.factory.additional.dirt.dirt_util import DirtProperties from environments.factory.base.objects import Floor from environments.factory.base.registers import EntityCollection from environments.factory.additional.dirt.dirt_util import Constants as c -class DirtRegister(EntityCollection): +class DirtPiles(EntityCollection): - _accepted_objects = Dirt + _accepted_objects = DirtPile @property def amount(self): @@ -18,7 +18,7 @@ class DirtRegister(EntityCollection): return self._dirt_properties def __init__(self, dirt_properties, *args): - super(DirtRegister, self).__init__(*args) + super(DirtPiles, self).__init__(*args) self._dirt_properties: DirtProperties = dirt_properties def spawn_dirt(self, then_dirty_tiles) -> bool: @@ -28,7 +28,7 @@ class DirtRegister(EntityCollection): if not self.amount > self.dirt_properties.max_global_amount: dirt = self.by_pos(tile.pos) if dirt is None: - dirt = Dirt(tile, self, amount=self.dirt_properties.max_spawn_amount) + dirt = DirtPile(tile, self, amount=self.dirt_properties.max_spawn_amount) self.add_item(dirt) else: new_value = dirt.amount + self.dirt_properties.max_spawn_amount @@ -38,5 +38,5 @@ class DirtRegister(EntityCollection): return c.VALID def __repr__(self): - s = super(DirtRegister, self).__repr__() + s = super(DirtPiles, self).__repr__() return f'{s[:-1]}, {self.amount})' diff --git a/environments/factory/additional/dirt/dirt_entity.py b/environments/factory/additional/dirt/dirt_entity.py index 69c94ca..f491afb 100644 --- a/environments/factory/additional/dirt/dirt_entity.py +++ b/environments/factory/additional/dirt/dirt_entity.py @@ -1,7 +1,7 @@ from environments.factory.base.objects import Entity -class Dirt(Entity): +class DirtPile(Entity): @property def amount(self): @@ -13,14 +13,14 @@ class Dirt(Entity): return self._amount def __init__(self, *args, amount=None, **kwargs): - super(Dirt, self).__init__(*args, **kwargs) + super(DirtPile, self).__init__(*args, **kwargs) self._amount = amount def set_new_amount(self, amount): self._amount = amount self._collection.notify_change_to_value(self) - def summarize_state(self, **kwargs): - state_dict = super().summarize_state(**kwargs) + def summarize_state(self): + state_dict = super().summarize_state() state_dict.update(amount=float(self.amount)) return state_dict diff --git a/environments/factory/additional/dirt/dirt_util.py b/environments/factory/additional/dirt/dirt_util.py index b5a23bf..2c5dca9 100644 --- a/environments/factory/additional/dirt/dirt_util.py +++ b/environments/factory/additional/dirt/dirt_util.py @@ -4,7 +4,7 @@ from environments.helpers import Constants as BaseConstants, EnvActions as BaseA class Constants(BaseConstants): - DIRT = 'Dirt' + DIRT = 'DirtPile' class Actions(BaseActions): diff --git a/environments/factory/additional/dirt/factory_dirt.py b/environments/factory/additional/dirt/factory_dirt.py index dcdcd9d..6a06a79 100644 --- a/environments/factory/additional/dirt/factory_dirt.py +++ b/environments/factory/additional/dirt/factory_dirt.py @@ -5,8 +5,8 @@ import random import numpy as np -from environments.factory.additional.dirt.dirt_collections import DirtRegister -from environments.factory.additional.dirt.dirt_entity import Dirt +from environments.factory.additional.dirt.dirt_collections import DirtPiles +from environments.factory.additional.dirt.dirt_entity import DirtPile from environments.factory.additional.dirt.dirt_util import Constants, Actions, RewardsDirt, DirtProperties from environments.factory.base.base_factory import BaseFactory @@ -43,7 +43,7 @@ class DirtFactory(BaseFactory): @property def entities_hook(self) -> Dict[(str, Entities)]: super_entities = super().entities_hook - dirt_register = DirtRegister(self.dirt_prop, self._level_shape) + dirt_register = DirtPiles(self.dirt_prop, self._level_shape) super_entities.update(({c.DIRT: dirt_register})) return super_entities @@ -57,7 +57,7 @@ class DirtFactory(BaseFactory): self.dirt_prop = dirt_prop self.rewards_dirt = rewards_dirt self._dirt_rng = np.random.default_rng(env_seed) - self._dirt: DirtRegister + self._dirt: DirtPiles kwargs.update(env_seed=env_seed) # TODO: Reset ---> document this super().__init__(*args, **kwargs) @@ -96,7 +96,7 @@ class DirtFactory(BaseFactory): def trigger_dirt_spawn(self, initial_spawn=False): dirt_rng = self._dirt_rng free_for_dirt = [x for x in self[c.FLOOR] - if len(x.guests) == 0 or (len(x.guests) == 1 and isinstance(next(y for y in x.guests), Dirt)) + if len(x.guests) == 0 or (len(x.guests) == 1 and isinstance(next(y for y in x.guests), DirtPile)) ] self._dirt_rng.shuffle(free_for_dirt) if initial_spawn: @@ -123,7 +123,7 @@ class DirtFactory(BaseFactory): new_pos_dirt = self[c.DIRT].by_pos(agent.pos) new_pos_dirt.set_new_amount(max(0, new_pos_dirt.amount + smeared_dirt)) if self._next_dirt_spawn < 0: - pass # No Dirt Spawn + pass # No DirtPile Spawn elif not self._next_dirt_spawn: self.trigger_dirt_spawn() self._next_dirt_spawn = self.dirt_prop.spawn_frequency diff --git a/environments/factory/additional/item/factory_item.py b/environments/factory/additional/item/factory_item.py index dc6b3e8..ee84a66 100644 --- a/environments/factory/additional/item/factory_item.py +++ b/environments/factory/additional/item/factory_item.py @@ -3,7 +3,7 @@ from typing import List, Union, Dict import numpy as np import random -from environments.factory.additional.item.item_collections import ItemRegister, Inventories, DropOffLocations +from environments.factory.additional.item.item_collections import Items, Inventories, DropOffLocations from environments.factory.additional.item.item_util import Constants, Actions, RewardsItem, ItemProperties from environments.factory.base.base_factory import BaseFactory from environments.factory.base.objects import Agent, Action @@ -49,7 +49,7 @@ class ItemFactory(BaseFactory): entity_kwargs=dict( storage_size_until_full=self.item_prop.max_dropoff_storage_size) ) - item_register = ItemRegister(self._level_shape) + item_register = Items(self._level_shape) empty_tiles = self[c.FLOOR].empty_tiles[:self.item_prop.n_items] item_register.spawn_items(empty_tiles) diff --git a/environments/factory/additional/item/item_collections.py b/environments/factory/additional/item/item_collections.py index 7a12b88..b99fc9b 100644 --- a/environments/factory/additional/item/item_collections.py +++ b/environments/factory/additional/item/item_collections.py @@ -7,7 +7,7 @@ from environments.factory.base.registers import EntityCollection, BoundEnvObjCol from environments.factory.additional.item.item_entities import Item, DropOffLocation -class ItemRegister(EntityCollection): +class Items(EntityCollection): _accepted_objects = Item @@ -37,9 +37,9 @@ class Inventory(BoundEnvObjCollection): return super(Inventory, self).as_array() def summarize_states(self, **kwargs): - attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'} - attr_dict.update(dict(items={key: val.summarize_state(**kwargs) for key, val in self.items()})) - attr_dict.update(dict(name=self.name)) + attr_dict = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'} + attr_dict.update(dict(items=[val.summarize_state(**kwargs) for key, val in self.items()])) + attr_dict.update(dict(name=self.name, belongs_to=self._bound_entity.name)) return attr_dict def pop(self): @@ -79,9 +79,11 @@ class Inventories(ObjectCollection): return None def summarize_states(self, **kwargs): - return {key: val.summarize_states(**kwargs) for key, val in self.items()} + return [val.summarize_states(**kwargs) for key, val in self.items()] class DropOffLocations(EntityCollection): _accepted_objects = DropOffLocation + _stateless_entities = True + diff --git a/environments/factory/additional/item/item_entities.py b/environments/factory/additional/item/item_entities.py index c8566c8..59a9afc 100644 --- a/environments/factory/additional/item/item_entities.py +++ b/environments/factory/additional/item/item_entities.py @@ -26,7 +26,7 @@ class Item(Entity): def set_tile_to(self, no_pos_tile): self._tile = no_pos_tile - def summarize_state(self, **_) -> dict: + def summarize_state(self) -> dict: super_summarization = super(Item, self).summarize_state() super_summarization.update(dict(auto_despawn=self.auto_despawn)) return super_summarization @@ -55,7 +55,3 @@ class DropOffLocation(Entity): @property def is_full(self): return False if not self.storage.maxlen else self.storage.maxlen == len(self.storage) - - def summarize_state(self, n_steps=None) -> dict: - if n_steps == h.STEPS_START: - return super().summarize_state(n_steps=n_steps) diff --git a/environments/factory/base/base_factory.py b/environments/factory/base/base_factory.py index 7233270..5ef9227 100644 --- a/environments/factory/base/base_factory.py +++ b/environments/factory/base/base_factory.py @@ -71,6 +71,12 @@ class BaseFactory(gym.Env): d['class_name'] = self.__class__.__name__ return d + @property + def summarize_header(self): + summary_dict = self._summarize_state(stateless_entities=True) + summary_dict.update(actions=self._actions.summarize()) + return summary_dict + def __enter__(self): return self if self.obs_prop.frames_to_stack == 0 else \ MarlFrameStack(FrameStack(self, self.obs_prop.frames_to_stack)) @@ -665,12 +671,12 @@ class BaseFactory(gym.Env): else: return [] - def _summarize_state(self): + def _summarize_state(self, stateless_entities=False): summary = {f'{REC_TAC}step': self._steps} for entity_group in self._entities: - summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states(n_steps=self._steps)}) - + if entity_group.is_stateless == stateless_entities: + summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states()}) return summary def print(self, string): diff --git a/environments/factory/base/objects.py b/environments/factory/base/objects.py index 0e6fbab..7afa9d7 100644 --- a/environments/factory/base/objects.py +++ b/environments/factory/base/objects.py @@ -86,7 +86,7 @@ class EnvObject(Object): # TODO: Missing Documentation class Entity(EnvObject): - """Full Env Entity that lives on the env Grid. Doors, Items, Dirt etc...""" + """Full Env Entity that lives on the env Grid. Doors, Items, DirtPile etc...""" @property def can_collide(self): @@ -113,7 +113,7 @@ class Entity(EnvObject): self._tile = tile tile.enter(self) - def summarize_state(self, **_) -> dict: + def summarize_state(self) -> dict: return dict(name=str(self.name), x=int(self.x), y=int(self.y), tile=str(self.tile.name), can_collide=bool(self.can_collide)) @@ -338,8 +338,8 @@ class Door(Entity): if not closed_on_init: self._open() - def summarize_state(self, **kwargs): - state_dict = super().summarize_state(**kwargs) + def summarize_state(self): + state_dict = super().summarize_state() state_dict.update(state=str(self.str_state), time_to_close=int(self.time_to_close)) return state_dict @@ -402,7 +402,7 @@ class Agent(MoveableEntity): # if attr.startswith('temp'): self.step_result = None - def summarize_state(self, **kwargs): - state_dict = super().summarize_state(**kwargs) + def summarize_state(self): + state_dict = super().summarize_state() state_dict.update(valid=bool(self.step_result['action_valid']), action=str(self.step_result['action_name'])) return state_dict diff --git a/environments/factory/base/registers.py b/environments/factory/base/registers.py index 5bc800a..2664786 100644 --- a/environments/factory/base/registers.py +++ b/environments/factory/base/registers.py @@ -19,6 +19,11 @@ from environments.helpers import Constants as c class ObjectCollection: _accepted_objects = Object + _stateless_entities = False + + @property + def is_stateless(self): + return self._stateless_entities @property def name(self): @@ -116,8 +121,8 @@ class EnvObjectCollection(ObjectCollection): self._lazy_eval_transforms = [] return self._array - def summarize_states(self, n_steps=None): - return [entity.summarize_state(n_steps=n_steps) for entity in self.values()] + def summarize_states(self): + return [entity.summarize_state() for entity in self.values()] def notify_change_to_free(self, env_object: EnvObject): self._array_change_notifyer(env_object, value=c.FREE_CELL) @@ -290,9 +295,6 @@ class GlobalPositions(EnvObjectCollection): # noinspection PyTypeChecker self.add_additional_items(global_positions) - def summarize_states(self, n_steps=None): - return {} - def idx_by_entity(self, entity): try: return next((idx for idx, inv in enumerate(self) if inv.belongs_to_entity(entity))) @@ -376,6 +378,7 @@ class Entities(ObjectCollection): class Walls(EntityCollection): _accepted_objects = Wall + _stateless_entities = True def as_array(self): if not np.any(self._array): @@ -406,15 +409,10 @@ class Walls(EntityCollection): def from_tiles(cls, tiles, *args, **kwargs): raise RuntimeError() - def summarize_states(self, n_steps=None): - if n_steps == h.STEPS_START: - return super(Walls, self).summarize_states(n_steps=n_steps) - else: - return {} - class Floors(Walls): _accepted_objects = Floor + _stateless_entities = True def __init__(self, *args, is_blocking_light=False, **kwargs): super(Floors, self).__init__(*args, is_blocking_light=is_blocking_light, **kwargs) @@ -436,10 +434,6 @@ class Floors(Walls): def from_tiles(cls, tiles, *args, **kwargs): raise RuntimeError() - def summarize_states(self, n_steps=None): - # Do not summarize - return {} - class Agents(MovingEntityObjectCollection): _accepted_objects = Agent @@ -521,6 +515,9 @@ class Actions(ObjectCollection): def is_moving_action(self, action: Union[int]): return action in self.movement_actions.values() + def summarize(self): + return [dict(name=action.identifier) for action in self] + class Zones(ObjectCollection): diff --git a/environments/factory/factory_dirt_stationary_machines.py b/environments/factory/factory_dirt_stationary_machines.py index 7c79444..3afe15c 100644 --- a/environments/factory/factory_dirt_stationary_machines.py +++ b/environments/factory/factory_dirt_stationary_machines.py @@ -4,8 +4,8 @@ import numpy as np from environments.factory.base.objects import Agent, Entity, Action from environments.factory.factory_dirt import DirtFactory -from environments.factory.additional.dirt.dirt_collections import DirtRegister -from environments.factory.additional.dirt.dirt_entity import Dirt +from environments.factory.additional.dirt.dirt_collections import DirtPiles +from environments.factory.additional.dirt.dirt_entity import DirtPile from environments.factory.base.objects import Floor from environments.factory.base.registers import Floors, Entities, EntityCollection diff --git a/environments/logging/recorder.py b/environments/logging/recorder.py index 9abf683..893a0a3 100644 --- a/environments/logging/recorder.py +++ b/environments/logging/recorder.py @@ -1,3 +1,4 @@ +import warnings from collections import defaultdict from os import PathLike from pathlib import Path @@ -21,6 +22,7 @@ class EnvRecorder(BaseCallback): self._recorder_dict = defaultdict(list) self._recorder_out_list = list() self._episode_counter = 1 + self._do_record_dict = defaultdict(lambda: False) if isinstance(entities, str): if entities.lower() == 'all': self._entities = None @@ -58,7 +60,11 @@ class EnvRecorder(BaseCallback): def step(self, actions): step_result = self.unwrapped.step(actions) - self._on_step() + if self.do_record_episode(0): + info = step_result[-1] + self._read_info(0, info) + if self._do_record_dict[0]: + self._read_done(0, step_result[-2]) return step_result def finalize(self): @@ -71,7 +77,8 @@ class EnvRecorder(BaseCallback): # cls.out_file.unlink(missing_ok=True) with filepath.open('w') as f: out_dict = {'n_episodes': self._episode_counter, - 'header': self.unwrapped.params, + 'env_params': self.unwrapped.params, + 'header': self.unwrapped.summarize_header, 'episodes': self._recorder_out_list } try: @@ -99,18 +106,29 @@ class EnvRecorder(BaseCallback): if save_trajectory_map: raise NotImplementedError('This has not yet been implemented.') + def do_record_episode(self, env_idx): + if not self._recorder_dict[env_idx]: + if self.freq: + self._do_record_dict[env_idx] = (self.freq == -1) or (self._episode_counter % self.freq) == 0 + else: + self._do_record_dict[env_idx] = False + warnings.warn('You did wrap your Environment with a recorder, but set the freq to zero\n' + 'Nothing will be recorded') + self._episode_counter += 1 + else: + pass + return self._do_record_dict[env_idx] + def _on_step(self) -> bool: - do_record = self.freq == -1 or self._episode_counter % self.freq == 0 for env_idx, info in enumerate(self.locals.get('infos', [])): - if do_record: + if self._do_record_dict[env_idx]: self._read_info(env_idx, info) dones = list(enumerate(self.locals.get('dones', []))) dones.extend(list(enumerate(self.locals.get('done', [])))) for env_idx, done in dones: - if do_record: + if self._do_record_dict[env_idx]: self._read_done(env_idx, done) - if done: - self._episode_counter += 1 + return True def _on_training_end(self) -> None: diff --git a/quickstart/combine_and_monitor_rerun.py b/quickstart/combine_and_monitor_rerun.py index 6c6aea8..031f26b 100644 --- a/quickstart/combine_and_monitor_rerun.py +++ b/quickstart/combine_and_monitor_rerun.py @@ -60,6 +60,7 @@ def encapsule_env_factory(env_fctry, env_kwrgs): if __name__ == '__main__': + render = False # Define Global Env Parameters # Define properties object parameters factory_kwargs = dict( @@ -140,17 +141,17 @@ if __name__ == '__main__': combined_env_kwargs[key] = val else: assert combined_env_kwargs[key] == val + del combined_env_kwargs['key'] combined_env_kwargs.update(n_agents=len(comparable_runs)) - - with(type("CombinedEnv", tuple(available_envs.values()), {})(**combined_env_kwargs)) as combEnv: + with type("CombinedEnv", tuple(available_envs.values()), {})(**combined_env_kwargs) as combEnv: # EnvMonitor Init comb = f'comb_{model_name}_{seed}' comb_monitor_path = combinations_path / comb / f'{comb}_monitor.pick' - comb_recorder_path = combinations_path / comb / f'{comb}_recorder.pick' + comb_recorder_path = combinations_path / comb / f'{comb}_recorder.json' comb_monitor_path.parent.mkdir(parents=True, exist_ok=True) monitoredCombEnv = EnvMonitor(combEnv, filepath=comb_monitor_path) - # monitoredCombEnv = EnvRecorder(monitoredCombEnv, filepath=comb_monitor_path) + monitoredCombEnv = EnvRecorder(monitoredCombEnv, filepath=comb_recorder_path, freq=1) # Evaluation starts here ##################################################### # Load all models @@ -164,8 +165,9 @@ if __name__ == '__main__': *(agent.named_action_space for agent in loaded_models) ) - for episode in range(50): - obs, _ = monitoredCombEnv.reset(), monitoredCombEnv.render() + for episode in range(1): + obs = monitoredCombEnv.reset() + if render: monitoredCombEnv.render() rew, done_bool = 0, False while not done_bool: actions = [] @@ -176,12 +178,12 @@ if __name__ == '__main__': obs, step_r, done_bool, info_obj = monitoredCombEnv.step(actions) rew += step_r - monitoredCombEnv.render() + if render: monitoredCombEnv.render() if done_bool: break print(f'Factory run {episode} done, reward is:\n {rew}') # Eval monitor outputs are automatically stored by the monitor object # TODO: Plotting - monitoredCombEnv.save_records(comb_monitor_path) + monitoredCombEnv.save_records() monitoredCombEnv.save_run() pass