recorder fixed
This commit is contained in:
parent
6a24e7b518
commit
4f3924d3ab
@ -11,7 +11,7 @@ from environments.helpers import EnvActions as BaseActions
|
|||||||
|
|
||||||
|
|
||||||
class Constants(BaseConstants):
|
class Constants(BaseConstants):
|
||||||
DIRT = 'Dirt'
|
DIRT = 'DirtPile'
|
||||||
|
|
||||||
|
|
||||||
class Actions(BaseActions):
|
class Actions(BaseActions):
|
||||||
|
@ -2,24 +2,19 @@ from environments.factory.additional.btry.btry_objects import Battery, ChargePod
|
|||||||
from environments.factory.base.registers import EnvObjectCollection, EntityCollection
|
from environments.factory.base.registers import EnvObjectCollection, EntityCollection
|
||||||
|
|
||||||
|
|
||||||
class BatteriesRegister(EnvObjectCollection):
|
class Batteries(EnvObjectCollection):
|
||||||
|
|
||||||
_accepted_objects = Battery
|
_accepted_objects = Battery
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(BatteriesRegister, self).__init__(*args, individual_slices=True,
|
super(Batteries, self).__init__(*args, individual_slices=True,
|
||||||
is_blocking_light=False, can_be_shadowed=False, **kwargs)
|
is_blocking_light=False, can_be_shadowed=False, **kwargs)
|
||||||
self.is_observable = True
|
self.is_observable = True
|
||||||
|
|
||||||
def spawn_batteries(self, agents, initial_charge_level):
|
def spawn_batteries(self, agents, initial_charge_level):
|
||||||
batteries = [self._accepted_objects(initial_charge_level, agent, self) for _, agent in enumerate(agents)]
|
batteries = [self._accepted_objects(initial_charge_level, agent, self) for _, agent in enumerate(agents)]
|
||||||
self.add_additional_items(batteries)
|
self.add_additional_items(batteries)
|
||||||
|
|
||||||
def summarize_states(self, n_steps=None):
|
|
||||||
# as dict with additional nesting
|
|
||||||
# return dict(items=super(Inventories, cls).summarize_states())
|
|
||||||
return super(BatteriesRegister, self).summarize_states(n_steps=n_steps)
|
|
||||||
|
|
||||||
# Todo Move this to Mixin!
|
# Todo Move this to Mixin!
|
||||||
def by_entity(self, entity):
|
def by_entity(self, entity):
|
||||||
try:
|
try:
|
||||||
@ -40,11 +35,7 @@ class BatteriesRegister(EnvObjectCollection):
|
|||||||
class ChargePods(EntityCollection):
|
class ChargePods(EntityCollection):
|
||||||
|
|
||||||
_accepted_objects = ChargePod
|
_accepted_objects = ChargePod
|
||||||
|
_stateless_entities = True
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
super(ChargePods, self).__repr__()
|
super(ChargePods, self).__repr__()
|
||||||
|
|
||||||
def summarize_states(self, n_steps=None):
|
|
||||||
# as dict with additional nesting
|
|
||||||
# return dict(items=super(Inventories, cls).summarize_states())
|
|
||||||
return super(ChargePods, self).summarize_states(n_steps=n_steps)
|
|
||||||
|
@ -35,7 +35,7 @@ class Battery(BoundingMixin, EnvObject):
|
|||||||
|
|
||||||
def summarize_state(self, **_):
|
def summarize_state(self, **_):
|
||||||
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||||
attr_dict.update(dict(name=self.name))
|
attr_dict.update(dict(name=self.name, belongs_to=self._bound_entity.name))
|
||||||
return attr_dict
|
return attr_dict
|
||||||
|
|
||||||
|
|
||||||
@ -58,10 +58,3 @@ class ChargePod(Entity):
|
|||||||
return c.NOT_VALID
|
return c.NOT_VALID
|
||||||
valid = battery.do_charge_action(self.charge_rate)
|
valid = battery.do_charge_action(self.charge_rate)
|
||||||
return valid
|
return valid
|
||||||
|
|
||||||
def summarize_state(self, n_steps=None) -> dict:
|
|
||||||
if n_steps == h.STEPS_START:
|
|
||||||
summary = super().summarize_state(n_steps=n_steps)
|
|
||||||
return summary
|
|
||||||
else:
|
|
||||||
{}
|
|
||||||
|
@ -2,7 +2,7 @@ from typing import Dict, List
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from environments.factory.additional.btry.btry_collections import BatteriesRegister, ChargePods
|
from environments.factory.additional.btry.btry_collections import Batteries, ChargePods
|
||||||
from environments.factory.additional.btry.btry_util import Constants, Actions, RewardsBtry, BatteryProperties
|
from environments.factory.additional.btry.btry_util import Constants, Actions, RewardsBtry, BatteryProperties
|
||||||
from environments.factory.base.base_factory import BaseFactory
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
from environments.factory.base.objects import Agent, Action
|
from environments.factory.base.objects import Agent, Action
|
||||||
@ -45,8 +45,8 @@ class BatteryFactory(BaseFactory):
|
|||||||
multi_charge=self.btry_prop.multi_charge)
|
multi_charge=self.btry_prop.multi_charge)
|
||||||
)
|
)
|
||||||
|
|
||||||
batteries = BatteriesRegister(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2),
|
batteries = Batteries(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2),
|
||||||
)
|
)
|
||||||
batteries.spawn_batteries(self[c.AGENT], self.btry_prop.initial_charge)
|
batteries.spawn_batteries(self[c.AGENT], self.btry_prop.initial_charge)
|
||||||
super_entities.update({c.BATTERIES: batteries, c.CHARGE_PODS: charge_pods})
|
super_entities.update({c.BATTERIES: batteries, c.CHARGE_PODS: charge_pods})
|
||||||
return super_entities
|
return super_entities
|
||||||
|
@ -25,9 +25,6 @@ class Destinations(EntityCollection):
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return super(Destinations, self).__repr__()
|
return super(Destinations, self).__repr__()
|
||||||
|
|
||||||
def summarize_states(self, n_steps=None):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
class ReachedDestinations(Destinations):
|
class ReachedDestinations(Destinations):
|
||||||
_accepted_objects = Destination
|
_accepted_objects = Destination
|
||||||
@ -37,8 +34,5 @@ class ReachedDestinations(Destinations):
|
|||||||
self.can_be_shadowed = False
|
self.can_be_shadowed = False
|
||||||
self.is_blocking_light = False
|
self.is_blocking_light = False
|
||||||
|
|
||||||
def summarize_states(self, n_steps=None):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return super(ReachedDestinations, self).__repr__()
|
return super(ReachedDestinations, self).__repr__()
|
||||||
|
@ -38,7 +38,8 @@ class Destination(Entity):
|
|||||||
def agent_is_dwelling(self, agent: Agent):
|
def agent_is_dwelling(self, agent: Agent):
|
||||||
return self._per_agent_times[agent.name] < self.dwell_time
|
return self._per_agent_times[agent.name] < self.dwell_time
|
||||||
|
|
||||||
def summarize_state(self, n_steps=None) -> dict:
|
def summarize_state(self) -> dict:
|
||||||
state_summary = super().summarize_state(n_steps=n_steps)
|
state_summary = super().summarize_state()
|
||||||
state_summary.update(per_agent_times=self._per_agent_times)
|
state_summary.update(per_agent_times=[
|
||||||
|
dict(belongs_to=key, time=val) for key, val in self._per_agent_times.keys()], dwell_time=self.dwell_time)
|
||||||
return state_summary
|
return state_summary
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
from environments.factory.additional.dirt.dirt_entity import Dirt
|
from environments.factory.additional.dirt.dirt_entity import DirtPile
|
||||||
from environments.factory.additional.dirt.dirt_util import DirtProperties
|
from environments.factory.additional.dirt.dirt_util import DirtProperties
|
||||||
from environments.factory.base.objects import Floor
|
from environments.factory.base.objects import Floor
|
||||||
from environments.factory.base.registers import EntityCollection
|
from environments.factory.base.registers import EntityCollection
|
||||||
from environments.factory.additional.dirt.dirt_util import Constants as c
|
from environments.factory.additional.dirt.dirt_util import Constants as c
|
||||||
|
|
||||||
|
|
||||||
class DirtRegister(EntityCollection):
|
class DirtPiles(EntityCollection):
|
||||||
|
|
||||||
_accepted_objects = Dirt
|
_accepted_objects = DirtPile
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def amount(self):
|
def amount(self):
|
||||||
@ -18,7 +18,7 @@ class DirtRegister(EntityCollection):
|
|||||||
return self._dirt_properties
|
return self._dirt_properties
|
||||||
|
|
||||||
def __init__(self, dirt_properties, *args):
|
def __init__(self, dirt_properties, *args):
|
||||||
super(DirtRegister, self).__init__(*args)
|
super(DirtPiles, self).__init__(*args)
|
||||||
self._dirt_properties: DirtProperties = dirt_properties
|
self._dirt_properties: DirtProperties = dirt_properties
|
||||||
|
|
||||||
def spawn_dirt(self, then_dirty_tiles) -> bool:
|
def spawn_dirt(self, then_dirty_tiles) -> bool:
|
||||||
@ -28,7 +28,7 @@ class DirtRegister(EntityCollection):
|
|||||||
if not self.amount > self.dirt_properties.max_global_amount:
|
if not self.amount > self.dirt_properties.max_global_amount:
|
||||||
dirt = self.by_pos(tile.pos)
|
dirt = self.by_pos(tile.pos)
|
||||||
if dirt is None:
|
if dirt is None:
|
||||||
dirt = Dirt(tile, self, amount=self.dirt_properties.max_spawn_amount)
|
dirt = DirtPile(tile, self, amount=self.dirt_properties.max_spawn_amount)
|
||||||
self.add_item(dirt)
|
self.add_item(dirt)
|
||||||
else:
|
else:
|
||||||
new_value = dirt.amount + self.dirt_properties.max_spawn_amount
|
new_value = dirt.amount + self.dirt_properties.max_spawn_amount
|
||||||
@ -38,5 +38,5 @@ class DirtRegister(EntityCollection):
|
|||||||
return c.VALID
|
return c.VALID
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
s = super(DirtRegister, self).__repr__()
|
s = super(DirtPiles, self).__repr__()
|
||||||
return f'{s[:-1]}, {self.amount})'
|
return f'{s[:-1]}, {self.amount})'
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from environments.factory.base.objects import Entity
|
from environments.factory.base.objects import Entity
|
||||||
|
|
||||||
|
|
||||||
class Dirt(Entity):
|
class DirtPile(Entity):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def amount(self):
|
def amount(self):
|
||||||
@ -13,14 +13,14 @@ class Dirt(Entity):
|
|||||||
return self._amount
|
return self._amount
|
||||||
|
|
||||||
def __init__(self, *args, amount=None, **kwargs):
|
def __init__(self, *args, amount=None, **kwargs):
|
||||||
super(Dirt, self).__init__(*args, **kwargs)
|
super(DirtPile, self).__init__(*args, **kwargs)
|
||||||
self._amount = amount
|
self._amount = amount
|
||||||
|
|
||||||
def set_new_amount(self, amount):
|
def set_new_amount(self, amount):
|
||||||
self._amount = amount
|
self._amount = amount
|
||||||
self._collection.notify_change_to_value(self)
|
self._collection.notify_change_to_value(self)
|
||||||
|
|
||||||
def summarize_state(self, **kwargs):
|
def summarize_state(self):
|
||||||
state_dict = super().summarize_state(**kwargs)
|
state_dict = super().summarize_state()
|
||||||
state_dict.update(amount=float(self.amount))
|
state_dict.update(amount=float(self.amount))
|
||||||
return state_dict
|
return state_dict
|
||||||
|
@ -4,7 +4,7 @@ from environments.helpers import Constants as BaseConstants, EnvActions as BaseA
|
|||||||
|
|
||||||
|
|
||||||
class Constants(BaseConstants):
|
class Constants(BaseConstants):
|
||||||
DIRT = 'Dirt'
|
DIRT = 'DirtPile'
|
||||||
|
|
||||||
|
|
||||||
class Actions(BaseActions):
|
class Actions(BaseActions):
|
||||||
|
@ -5,8 +5,8 @@ import random
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from environments.factory.additional.dirt.dirt_collections import DirtRegister
|
from environments.factory.additional.dirt.dirt_collections import DirtPiles
|
||||||
from environments.factory.additional.dirt.dirt_entity import Dirt
|
from environments.factory.additional.dirt.dirt_entity import DirtPile
|
||||||
from environments.factory.additional.dirt.dirt_util import Constants, Actions, RewardsDirt, DirtProperties
|
from environments.factory.additional.dirt.dirt_util import Constants, Actions, RewardsDirt, DirtProperties
|
||||||
|
|
||||||
from environments.factory.base.base_factory import BaseFactory
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
@ -43,7 +43,7 @@ class DirtFactory(BaseFactory):
|
|||||||
@property
|
@property
|
||||||
def entities_hook(self) -> Dict[(str, Entities)]:
|
def entities_hook(self) -> Dict[(str, Entities)]:
|
||||||
super_entities = super().entities_hook
|
super_entities = super().entities_hook
|
||||||
dirt_register = DirtRegister(self.dirt_prop, self._level_shape)
|
dirt_register = DirtPiles(self.dirt_prop, self._level_shape)
|
||||||
super_entities.update(({c.DIRT: dirt_register}))
|
super_entities.update(({c.DIRT: dirt_register}))
|
||||||
return super_entities
|
return super_entities
|
||||||
|
|
||||||
@ -57,7 +57,7 @@ class DirtFactory(BaseFactory):
|
|||||||
self.dirt_prop = dirt_prop
|
self.dirt_prop = dirt_prop
|
||||||
self.rewards_dirt = rewards_dirt
|
self.rewards_dirt = rewards_dirt
|
||||||
self._dirt_rng = np.random.default_rng(env_seed)
|
self._dirt_rng = np.random.default_rng(env_seed)
|
||||||
self._dirt: DirtRegister
|
self._dirt: DirtPiles
|
||||||
kwargs.update(env_seed=env_seed)
|
kwargs.update(env_seed=env_seed)
|
||||||
# TODO: Reset ---> document this
|
# TODO: Reset ---> document this
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@ -96,7 +96,7 @@ class DirtFactory(BaseFactory):
|
|||||||
def trigger_dirt_spawn(self, initial_spawn=False):
|
def trigger_dirt_spawn(self, initial_spawn=False):
|
||||||
dirt_rng = self._dirt_rng
|
dirt_rng = self._dirt_rng
|
||||||
free_for_dirt = [x for x in self[c.FLOOR]
|
free_for_dirt = [x for x in self[c.FLOOR]
|
||||||
if len(x.guests) == 0 or (len(x.guests) == 1 and isinstance(next(y for y in x.guests), Dirt))
|
if len(x.guests) == 0 or (len(x.guests) == 1 and isinstance(next(y for y in x.guests), DirtPile))
|
||||||
]
|
]
|
||||||
self._dirt_rng.shuffle(free_for_dirt)
|
self._dirt_rng.shuffle(free_for_dirt)
|
||||||
if initial_spawn:
|
if initial_spawn:
|
||||||
@ -123,7 +123,7 @@ class DirtFactory(BaseFactory):
|
|||||||
new_pos_dirt = self[c.DIRT].by_pos(agent.pos)
|
new_pos_dirt = self[c.DIRT].by_pos(agent.pos)
|
||||||
new_pos_dirt.set_new_amount(max(0, new_pos_dirt.amount + smeared_dirt))
|
new_pos_dirt.set_new_amount(max(0, new_pos_dirt.amount + smeared_dirt))
|
||||||
if self._next_dirt_spawn < 0:
|
if self._next_dirt_spawn < 0:
|
||||||
pass # No Dirt Spawn
|
pass # No DirtPile Spawn
|
||||||
elif not self._next_dirt_spawn:
|
elif not self._next_dirt_spawn:
|
||||||
self.trigger_dirt_spawn()
|
self.trigger_dirt_spawn()
|
||||||
self._next_dirt_spawn = self.dirt_prop.spawn_frequency
|
self._next_dirt_spawn = self.dirt_prop.spawn_frequency
|
||||||
|
@ -3,7 +3,7 @@ from typing import List, Union, Dict
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from environments.factory.additional.item.item_collections import ItemRegister, Inventories, DropOffLocations
|
from environments.factory.additional.item.item_collections import Items, Inventories, DropOffLocations
|
||||||
from environments.factory.additional.item.item_util import Constants, Actions, RewardsItem, ItemProperties
|
from environments.factory.additional.item.item_util import Constants, Actions, RewardsItem, ItemProperties
|
||||||
from environments.factory.base.base_factory import BaseFactory
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
from environments.factory.base.objects import Agent, Action
|
from environments.factory.base.objects import Agent, Action
|
||||||
@ -49,7 +49,7 @@ class ItemFactory(BaseFactory):
|
|||||||
entity_kwargs=dict(
|
entity_kwargs=dict(
|
||||||
storage_size_until_full=self.item_prop.max_dropoff_storage_size)
|
storage_size_until_full=self.item_prop.max_dropoff_storage_size)
|
||||||
)
|
)
|
||||||
item_register = ItemRegister(self._level_shape)
|
item_register = Items(self._level_shape)
|
||||||
empty_tiles = self[c.FLOOR].empty_tiles[:self.item_prop.n_items]
|
empty_tiles = self[c.FLOOR].empty_tiles[:self.item_prop.n_items]
|
||||||
item_register.spawn_items(empty_tiles)
|
item_register.spawn_items(empty_tiles)
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from environments.factory.base.registers import EntityCollection, BoundEnvObjCol
|
|||||||
from environments.factory.additional.item.item_entities import Item, DropOffLocation
|
from environments.factory.additional.item.item_entities import Item, DropOffLocation
|
||||||
|
|
||||||
|
|
||||||
class ItemRegister(EntityCollection):
|
class Items(EntityCollection):
|
||||||
|
|
||||||
_accepted_objects = Item
|
_accepted_objects = Item
|
||||||
|
|
||||||
@ -37,9 +37,9 @@ class Inventory(BoundEnvObjCollection):
|
|||||||
return super(Inventory, self).as_array()
|
return super(Inventory, self).as_array()
|
||||||
|
|
||||||
def summarize_states(self, **kwargs):
|
def summarize_states(self, **kwargs):
|
||||||
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
attr_dict = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||||
attr_dict.update(dict(items={key: val.summarize_state(**kwargs) for key, val in self.items()}))
|
attr_dict.update(dict(items=[val.summarize_state(**kwargs) for key, val in self.items()]))
|
||||||
attr_dict.update(dict(name=self.name))
|
attr_dict.update(dict(name=self.name, belongs_to=self._bound_entity.name))
|
||||||
return attr_dict
|
return attr_dict
|
||||||
|
|
||||||
def pop(self):
|
def pop(self):
|
||||||
@ -79,9 +79,11 @@ class Inventories(ObjectCollection):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def summarize_states(self, **kwargs):
|
def summarize_states(self, **kwargs):
|
||||||
return {key: val.summarize_states(**kwargs) for key, val in self.items()}
|
return [val.summarize_states(**kwargs) for key, val in self.items()]
|
||||||
|
|
||||||
|
|
||||||
class DropOffLocations(EntityCollection):
|
class DropOffLocations(EntityCollection):
|
||||||
|
|
||||||
_accepted_objects = DropOffLocation
|
_accepted_objects = DropOffLocation
|
||||||
|
_stateless_entities = True
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ class Item(Entity):
|
|||||||
def set_tile_to(self, no_pos_tile):
|
def set_tile_to(self, no_pos_tile):
|
||||||
self._tile = no_pos_tile
|
self._tile = no_pos_tile
|
||||||
|
|
||||||
def summarize_state(self, **_) -> dict:
|
def summarize_state(self) -> dict:
|
||||||
super_summarization = super(Item, self).summarize_state()
|
super_summarization = super(Item, self).summarize_state()
|
||||||
super_summarization.update(dict(auto_despawn=self.auto_despawn))
|
super_summarization.update(dict(auto_despawn=self.auto_despawn))
|
||||||
return super_summarization
|
return super_summarization
|
||||||
@ -55,7 +55,3 @@ class DropOffLocation(Entity):
|
|||||||
@property
|
@property
|
||||||
def is_full(self):
|
def is_full(self):
|
||||||
return False if not self.storage.maxlen else self.storage.maxlen == len(self.storage)
|
return False if not self.storage.maxlen else self.storage.maxlen == len(self.storage)
|
||||||
|
|
||||||
def summarize_state(self, n_steps=None) -> dict:
|
|
||||||
if n_steps == h.STEPS_START:
|
|
||||||
return super().summarize_state(n_steps=n_steps)
|
|
||||||
|
@ -71,6 +71,12 @@ class BaseFactory(gym.Env):
|
|||||||
d['class_name'] = self.__class__.__name__
|
d['class_name'] = self.__class__.__name__
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
@property
|
||||||
|
def summarize_header(self):
|
||||||
|
summary_dict = self._summarize_state(stateless_entities=True)
|
||||||
|
summary_dict.update(actions=self._actions.summarize())
|
||||||
|
return summary_dict
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self if self.obs_prop.frames_to_stack == 0 else \
|
return self if self.obs_prop.frames_to_stack == 0 else \
|
||||||
MarlFrameStack(FrameStack(self, self.obs_prop.frames_to_stack))
|
MarlFrameStack(FrameStack(self, self.obs_prop.frames_to_stack))
|
||||||
@ -665,12 +671,12 @@ class BaseFactory(gym.Env):
|
|||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _summarize_state(self):
|
def _summarize_state(self, stateless_entities=False):
|
||||||
summary = {f'{REC_TAC}step': self._steps}
|
summary = {f'{REC_TAC}step': self._steps}
|
||||||
|
|
||||||
for entity_group in self._entities:
|
for entity_group in self._entities:
|
||||||
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states(n_steps=self._steps)})
|
if entity_group.is_stateless == stateless_entities:
|
||||||
|
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states()})
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
def print(self, string):
|
def print(self, string):
|
||||||
|
@ -86,7 +86,7 @@ class EnvObject(Object):
|
|||||||
|
|
||||||
# TODO: Missing Documentation
|
# TODO: Missing Documentation
|
||||||
class Entity(EnvObject):
|
class Entity(EnvObject):
|
||||||
"""Full Env Entity that lives on the env Grid. Doors, Items, Dirt etc..."""
|
"""Full Env Entity that lives on the env Grid. Doors, Items, DirtPile etc..."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def can_collide(self):
|
def can_collide(self):
|
||||||
@ -113,7 +113,7 @@ class Entity(EnvObject):
|
|||||||
self._tile = tile
|
self._tile = tile
|
||||||
tile.enter(self)
|
tile.enter(self)
|
||||||
|
|
||||||
def summarize_state(self, **_) -> dict:
|
def summarize_state(self) -> dict:
|
||||||
return dict(name=str(self.name), x=int(self.x), y=int(self.y),
|
return dict(name=str(self.name), x=int(self.x), y=int(self.y),
|
||||||
tile=str(self.tile.name), can_collide=bool(self.can_collide))
|
tile=str(self.tile.name), can_collide=bool(self.can_collide))
|
||||||
|
|
||||||
@ -338,8 +338,8 @@ class Door(Entity):
|
|||||||
if not closed_on_init:
|
if not closed_on_init:
|
||||||
self._open()
|
self._open()
|
||||||
|
|
||||||
def summarize_state(self, **kwargs):
|
def summarize_state(self):
|
||||||
state_dict = super().summarize_state(**kwargs)
|
state_dict = super().summarize_state()
|
||||||
state_dict.update(state=str(self.str_state), time_to_close=int(self.time_to_close))
|
state_dict.update(state=str(self.str_state), time_to_close=int(self.time_to_close))
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
@ -402,7 +402,7 @@ class Agent(MoveableEntity):
|
|||||||
# if attr.startswith('temp'):
|
# if attr.startswith('temp'):
|
||||||
self.step_result = None
|
self.step_result = None
|
||||||
|
|
||||||
def summarize_state(self, **kwargs):
|
def summarize_state(self):
|
||||||
state_dict = super().summarize_state(**kwargs)
|
state_dict = super().summarize_state()
|
||||||
state_dict.update(valid=bool(self.step_result['action_valid']), action=str(self.step_result['action_name']))
|
state_dict.update(valid=bool(self.step_result['action_valid']), action=str(self.step_result['action_name']))
|
||||||
return state_dict
|
return state_dict
|
||||||
|
@ -19,6 +19,11 @@ from environments.helpers import Constants as c
|
|||||||
|
|
||||||
class ObjectCollection:
|
class ObjectCollection:
|
||||||
_accepted_objects = Object
|
_accepted_objects = Object
|
||||||
|
_stateless_entities = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_stateless(self):
|
||||||
|
return self._stateless_entities
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
@ -116,8 +121,8 @@ class EnvObjectCollection(ObjectCollection):
|
|||||||
self._lazy_eval_transforms = []
|
self._lazy_eval_transforms = []
|
||||||
return self._array
|
return self._array
|
||||||
|
|
||||||
def summarize_states(self, n_steps=None):
|
def summarize_states(self):
|
||||||
return [entity.summarize_state(n_steps=n_steps) for entity in self.values()]
|
return [entity.summarize_state() for entity in self.values()]
|
||||||
|
|
||||||
def notify_change_to_free(self, env_object: EnvObject):
|
def notify_change_to_free(self, env_object: EnvObject):
|
||||||
self._array_change_notifyer(env_object, value=c.FREE_CELL)
|
self._array_change_notifyer(env_object, value=c.FREE_CELL)
|
||||||
@ -290,9 +295,6 @@ class GlobalPositions(EnvObjectCollection):
|
|||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
self.add_additional_items(global_positions)
|
self.add_additional_items(global_positions)
|
||||||
|
|
||||||
def summarize_states(self, n_steps=None):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def idx_by_entity(self, entity):
|
def idx_by_entity(self, entity):
|
||||||
try:
|
try:
|
||||||
return next((idx for idx, inv in enumerate(self) if inv.belongs_to_entity(entity)))
|
return next((idx for idx, inv in enumerate(self) if inv.belongs_to_entity(entity)))
|
||||||
@ -376,6 +378,7 @@ class Entities(ObjectCollection):
|
|||||||
|
|
||||||
class Walls(EntityCollection):
|
class Walls(EntityCollection):
|
||||||
_accepted_objects = Wall
|
_accepted_objects = Wall
|
||||||
|
_stateless_entities = True
|
||||||
|
|
||||||
def as_array(self):
|
def as_array(self):
|
||||||
if not np.any(self._array):
|
if not np.any(self._array):
|
||||||
@ -406,15 +409,10 @@ class Walls(EntityCollection):
|
|||||||
def from_tiles(cls, tiles, *args, **kwargs):
|
def from_tiles(cls, tiles, *args, **kwargs):
|
||||||
raise RuntimeError()
|
raise RuntimeError()
|
||||||
|
|
||||||
def summarize_states(self, n_steps=None):
|
|
||||||
if n_steps == h.STEPS_START:
|
|
||||||
return super(Walls, self).summarize_states(n_steps=n_steps)
|
|
||||||
else:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
class Floors(Walls):
|
class Floors(Walls):
|
||||||
_accepted_objects = Floor
|
_accepted_objects = Floor
|
||||||
|
_stateless_entities = True
|
||||||
|
|
||||||
def __init__(self, *args, is_blocking_light=False, **kwargs):
|
def __init__(self, *args, is_blocking_light=False, **kwargs):
|
||||||
super(Floors, self).__init__(*args, is_blocking_light=is_blocking_light, **kwargs)
|
super(Floors, self).__init__(*args, is_blocking_light=is_blocking_light, **kwargs)
|
||||||
@ -436,10 +434,6 @@ class Floors(Walls):
|
|||||||
def from_tiles(cls, tiles, *args, **kwargs):
|
def from_tiles(cls, tiles, *args, **kwargs):
|
||||||
raise RuntimeError()
|
raise RuntimeError()
|
||||||
|
|
||||||
def summarize_states(self, n_steps=None):
|
|
||||||
# Do not summarize
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
class Agents(MovingEntityObjectCollection):
|
class Agents(MovingEntityObjectCollection):
|
||||||
_accepted_objects = Agent
|
_accepted_objects = Agent
|
||||||
@ -521,6 +515,9 @@ class Actions(ObjectCollection):
|
|||||||
def is_moving_action(self, action: Union[int]):
|
def is_moving_action(self, action: Union[int]):
|
||||||
return action in self.movement_actions.values()
|
return action in self.movement_actions.values()
|
||||||
|
|
||||||
|
def summarize(self):
|
||||||
|
return [dict(name=action.identifier) for action in self]
|
||||||
|
|
||||||
|
|
||||||
class Zones(ObjectCollection):
|
class Zones(ObjectCollection):
|
||||||
|
|
||||||
|
@ -4,8 +4,8 @@ import numpy as np
|
|||||||
|
|
||||||
from environments.factory.base.objects import Agent, Entity, Action
|
from environments.factory.base.objects import Agent, Entity, Action
|
||||||
from environments.factory.factory_dirt import DirtFactory
|
from environments.factory.factory_dirt import DirtFactory
|
||||||
from environments.factory.additional.dirt.dirt_collections import DirtRegister
|
from environments.factory.additional.dirt.dirt_collections import DirtPiles
|
||||||
from environments.factory.additional.dirt.dirt_entity import Dirt
|
from environments.factory.additional.dirt.dirt_entity import DirtPile
|
||||||
from environments.factory.base.objects import Floor
|
from environments.factory.base.objects import Floor
|
||||||
from environments.factory.base.registers import Floors, Entities, EntityCollection
|
from environments.factory.base.registers import Floors, Entities, EntityCollection
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from os import PathLike
|
from os import PathLike
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -21,6 +22,7 @@ class EnvRecorder(BaseCallback):
|
|||||||
self._recorder_dict = defaultdict(list)
|
self._recorder_dict = defaultdict(list)
|
||||||
self._recorder_out_list = list()
|
self._recorder_out_list = list()
|
||||||
self._episode_counter = 1
|
self._episode_counter = 1
|
||||||
|
self._do_record_dict = defaultdict(lambda: False)
|
||||||
if isinstance(entities, str):
|
if isinstance(entities, str):
|
||||||
if entities.lower() == 'all':
|
if entities.lower() == 'all':
|
||||||
self._entities = None
|
self._entities = None
|
||||||
@ -58,7 +60,11 @@ class EnvRecorder(BaseCallback):
|
|||||||
|
|
||||||
def step(self, actions):
|
def step(self, actions):
|
||||||
step_result = self.unwrapped.step(actions)
|
step_result = self.unwrapped.step(actions)
|
||||||
self._on_step()
|
if self.do_record_episode(0):
|
||||||
|
info = step_result[-1]
|
||||||
|
self._read_info(0, info)
|
||||||
|
if self._do_record_dict[0]:
|
||||||
|
self._read_done(0, step_result[-2])
|
||||||
return step_result
|
return step_result
|
||||||
|
|
||||||
def finalize(self):
|
def finalize(self):
|
||||||
@ -71,7 +77,8 @@ class EnvRecorder(BaseCallback):
|
|||||||
# cls.out_file.unlink(missing_ok=True)
|
# cls.out_file.unlink(missing_ok=True)
|
||||||
with filepath.open('w') as f:
|
with filepath.open('w') as f:
|
||||||
out_dict = {'n_episodes': self._episode_counter,
|
out_dict = {'n_episodes': self._episode_counter,
|
||||||
'header': self.unwrapped.params,
|
'env_params': self.unwrapped.params,
|
||||||
|
'header': self.unwrapped.summarize_header,
|
||||||
'episodes': self._recorder_out_list
|
'episodes': self._recorder_out_list
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
@ -99,18 +106,29 @@ class EnvRecorder(BaseCallback):
|
|||||||
if save_trajectory_map:
|
if save_trajectory_map:
|
||||||
raise NotImplementedError('This has not yet been implemented.')
|
raise NotImplementedError('This has not yet been implemented.')
|
||||||
|
|
||||||
|
def do_record_episode(self, env_idx):
|
||||||
|
if not self._recorder_dict[env_idx]:
|
||||||
|
if self.freq:
|
||||||
|
self._do_record_dict[env_idx] = (self.freq == -1) or (self._episode_counter % self.freq) == 0
|
||||||
|
else:
|
||||||
|
self._do_record_dict[env_idx] = False
|
||||||
|
warnings.warn('You did wrap your Environment with a recorder, but set the freq to zero\n'
|
||||||
|
'Nothing will be recorded')
|
||||||
|
self._episode_counter += 1
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
return self._do_record_dict[env_idx]
|
||||||
|
|
||||||
def _on_step(self) -> bool:
|
def _on_step(self) -> bool:
|
||||||
do_record = self.freq == -1 or self._episode_counter % self.freq == 0
|
|
||||||
for env_idx, info in enumerate(self.locals.get('infos', [])):
|
for env_idx, info in enumerate(self.locals.get('infos', [])):
|
||||||
if do_record:
|
if self._do_record_dict[env_idx]:
|
||||||
self._read_info(env_idx, info)
|
self._read_info(env_idx, info)
|
||||||
dones = list(enumerate(self.locals.get('dones', [])))
|
dones = list(enumerate(self.locals.get('dones', [])))
|
||||||
dones.extend(list(enumerate(self.locals.get('done', []))))
|
dones.extend(list(enumerate(self.locals.get('done', []))))
|
||||||
for env_idx, done in dones:
|
for env_idx, done in dones:
|
||||||
if do_record:
|
if self._do_record_dict[env_idx]:
|
||||||
self._read_done(env_idx, done)
|
self._read_done(env_idx, done)
|
||||||
if done:
|
|
||||||
self._episode_counter += 1
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _on_training_end(self) -> None:
|
def _on_training_end(self) -> None:
|
||||||
|
@ -60,6 +60,7 @@ def encapsule_env_factory(env_fctry, env_kwrgs):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
render = False
|
||||||
# Define Global Env Parameters
|
# Define Global Env Parameters
|
||||||
# Define properties object parameters
|
# Define properties object parameters
|
||||||
factory_kwargs = dict(
|
factory_kwargs = dict(
|
||||||
@ -140,17 +141,17 @@ if __name__ == '__main__':
|
|||||||
combined_env_kwargs[key] = val
|
combined_env_kwargs[key] = val
|
||||||
else:
|
else:
|
||||||
assert combined_env_kwargs[key] == val
|
assert combined_env_kwargs[key] == val
|
||||||
|
del combined_env_kwargs['key']
|
||||||
combined_env_kwargs.update(n_agents=len(comparable_runs))
|
combined_env_kwargs.update(n_agents=len(comparable_runs))
|
||||||
|
with type("CombinedEnv", tuple(available_envs.values()), {})(**combined_env_kwargs) as combEnv:
|
||||||
with(type("CombinedEnv", tuple(available_envs.values()), {})(**combined_env_kwargs)) as combEnv:
|
|
||||||
# EnvMonitor Init
|
# EnvMonitor Init
|
||||||
comb = f'comb_{model_name}_{seed}'
|
comb = f'comb_{model_name}_{seed}'
|
||||||
comb_monitor_path = combinations_path / comb / f'{comb}_monitor.pick'
|
comb_monitor_path = combinations_path / comb / f'{comb}_monitor.pick'
|
||||||
comb_recorder_path = combinations_path / comb / f'{comb}_recorder.pick'
|
comb_recorder_path = combinations_path / comb / f'{comb}_recorder.json'
|
||||||
comb_monitor_path.parent.mkdir(parents=True, exist_ok=True)
|
comb_monitor_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
monitoredCombEnv = EnvMonitor(combEnv, filepath=comb_monitor_path)
|
monitoredCombEnv = EnvMonitor(combEnv, filepath=comb_monitor_path)
|
||||||
# monitoredCombEnv = EnvRecorder(monitoredCombEnv, filepath=comb_monitor_path)
|
monitoredCombEnv = EnvRecorder(monitoredCombEnv, filepath=comb_recorder_path, freq=1)
|
||||||
|
|
||||||
# Evaluation starts here #####################################################
|
# Evaluation starts here #####################################################
|
||||||
# Load all models
|
# Load all models
|
||||||
@ -164,8 +165,9 @@ if __name__ == '__main__':
|
|||||||
*(agent.named_action_space for agent in loaded_models)
|
*(agent.named_action_space for agent in loaded_models)
|
||||||
)
|
)
|
||||||
|
|
||||||
for episode in range(50):
|
for episode in range(1):
|
||||||
obs, _ = monitoredCombEnv.reset(), monitoredCombEnv.render()
|
obs = monitoredCombEnv.reset()
|
||||||
|
if render: monitoredCombEnv.render()
|
||||||
rew, done_bool = 0, False
|
rew, done_bool = 0, False
|
||||||
while not done_bool:
|
while not done_bool:
|
||||||
actions = []
|
actions = []
|
||||||
@ -176,12 +178,12 @@ if __name__ == '__main__':
|
|||||||
obs, step_r, done_bool, info_obj = monitoredCombEnv.step(actions)
|
obs, step_r, done_bool, info_obj = monitoredCombEnv.step(actions)
|
||||||
|
|
||||||
rew += step_r
|
rew += step_r
|
||||||
monitoredCombEnv.render()
|
if render: monitoredCombEnv.render()
|
||||||
if done_bool:
|
if done_bool:
|
||||||
break
|
break
|
||||||
print(f'Factory run {episode} done, reward is:\n {rew}')
|
print(f'Factory run {episode} done, reward is:\n {rew}')
|
||||||
# Eval monitor outputs are automatically stored by the monitor object
|
# Eval monitor outputs are automatically stored by the monitor object
|
||||||
# TODO: Plotting
|
# TODO: Plotting
|
||||||
monitoredCombEnv.save_records(comb_monitor_path)
|
monitoredCombEnv.save_records()
|
||||||
monitoredCombEnv.save_run()
|
monitoredCombEnv.save_run()
|
||||||
pass
|
pass
|
||||||
|
Loading…
x
Reference in New Issue
Block a user