recorder fixed
This commit is contained in:
@@ -2,24 +2,19 @@ from environments.factory.additional.btry.btry_objects import Battery, ChargePod
|
||||
from environments.factory.base.registers import EnvObjectCollection, EntityCollection
|
||||
|
||||
|
||||
class BatteriesRegister(EnvObjectCollection):
|
||||
class Batteries(EnvObjectCollection):
|
||||
|
||||
_accepted_objects = Battery
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(BatteriesRegister, self).__init__(*args, individual_slices=True,
|
||||
is_blocking_light=False, can_be_shadowed=False, **kwargs)
|
||||
super(Batteries, self).__init__(*args, individual_slices=True,
|
||||
is_blocking_light=False, can_be_shadowed=False, **kwargs)
|
||||
self.is_observable = True
|
||||
|
||||
def spawn_batteries(self, agents, initial_charge_level):
|
||||
batteries = [self._accepted_objects(initial_charge_level, agent, self) for _, agent in enumerate(agents)]
|
||||
self.add_additional_items(batteries)
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
# as dict with additional nesting
|
||||
# return dict(items=super(Inventories, cls).summarize_states())
|
||||
return super(BatteriesRegister, self).summarize_states(n_steps=n_steps)
|
||||
|
||||
# Todo Move this to Mixin!
|
||||
def by_entity(self, entity):
|
||||
try:
|
||||
@@ -40,11 +35,7 @@ class BatteriesRegister(EnvObjectCollection):
|
||||
class ChargePods(EntityCollection):
|
||||
|
||||
_accepted_objects = ChargePod
|
||||
_stateless_entities = True
|
||||
|
||||
def __repr__(self):
|
||||
super(ChargePods, self).__repr__()
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
# as dict with additional nesting
|
||||
# return dict(items=super(Inventories, cls).summarize_states())
|
||||
return super(ChargePods, self).summarize_states(n_steps=n_steps)
|
||||
|
||||
@@ -35,7 +35,7 @@ class Battery(BoundingMixin, EnvObject):
|
||||
|
||||
def summarize_state(self, **_):
|
||||
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||
attr_dict.update(dict(name=self.name))
|
||||
attr_dict.update(dict(name=self.name, belongs_to=self._bound_entity.name))
|
||||
return attr_dict
|
||||
|
||||
|
||||
@@ -58,10 +58,3 @@ class ChargePod(Entity):
|
||||
return c.NOT_VALID
|
||||
valid = battery.do_charge_action(self.charge_rate)
|
||||
return valid
|
||||
|
||||
def summarize_state(self, n_steps=None) -> dict:
|
||||
if n_steps == h.STEPS_START:
|
||||
summary = super().summarize_state(n_steps=n_steps)
|
||||
return summary
|
||||
else:
|
||||
{}
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from environments.factory.additional.btry.btry_collections import BatteriesRegister, ChargePods
|
||||
from environments.factory.additional.btry.btry_collections import Batteries, ChargePods
|
||||
from environments.factory.additional.btry.btry_util import Constants, Actions, RewardsBtry, BatteryProperties
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.factory.base.objects import Agent, Action
|
||||
@@ -45,8 +45,8 @@ class BatteryFactory(BaseFactory):
|
||||
multi_charge=self.btry_prop.multi_charge)
|
||||
)
|
||||
|
||||
batteries = BatteriesRegister(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2),
|
||||
)
|
||||
batteries = Batteries(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2),
|
||||
)
|
||||
batteries.spawn_batteries(self[c.AGENT], self.btry_prop.initial_charge)
|
||||
super_entities.update({c.BATTERIES: batteries, c.CHARGE_PODS: charge_pods})
|
||||
return super_entities
|
||||
|
||||
@@ -25,9 +25,6 @@ class Destinations(EntityCollection):
|
||||
def __repr__(self):
|
||||
return super(Destinations, self).__repr__()
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
return {}
|
||||
|
||||
|
||||
class ReachedDestinations(Destinations):
|
||||
_accepted_objects = Destination
|
||||
@@ -37,8 +34,5 @@ class ReachedDestinations(Destinations):
|
||||
self.can_be_shadowed = False
|
||||
self.is_blocking_light = False
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
return {}
|
||||
|
||||
def __repr__(self):
|
||||
return super(ReachedDestinations, self).__repr__()
|
||||
|
||||
@@ -38,7 +38,8 @@ class Destination(Entity):
|
||||
def agent_is_dwelling(self, agent: Agent):
|
||||
return self._per_agent_times[agent.name] < self.dwell_time
|
||||
|
||||
def summarize_state(self, n_steps=None) -> dict:
|
||||
state_summary = super().summarize_state(n_steps=n_steps)
|
||||
state_summary.update(per_agent_times=self._per_agent_times)
|
||||
def summarize_state(self) -> dict:
|
||||
state_summary = super().summarize_state()
|
||||
state_summary.update(per_agent_times=[
|
||||
dict(belongs_to=key, time=val) for key, val in self._per_agent_times.keys()], dwell_time=self.dwell_time)
|
||||
return state_summary
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from environments.factory.additional.dirt.dirt_entity import Dirt
|
||||
from environments.factory.additional.dirt.dirt_entity import DirtPile
|
||||
from environments.factory.additional.dirt.dirt_util import DirtProperties
|
||||
from environments.factory.base.objects import Floor
|
||||
from environments.factory.base.registers import EntityCollection
|
||||
from environments.factory.additional.dirt.dirt_util import Constants as c
|
||||
|
||||
|
||||
class DirtRegister(EntityCollection):
|
||||
class DirtPiles(EntityCollection):
|
||||
|
||||
_accepted_objects = Dirt
|
||||
_accepted_objects = DirtPile
|
||||
|
||||
@property
|
||||
def amount(self):
|
||||
@@ -18,7 +18,7 @@ class DirtRegister(EntityCollection):
|
||||
return self._dirt_properties
|
||||
|
||||
def __init__(self, dirt_properties, *args):
|
||||
super(DirtRegister, self).__init__(*args)
|
||||
super(DirtPiles, self).__init__(*args)
|
||||
self._dirt_properties: DirtProperties = dirt_properties
|
||||
|
||||
def spawn_dirt(self, then_dirty_tiles) -> bool:
|
||||
@@ -28,7 +28,7 @@ class DirtRegister(EntityCollection):
|
||||
if not self.amount > self.dirt_properties.max_global_amount:
|
||||
dirt = self.by_pos(tile.pos)
|
||||
if dirt is None:
|
||||
dirt = Dirt(tile, self, amount=self.dirt_properties.max_spawn_amount)
|
||||
dirt = DirtPile(tile, self, amount=self.dirt_properties.max_spawn_amount)
|
||||
self.add_item(dirt)
|
||||
else:
|
||||
new_value = dirt.amount + self.dirt_properties.max_spawn_amount
|
||||
@@ -38,5 +38,5 @@ class DirtRegister(EntityCollection):
|
||||
return c.VALID
|
||||
|
||||
def __repr__(self):
|
||||
s = super(DirtRegister, self).__repr__()
|
||||
s = super(DirtPiles, self).__repr__()
|
||||
return f'{s[:-1]}, {self.amount})'
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from environments.factory.base.objects import Entity
|
||||
|
||||
|
||||
class Dirt(Entity):
|
||||
class DirtPile(Entity):
|
||||
|
||||
@property
|
||||
def amount(self):
|
||||
@@ -13,14 +13,14 @@ class Dirt(Entity):
|
||||
return self._amount
|
||||
|
||||
def __init__(self, *args, amount=None, **kwargs):
|
||||
super(Dirt, self).__init__(*args, **kwargs)
|
||||
super(DirtPile, self).__init__(*args, **kwargs)
|
||||
self._amount = amount
|
||||
|
||||
def set_new_amount(self, amount):
|
||||
self._amount = amount
|
||||
self._collection.notify_change_to_value(self)
|
||||
|
||||
def summarize_state(self, **kwargs):
|
||||
state_dict = super().summarize_state(**kwargs)
|
||||
def summarize_state(self):
|
||||
state_dict = super().summarize_state()
|
||||
state_dict.update(amount=float(self.amount))
|
||||
return state_dict
|
||||
|
||||
@@ -4,7 +4,7 @@ from environments.helpers import Constants as BaseConstants, EnvActions as BaseA
|
||||
|
||||
|
||||
class Constants(BaseConstants):
|
||||
DIRT = 'Dirt'
|
||||
DIRT = 'DirtPile'
|
||||
|
||||
|
||||
class Actions(BaseActions):
|
||||
|
||||
@@ -5,8 +5,8 @@ import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
from environments.factory.additional.dirt.dirt_collections import DirtRegister
|
||||
from environments.factory.additional.dirt.dirt_entity import Dirt
|
||||
from environments.factory.additional.dirt.dirt_collections import DirtPiles
|
||||
from environments.factory.additional.dirt.dirt_entity import DirtPile
|
||||
from environments.factory.additional.dirt.dirt_util import Constants, Actions, RewardsDirt, DirtProperties
|
||||
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
@@ -43,7 +43,7 @@ class DirtFactory(BaseFactory):
|
||||
@property
|
||||
def entities_hook(self) -> Dict[(str, Entities)]:
|
||||
super_entities = super().entities_hook
|
||||
dirt_register = DirtRegister(self.dirt_prop, self._level_shape)
|
||||
dirt_register = DirtPiles(self.dirt_prop, self._level_shape)
|
||||
super_entities.update(({c.DIRT: dirt_register}))
|
||||
return super_entities
|
||||
|
||||
@@ -57,7 +57,7 @@ class DirtFactory(BaseFactory):
|
||||
self.dirt_prop = dirt_prop
|
||||
self.rewards_dirt = rewards_dirt
|
||||
self._dirt_rng = np.random.default_rng(env_seed)
|
||||
self._dirt: DirtRegister
|
||||
self._dirt: DirtPiles
|
||||
kwargs.update(env_seed=env_seed)
|
||||
# TODO: Reset ---> document this
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -96,7 +96,7 @@ class DirtFactory(BaseFactory):
|
||||
def trigger_dirt_spawn(self, initial_spawn=False):
|
||||
dirt_rng = self._dirt_rng
|
||||
free_for_dirt = [x for x in self[c.FLOOR]
|
||||
if len(x.guests) == 0 or (len(x.guests) == 1 and isinstance(next(y for y in x.guests), Dirt))
|
||||
if len(x.guests) == 0 or (len(x.guests) == 1 and isinstance(next(y for y in x.guests), DirtPile))
|
||||
]
|
||||
self._dirt_rng.shuffle(free_for_dirt)
|
||||
if initial_spawn:
|
||||
@@ -123,7 +123,7 @@ class DirtFactory(BaseFactory):
|
||||
new_pos_dirt = self[c.DIRT].by_pos(agent.pos)
|
||||
new_pos_dirt.set_new_amount(max(0, new_pos_dirt.amount + smeared_dirt))
|
||||
if self._next_dirt_spawn < 0:
|
||||
pass # No Dirt Spawn
|
||||
pass # No DirtPile Spawn
|
||||
elif not self._next_dirt_spawn:
|
||||
self.trigger_dirt_spawn()
|
||||
self._next_dirt_spawn = self.dirt_prop.spawn_frequency
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import List, Union, Dict
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
from environments.factory.additional.item.item_collections import ItemRegister, Inventories, DropOffLocations
|
||||
from environments.factory.additional.item.item_collections import Items, Inventories, DropOffLocations
|
||||
from environments.factory.additional.item.item_util import Constants, Actions, RewardsItem, ItemProperties
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.factory.base.objects import Agent, Action
|
||||
@@ -49,7 +49,7 @@ class ItemFactory(BaseFactory):
|
||||
entity_kwargs=dict(
|
||||
storage_size_until_full=self.item_prop.max_dropoff_storage_size)
|
||||
)
|
||||
item_register = ItemRegister(self._level_shape)
|
||||
item_register = Items(self._level_shape)
|
||||
empty_tiles = self[c.FLOOR].empty_tiles[:self.item_prop.n_items]
|
||||
item_register.spawn_items(empty_tiles)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from environments.factory.base.registers import EntityCollection, BoundEnvObjCol
|
||||
from environments.factory.additional.item.item_entities import Item, DropOffLocation
|
||||
|
||||
|
||||
class ItemRegister(EntityCollection):
|
||||
class Items(EntityCollection):
|
||||
|
||||
_accepted_objects = Item
|
||||
|
||||
@@ -37,9 +37,9 @@ class Inventory(BoundEnvObjCollection):
|
||||
return super(Inventory, self).as_array()
|
||||
|
||||
def summarize_states(self, **kwargs):
|
||||
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||
attr_dict.update(dict(items={key: val.summarize_state(**kwargs) for key, val in self.items()}))
|
||||
attr_dict.update(dict(name=self.name))
|
||||
attr_dict = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||
attr_dict.update(dict(items=[val.summarize_state(**kwargs) for key, val in self.items()]))
|
||||
attr_dict.update(dict(name=self.name, belongs_to=self._bound_entity.name))
|
||||
return attr_dict
|
||||
|
||||
def pop(self):
|
||||
@@ -79,9 +79,11 @@ class Inventories(ObjectCollection):
|
||||
return None
|
||||
|
||||
def summarize_states(self, **kwargs):
|
||||
return {key: val.summarize_states(**kwargs) for key, val in self.items()}
|
||||
return [val.summarize_states(**kwargs) for key, val in self.items()]
|
||||
|
||||
|
||||
class DropOffLocations(EntityCollection):
|
||||
|
||||
_accepted_objects = DropOffLocation
|
||||
_stateless_entities = True
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ class Item(Entity):
|
||||
def set_tile_to(self, no_pos_tile):
|
||||
self._tile = no_pos_tile
|
||||
|
||||
def summarize_state(self, **_) -> dict:
|
||||
def summarize_state(self) -> dict:
|
||||
super_summarization = super(Item, self).summarize_state()
|
||||
super_summarization.update(dict(auto_despawn=self.auto_despawn))
|
||||
return super_summarization
|
||||
@@ -55,7 +55,3 @@ class DropOffLocation(Entity):
|
||||
@property
|
||||
def is_full(self):
|
||||
return False if not self.storage.maxlen else self.storage.maxlen == len(self.storage)
|
||||
|
||||
def summarize_state(self, n_steps=None) -> dict:
|
||||
if n_steps == h.STEPS_START:
|
||||
return super().summarize_state(n_steps=n_steps)
|
||||
|
||||
Reference in New Issue
Block a user