mirror of
				https://github.com/illiumst/marl-factory-grid.git
				synced 2025-10-31 04:37:25 +01:00 
			
		
		
		
	recorder fixed
This commit is contained in:
		| @@ -11,7 +11,7 @@ from environments.helpers import EnvActions as BaseActions | ||||
|  | ||||
|  | ||||
| class Constants(BaseConstants): | ||||
|     DIRT = 'Dirt' | ||||
|     DIRT = 'DirtPile' | ||||
|  | ||||
|  | ||||
| class Actions(BaseActions): | ||||
|   | ||||
| @@ -2,24 +2,19 @@ from environments.factory.additional.btry.btry_objects import Battery, ChargePod | ||||
| from environments.factory.base.registers import EnvObjectCollection, EntityCollection | ||||
|  | ||||
|  | ||||
| class BatteriesRegister(EnvObjectCollection): | ||||
| class Batteries(EnvObjectCollection): | ||||
|  | ||||
|     _accepted_objects = Battery | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super(BatteriesRegister, self).__init__(*args, individual_slices=True, | ||||
|                                                 is_blocking_light=False, can_be_shadowed=False, **kwargs) | ||||
|         super(Batteries, self).__init__(*args, individual_slices=True, | ||||
|                                         is_blocking_light=False, can_be_shadowed=False, **kwargs) | ||||
|         self.is_observable = True | ||||
|  | ||||
|     def spawn_batteries(self, agents, initial_charge_level): | ||||
|         batteries = [self._accepted_objects(initial_charge_level, agent, self) for _, agent in enumerate(agents)] | ||||
|         self.add_additional_items(batteries) | ||||
|  | ||||
|     def summarize_states(self, n_steps=None): | ||||
|         # as dict with additional nesting | ||||
|         # return dict(items=super(Inventories, cls).summarize_states()) | ||||
|         return super(BatteriesRegister, self).summarize_states(n_steps=n_steps) | ||||
|  | ||||
|     # Todo Move this to Mixin! | ||||
|     def by_entity(self, entity): | ||||
|         try: | ||||
| @@ -40,11 +35,7 @@ class BatteriesRegister(EnvObjectCollection): | ||||
| class ChargePods(EntityCollection): | ||||
|  | ||||
|     _accepted_objects = ChargePod | ||||
|     _stateless_entities = True | ||||
|  | ||||
|     def __repr__(self): | ||||
|         super(ChargePods, self).__repr__() | ||||
|  | ||||
|     def summarize_states(self, n_steps=None): | ||||
|         # as dict with additional nesting | ||||
|         # return dict(items=super(Inventories, cls).summarize_states()) | ||||
|         return super(ChargePods, self).summarize_states(n_steps=n_steps) | ||||
|   | ||||
| @@ -35,7 +35,7 @@ class Battery(BoundingMixin, EnvObject): | ||||
|  | ||||
|     def summarize_state(self, **_): | ||||
|         attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'} | ||||
|         attr_dict.update(dict(name=self.name)) | ||||
|         attr_dict.update(dict(name=self.name, belongs_to=self._bound_entity.name)) | ||||
|         return attr_dict | ||||
|  | ||||
|  | ||||
| @@ -58,10 +58,3 @@ class ChargePod(Entity): | ||||
|             return c.NOT_VALID | ||||
|         valid = battery.do_charge_action(self.charge_rate) | ||||
|         return valid | ||||
|  | ||||
|     def summarize_state(self, n_steps=None) -> dict: | ||||
|         if n_steps == h.STEPS_START: | ||||
|             summary = super().summarize_state(n_steps=n_steps) | ||||
|             return summary | ||||
|         else: | ||||
|             {} | ||||
|   | ||||
| @@ -2,7 +2,7 @@ from typing import Dict, List | ||||
|  | ||||
| import numpy as np | ||||
|  | ||||
| from environments.factory.additional.btry.btry_collections import BatteriesRegister, ChargePods | ||||
| from environments.factory.additional.btry.btry_collections import Batteries, ChargePods | ||||
| from environments.factory.additional.btry.btry_util import Constants, Actions, RewardsBtry, BatteryProperties | ||||
| from environments.factory.base.base_factory import BaseFactory | ||||
| from environments.factory.base.objects import Agent, Action | ||||
| @@ -45,8 +45,8 @@ class BatteryFactory(BaseFactory): | ||||
|                                multi_charge=self.btry_prop.multi_charge) | ||||
|         ) | ||||
|  | ||||
|         batteries = BatteriesRegister(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2), | ||||
|                                       ) | ||||
|         batteries = Batteries(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2), | ||||
|                               ) | ||||
|         batteries.spawn_batteries(self[c.AGENT], self.btry_prop.initial_charge) | ||||
|         super_entities.update({c.BATTERIES: batteries, c.CHARGE_PODS: charge_pods}) | ||||
|         return super_entities | ||||
|   | ||||
| @@ -25,9 +25,6 @@ class Destinations(EntityCollection): | ||||
|     def __repr__(self): | ||||
|         return super(Destinations, self).__repr__() | ||||
|  | ||||
|     def summarize_states(self, n_steps=None): | ||||
|         return {} | ||||
|  | ||||
|  | ||||
| class ReachedDestinations(Destinations): | ||||
|     _accepted_objects = Destination | ||||
| @@ -37,8 +34,5 @@ class ReachedDestinations(Destinations): | ||||
|         self.can_be_shadowed = False | ||||
|         self.is_blocking_light = False | ||||
|  | ||||
|     def summarize_states(self, n_steps=None): | ||||
|         return {} | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return super(ReachedDestinations, self).__repr__() | ||||
|   | ||||
| @@ -38,7 +38,8 @@ class Destination(Entity): | ||||
|     def agent_is_dwelling(self, agent: Agent): | ||||
|         return self._per_agent_times[agent.name] < self.dwell_time | ||||
|  | ||||
|     def summarize_state(self, n_steps=None) -> dict: | ||||
|         state_summary = super().summarize_state(n_steps=n_steps) | ||||
|         state_summary.update(per_agent_times=self._per_agent_times) | ||||
|     def summarize_state(self) -> dict: | ||||
|         state_summary = super().summarize_state() | ||||
|         state_summary.update(per_agent_times=[ | ||||
|             dict(belongs_to=key, time=val) for key, val in self._per_agent_times.keys()], dwell_time=self.dwell_time) | ||||
|         return state_summary | ||||
|   | ||||
| @@ -1,13 +1,13 @@ | ||||
| from environments.factory.additional.dirt.dirt_entity import Dirt | ||||
| from environments.factory.additional.dirt.dirt_entity import DirtPile | ||||
| from environments.factory.additional.dirt.dirt_util import DirtProperties | ||||
| from environments.factory.base.objects import Floor | ||||
| from environments.factory.base.registers import EntityCollection | ||||
| from environments.factory.additional.dirt.dirt_util import Constants as c | ||||
|  | ||||
|  | ||||
| class DirtRegister(EntityCollection): | ||||
| class DirtPiles(EntityCollection): | ||||
|  | ||||
|     _accepted_objects = Dirt | ||||
|     _accepted_objects = DirtPile | ||||
|  | ||||
|     @property | ||||
|     def amount(self): | ||||
| @@ -18,7 +18,7 @@ class DirtRegister(EntityCollection): | ||||
|         return self._dirt_properties | ||||
|  | ||||
|     def __init__(self, dirt_properties, *args): | ||||
|         super(DirtRegister, self).__init__(*args) | ||||
|         super(DirtPiles, self).__init__(*args) | ||||
|         self._dirt_properties: DirtProperties = dirt_properties | ||||
|  | ||||
|     def spawn_dirt(self, then_dirty_tiles) -> bool: | ||||
| @@ -28,7 +28,7 @@ class DirtRegister(EntityCollection): | ||||
|             if not self.amount > self.dirt_properties.max_global_amount: | ||||
|                 dirt = self.by_pos(tile.pos) | ||||
|                 if dirt is None: | ||||
|                     dirt = Dirt(tile, self, amount=self.dirt_properties.max_spawn_amount) | ||||
|                     dirt = DirtPile(tile, self, amount=self.dirt_properties.max_spawn_amount) | ||||
|                     self.add_item(dirt) | ||||
|                 else: | ||||
|                     new_value = dirt.amount + self.dirt_properties.max_spawn_amount | ||||
| @@ -38,5 +38,5 @@ class DirtRegister(EntityCollection): | ||||
|         return c.VALID | ||||
|  | ||||
|     def __repr__(self): | ||||
|         s = super(DirtRegister, self).__repr__() | ||||
|         s = super(DirtPiles, self).__repr__() | ||||
|         return f'{s[:-1]}, {self.amount})' | ||||
|   | ||||
| @@ -1,7 +1,7 @@ | ||||
| from environments.factory.base.objects import Entity | ||||
|  | ||||
|  | ||||
| class Dirt(Entity): | ||||
| class DirtPile(Entity): | ||||
|  | ||||
|     @property | ||||
|     def amount(self): | ||||
| @@ -13,14 +13,14 @@ class Dirt(Entity): | ||||
|         return self._amount | ||||
|  | ||||
|     def __init__(self, *args, amount=None, **kwargs): | ||||
|         super(Dirt, self).__init__(*args, **kwargs) | ||||
|         super(DirtPile, self).__init__(*args, **kwargs) | ||||
|         self._amount = amount | ||||
|  | ||||
|     def set_new_amount(self, amount): | ||||
|         self._amount = amount | ||||
|         self._collection.notify_change_to_value(self) | ||||
|  | ||||
|     def summarize_state(self, **kwargs): | ||||
|         state_dict = super().summarize_state(**kwargs) | ||||
|     def summarize_state(self): | ||||
|         state_dict = super().summarize_state() | ||||
|         state_dict.update(amount=float(self.amount)) | ||||
|         return state_dict | ||||
|   | ||||
| @@ -4,7 +4,7 @@ from environments.helpers import Constants as BaseConstants, EnvActions as BaseA | ||||
|  | ||||
|  | ||||
| class Constants(BaseConstants): | ||||
|     DIRT = 'Dirt' | ||||
|     DIRT = 'DirtPile' | ||||
|  | ||||
|  | ||||
| class Actions(BaseActions): | ||||
|   | ||||
| @@ -5,8 +5,8 @@ import random | ||||
|  | ||||
| import numpy as np | ||||
|  | ||||
| from environments.factory.additional.dirt.dirt_collections import DirtRegister | ||||
| from environments.factory.additional.dirt.dirt_entity import Dirt | ||||
| from environments.factory.additional.dirt.dirt_collections import DirtPiles | ||||
| from environments.factory.additional.dirt.dirt_entity import DirtPile | ||||
| from environments.factory.additional.dirt.dirt_util import Constants, Actions, RewardsDirt, DirtProperties | ||||
|  | ||||
| from environments.factory.base.base_factory import BaseFactory | ||||
| @@ -43,7 +43,7 @@ class DirtFactory(BaseFactory): | ||||
|     @property | ||||
|     def entities_hook(self) -> Dict[(str, Entities)]: | ||||
|         super_entities = super().entities_hook | ||||
|         dirt_register = DirtRegister(self.dirt_prop, self._level_shape) | ||||
|         dirt_register = DirtPiles(self.dirt_prop, self._level_shape) | ||||
|         super_entities.update(({c.DIRT: dirt_register})) | ||||
|         return super_entities | ||||
|  | ||||
| @@ -57,7 +57,7 @@ class DirtFactory(BaseFactory): | ||||
|         self.dirt_prop = dirt_prop | ||||
|         self.rewards_dirt = rewards_dirt | ||||
|         self._dirt_rng = np.random.default_rng(env_seed) | ||||
|         self._dirt: DirtRegister | ||||
|         self._dirt: DirtPiles | ||||
|         kwargs.update(env_seed=env_seed) | ||||
|         # TODO: Reset ---> document this | ||||
|         super().__init__(*args, **kwargs) | ||||
| @@ -96,7 +96,7 @@ class DirtFactory(BaseFactory): | ||||
|     def trigger_dirt_spawn(self, initial_spawn=False): | ||||
|         dirt_rng = self._dirt_rng | ||||
|         free_for_dirt = [x for x in self[c.FLOOR] | ||||
|                          if len(x.guests) == 0 or (len(x.guests) == 1 and isinstance(next(y for y in x.guests), Dirt)) | ||||
|                          if len(x.guests) == 0 or (len(x.guests) == 1 and isinstance(next(y for y in x.guests), DirtPile)) | ||||
|                          ] | ||||
|         self._dirt_rng.shuffle(free_for_dirt) | ||||
|         if initial_spawn: | ||||
| @@ -123,7 +123,7 @@ class DirtFactory(BaseFactory): | ||||
|                                         new_pos_dirt = self[c.DIRT].by_pos(agent.pos) | ||||
|                                         new_pos_dirt.set_new_amount(max(0, new_pos_dirt.amount + smeared_dirt)) | ||||
|         if self._next_dirt_spawn < 0: | ||||
|             pass  # No Dirt Spawn | ||||
|             pass  # No DirtPile Spawn | ||||
|         elif not self._next_dirt_spawn: | ||||
|             self.trigger_dirt_spawn() | ||||
|             self._next_dirt_spawn = self.dirt_prop.spawn_frequency | ||||
|   | ||||
| @@ -3,7 +3,7 @@ from typing import List, Union, Dict | ||||
| import numpy as np | ||||
| import random | ||||
|  | ||||
| from environments.factory.additional.item.item_collections import ItemRegister, Inventories, DropOffLocations | ||||
| from environments.factory.additional.item.item_collections import Items, Inventories, DropOffLocations | ||||
| from environments.factory.additional.item.item_util import Constants, Actions, RewardsItem, ItemProperties | ||||
| from environments.factory.base.base_factory import BaseFactory | ||||
| from environments.factory.base.objects import Agent, Action | ||||
| @@ -49,7 +49,7 @@ class ItemFactory(BaseFactory): | ||||
|             entity_kwargs=dict( | ||||
|                 storage_size_until_full=self.item_prop.max_dropoff_storage_size) | ||||
|         ) | ||||
|         item_register = ItemRegister(self._level_shape) | ||||
|         item_register = Items(self._level_shape) | ||||
|         empty_tiles = self[c.FLOOR].empty_tiles[:self.item_prop.n_items] | ||||
|         item_register.spawn_items(empty_tiles) | ||||
|  | ||||
|   | ||||
| @@ -7,7 +7,7 @@ from environments.factory.base.registers import EntityCollection, BoundEnvObjCol | ||||
| from environments.factory.additional.item.item_entities import Item, DropOffLocation | ||||
|  | ||||
|  | ||||
| class ItemRegister(EntityCollection): | ||||
| class Items(EntityCollection): | ||||
|  | ||||
|     _accepted_objects = Item | ||||
|  | ||||
| @@ -37,9 +37,9 @@ class Inventory(BoundEnvObjCollection): | ||||
|         return super(Inventory, self).as_array() | ||||
|  | ||||
|     def summarize_states(self, **kwargs): | ||||
|         attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'} | ||||
|         attr_dict.update(dict(items={key: val.summarize_state(**kwargs) for key, val in self.items()})) | ||||
|         attr_dict.update(dict(name=self.name)) | ||||
|         attr_dict = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'} | ||||
|         attr_dict.update(dict(items=[val.summarize_state(**kwargs) for key, val in self.items()])) | ||||
|         attr_dict.update(dict(name=self.name, belongs_to=self._bound_entity.name)) | ||||
|         return attr_dict | ||||
|  | ||||
|     def pop(self): | ||||
| @@ -79,9 +79,11 @@ class Inventories(ObjectCollection): | ||||
|             return None | ||||
|  | ||||
|     def summarize_states(self, **kwargs): | ||||
|         return {key: val.summarize_states(**kwargs) for key, val in self.items()} | ||||
|         return [val.summarize_states(**kwargs) for key, val in self.items()] | ||||
|  | ||||
|  | ||||
| class DropOffLocations(EntityCollection): | ||||
|  | ||||
|     _accepted_objects = DropOffLocation | ||||
|     _stateless_entities = True | ||||
|  | ||||
|   | ||||
| @@ -26,7 +26,7 @@ class Item(Entity): | ||||
|     def set_tile_to(self, no_pos_tile): | ||||
|         self._tile = no_pos_tile | ||||
|  | ||||
|     def summarize_state(self, **_) -> dict: | ||||
|     def summarize_state(self) -> dict: | ||||
|         super_summarization = super(Item, self).summarize_state() | ||||
|         super_summarization.update(dict(auto_despawn=self.auto_despawn)) | ||||
|         return super_summarization | ||||
| @@ -55,7 +55,3 @@ class DropOffLocation(Entity): | ||||
|     @property | ||||
|     def is_full(self): | ||||
|         return False if not self.storage.maxlen else self.storage.maxlen == len(self.storage) | ||||
|  | ||||
|     def summarize_state(self, n_steps=None) -> dict: | ||||
|         if n_steps == h.STEPS_START: | ||||
|             return super().summarize_state(n_steps=n_steps) | ||||
|   | ||||
| @@ -71,6 +71,12 @@ class BaseFactory(gym.Env): | ||||
|         d['class_name'] = self.__class__.__name__ | ||||
|         return d | ||||
|  | ||||
|     @property | ||||
|     def summarize_header(self): | ||||
|         summary_dict = self._summarize_state(stateless_entities=True) | ||||
|         summary_dict.update(actions=self._actions.summarize()) | ||||
|         return summary_dict | ||||
|  | ||||
|     def __enter__(self): | ||||
|         return self if self.obs_prop.frames_to_stack == 0 else \ | ||||
|             MarlFrameStack(FrameStack(self, self.obs_prop.frames_to_stack)) | ||||
| @@ -665,12 +671,12 @@ class BaseFactory(gym.Env): | ||||
|         else: | ||||
|             return [] | ||||
|  | ||||
|     def _summarize_state(self): | ||||
|     def _summarize_state(self, stateless_entities=False): | ||||
|         summary = {f'{REC_TAC}step': self._steps} | ||||
|  | ||||
|         for entity_group in self._entities: | ||||
|             summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states(n_steps=self._steps)}) | ||||
|  | ||||
|             if entity_group.is_stateless == stateless_entities: | ||||
|                 summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states()}) | ||||
|         return summary | ||||
|  | ||||
|     def print(self, string): | ||||
|   | ||||
| @@ -86,7 +86,7 @@ class EnvObject(Object): | ||||
|  | ||||
| # TODO: Missing Documentation | ||||
| class Entity(EnvObject): | ||||
|     """Full Env Entity that lives on the env Grid. Doors, Items, Dirt etc...""" | ||||
|     """Full Env Entity that lives on the env Grid. Doors, Items, DirtPile etc...""" | ||||
|  | ||||
|     @property | ||||
|     def can_collide(self): | ||||
| @@ -113,7 +113,7 @@ class Entity(EnvObject): | ||||
|         self._tile = tile | ||||
|         tile.enter(self) | ||||
|  | ||||
|     def summarize_state(self, **_) -> dict: | ||||
|     def summarize_state(self) -> dict: | ||||
|         return dict(name=str(self.name), x=int(self.x), y=int(self.y), | ||||
|                     tile=str(self.tile.name), can_collide=bool(self.can_collide)) | ||||
|  | ||||
| @@ -338,8 +338,8 @@ class Door(Entity): | ||||
|         if not closed_on_init: | ||||
|             self._open() | ||||
|  | ||||
|     def summarize_state(self, **kwargs): | ||||
|         state_dict = super().summarize_state(**kwargs) | ||||
|     def summarize_state(self): | ||||
|         state_dict = super().summarize_state() | ||||
|         state_dict.update(state=str(self.str_state), time_to_close=int(self.time_to_close)) | ||||
|         return state_dict | ||||
|  | ||||
| @@ -402,7 +402,7 @@ class Agent(MoveableEntity): | ||||
|         #   if attr.startswith('temp'): | ||||
|         self.step_result = None | ||||
|  | ||||
|     def summarize_state(self, **kwargs): | ||||
|         state_dict = super().summarize_state(**kwargs) | ||||
|     def summarize_state(self): | ||||
|         state_dict = super().summarize_state() | ||||
|         state_dict.update(valid=bool(self.step_result['action_valid']), action=str(self.step_result['action_name'])) | ||||
|         return state_dict | ||||
|   | ||||
| @@ -19,6 +19,11 @@ from environments.helpers import Constants as c | ||||
|  | ||||
| class ObjectCollection: | ||||
|     _accepted_objects = Object | ||||
|     _stateless_entities = False | ||||
|  | ||||
|     @property | ||||
|     def is_stateless(self): | ||||
|         return self._stateless_entities | ||||
|  | ||||
|     @property | ||||
|     def name(self): | ||||
| @@ -116,8 +121,8 @@ class EnvObjectCollection(ObjectCollection): | ||||
|             self._lazy_eval_transforms = [] | ||||
|         return self._array | ||||
|  | ||||
|     def summarize_states(self, n_steps=None): | ||||
|         return [entity.summarize_state(n_steps=n_steps) for entity in self.values()] | ||||
|     def summarize_states(self): | ||||
|         return [entity.summarize_state() for entity in self.values()] | ||||
|  | ||||
|     def notify_change_to_free(self, env_object: EnvObject): | ||||
|         self._array_change_notifyer(env_object, value=c.FREE_CELL) | ||||
| @@ -290,9 +295,6 @@ class GlobalPositions(EnvObjectCollection): | ||||
|         # noinspection PyTypeChecker | ||||
|         self.add_additional_items(global_positions) | ||||
|  | ||||
|     def summarize_states(self, n_steps=None): | ||||
|         return {} | ||||
|  | ||||
|     def idx_by_entity(self, entity): | ||||
|         try: | ||||
|             return next((idx for idx, inv in enumerate(self) if inv.belongs_to_entity(entity))) | ||||
| @@ -376,6 +378,7 @@ class Entities(ObjectCollection): | ||||
|  | ||||
| class Walls(EntityCollection): | ||||
|     _accepted_objects = Wall | ||||
|     _stateless_entities = True | ||||
|  | ||||
|     def as_array(self): | ||||
|         if not np.any(self._array): | ||||
| @@ -406,15 +409,10 @@ class Walls(EntityCollection): | ||||
|     def from_tiles(cls, tiles, *args, **kwargs): | ||||
|         raise RuntimeError() | ||||
|  | ||||
|     def summarize_states(self, n_steps=None): | ||||
|         if n_steps == h.STEPS_START: | ||||
|             return super(Walls, self).summarize_states(n_steps=n_steps) | ||||
|         else: | ||||
|             return {} | ||||
|  | ||||
|  | ||||
| class Floors(Walls): | ||||
|     _accepted_objects = Floor | ||||
|     _stateless_entities = True | ||||
|  | ||||
|     def __init__(self, *args, is_blocking_light=False, **kwargs): | ||||
|         super(Floors, self).__init__(*args, is_blocking_light=is_blocking_light, **kwargs) | ||||
| @@ -436,10 +434,6 @@ class Floors(Walls): | ||||
|     def from_tiles(cls, tiles, *args, **kwargs): | ||||
|         raise RuntimeError() | ||||
|  | ||||
|     def summarize_states(self, n_steps=None): | ||||
|         # Do not summarize | ||||
|         return {} | ||||
|  | ||||
|  | ||||
| class Agents(MovingEntityObjectCollection): | ||||
|     _accepted_objects = Agent | ||||
| @@ -521,6 +515,9 @@ class Actions(ObjectCollection): | ||||
|     def is_moving_action(self, action: Union[int]): | ||||
|         return action in self.movement_actions.values() | ||||
|  | ||||
|     def summarize(self): | ||||
|         return [dict(name=action.identifier) for action in self] | ||||
|  | ||||
|  | ||||
| class Zones(ObjectCollection): | ||||
|  | ||||
|   | ||||
| @@ -4,8 +4,8 @@ import numpy as np | ||||
|  | ||||
| from environments.factory.base.objects import Agent, Entity, Action | ||||
| from environments.factory.factory_dirt import DirtFactory | ||||
| from environments.factory.additional.dirt.dirt_collections import DirtRegister | ||||
| from environments.factory.additional.dirt.dirt_entity import Dirt | ||||
| from environments.factory.additional.dirt.dirt_collections import DirtPiles | ||||
| from environments.factory.additional.dirt.dirt_entity import DirtPile | ||||
| from environments.factory.base.objects import Floor | ||||
| from environments.factory.base.registers import Floors, Entities, EntityCollection | ||||
|  | ||||
|   | ||||
| @@ -1,3 +1,4 @@ | ||||
| import warnings | ||||
| from collections import defaultdict | ||||
| from os import PathLike | ||||
| from pathlib import Path | ||||
| @@ -21,6 +22,7 @@ class EnvRecorder(BaseCallback): | ||||
|         self._recorder_dict = defaultdict(list) | ||||
|         self._recorder_out_list = list() | ||||
|         self._episode_counter = 1 | ||||
|         self._do_record_dict = defaultdict(lambda: False) | ||||
|         if isinstance(entities, str): | ||||
|             if entities.lower() == 'all': | ||||
|                 self._entities = None | ||||
| @@ -58,7 +60,11 @@ class EnvRecorder(BaseCallback): | ||||
|  | ||||
|     def step(self, actions): | ||||
|         step_result = self.unwrapped.step(actions) | ||||
|         self._on_step() | ||||
|         if self.do_record_episode(0): | ||||
|             info = step_result[-1] | ||||
|             self._read_info(0, info) | ||||
|         if self._do_record_dict[0]: | ||||
|             self._read_done(0, step_result[-2]) | ||||
|         return step_result | ||||
|  | ||||
|     def finalize(self): | ||||
| @@ -71,7 +77,8 @@ class EnvRecorder(BaseCallback): | ||||
|         # cls.out_file.unlink(missing_ok=True) | ||||
|         with filepath.open('w') as f: | ||||
|             out_dict = {'n_episodes': self._episode_counter, | ||||
|                         'header': self.unwrapped.params, | ||||
|                         'env_params': self.unwrapped.params, | ||||
|                         'header': self.unwrapped.summarize_header, | ||||
|                         'episodes': self._recorder_out_list | ||||
|                         } | ||||
|             try: | ||||
| @@ -99,18 +106,29 @@ class EnvRecorder(BaseCallback): | ||||
|         if save_trajectory_map: | ||||
|             raise NotImplementedError('This has not yet been implemented.') | ||||
|  | ||||
|     def do_record_episode(self, env_idx): | ||||
|         if not self._recorder_dict[env_idx]: | ||||
|             if self.freq: | ||||
|                 self._do_record_dict[env_idx] = (self.freq == -1) or (self._episode_counter % self.freq) == 0 | ||||
|             else: | ||||
|                 self._do_record_dict[env_idx] = False | ||||
|                 warnings.warn('You did wrap your Environment with a recorder, but set the freq to zero\n' | ||||
|                               'Nothing will be recorded') | ||||
|             self._episode_counter += 1 | ||||
|         else: | ||||
|             pass | ||||
|         return self._do_record_dict[env_idx] | ||||
|  | ||||
|     def _on_step(self) -> bool: | ||||
|         do_record = self.freq == -1 or self._episode_counter % self.freq == 0 | ||||
|         for env_idx, info in enumerate(self.locals.get('infos', [])): | ||||
|             if do_record: | ||||
|             if self._do_record_dict[env_idx]: | ||||
|                 self._read_info(env_idx, info) | ||||
|         dones = list(enumerate(self.locals.get('dones', []))) | ||||
|         dones.extend(list(enumerate(self.locals.get('done', [])))) | ||||
|         for env_idx, done in dones: | ||||
|             if do_record: | ||||
|             if self._do_record_dict[env_idx]: | ||||
|                 self._read_done(env_idx, done) | ||||
|             if done: | ||||
|                 self._episode_counter += 1 | ||||
|  | ||||
|         return True | ||||
|  | ||||
|     def _on_training_end(self) -> None: | ||||
|   | ||||
| @@ -60,6 +60,7 @@ def encapsule_env_factory(env_fctry, env_kwrgs): | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|  | ||||
|     render = False | ||||
|     # Define Global Env Parameters | ||||
|     # Define properties object parameters | ||||
|     factory_kwargs = dict( | ||||
| @@ -140,17 +141,17 @@ if __name__ == '__main__': | ||||
|                     combined_env_kwargs[key] = val | ||||
|                 else: | ||||
|                     assert combined_env_kwargs[key] == val | ||||
|             del combined_env_kwargs['key'] | ||||
|             combined_env_kwargs.update(n_agents=len(comparable_runs)) | ||||
|  | ||||
|             with(type("CombinedEnv", tuple(available_envs.values()), {})(**combined_env_kwargs)) as combEnv: | ||||
|             with type("CombinedEnv", tuple(available_envs.values()), {})(**combined_env_kwargs) as combEnv: | ||||
|                 # EnvMonitor Init | ||||
|                 comb = f'comb_{model_name}_{seed}' | ||||
|                 comb_monitor_path = combinations_path / comb / f'{comb}_monitor.pick' | ||||
|                 comb_recorder_path = combinations_path / comb / f'{comb}_recorder.pick' | ||||
|                 comb_recorder_path = combinations_path / comb / f'{comb}_recorder.json' | ||||
|                 comb_monitor_path.parent.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|                 monitoredCombEnv = EnvMonitor(combEnv, filepath=comb_monitor_path) | ||||
|                 # monitoredCombEnv = EnvRecorder(monitoredCombEnv, filepath=comb_monitor_path) | ||||
|                 monitoredCombEnv = EnvRecorder(monitoredCombEnv, filepath=comb_recorder_path, freq=1) | ||||
|  | ||||
|                 # Evaluation starts here ##################################################### | ||||
|                 # Load all models | ||||
| @@ -164,8 +165,9 @@ if __name__ == '__main__': | ||||
|                     *(agent.named_action_space for agent in loaded_models) | ||||
|                 ) | ||||
|  | ||||
|                 for episode in range(50): | ||||
|                     obs, _ = monitoredCombEnv.reset(), monitoredCombEnv.render() | ||||
|                 for episode in range(1): | ||||
|                     obs = monitoredCombEnv.reset() | ||||
|                     if render: monitoredCombEnv.render() | ||||
|                     rew, done_bool = 0, False | ||||
|                     while not done_bool: | ||||
|                         actions = [] | ||||
| @@ -176,12 +178,12 @@ if __name__ == '__main__': | ||||
|                         obs, step_r, done_bool, info_obj = monitoredCombEnv.step(actions) | ||||
|  | ||||
|                         rew += step_r | ||||
|                         monitoredCombEnv.render() | ||||
|                         if render: monitoredCombEnv.render() | ||||
|                         if done_bool: | ||||
|                             break | ||||
|                     print(f'Factory run {episode} done, reward is:\n    {rew}') | ||||
|                 # Eval monitor outputs are automatically stored by the monitor object | ||||
|                 # TODO: Plotting | ||||
|                 monitoredCombEnv.save_records(comb_monitor_path) | ||||
|                 monitoredCombEnv.save_records() | ||||
|                 monitoredCombEnv.save_run() | ||||
|             pass | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Steffen Illium
					Steffen Illium