recorder fixed
This commit is contained in:
@@ -1,13 +1,13 @@
|
||||
from environments.factory.additional.dirt.dirt_entity import Dirt
|
||||
from environments.factory.additional.dirt.dirt_entity import DirtPile
|
||||
from environments.factory.additional.dirt.dirt_util import DirtProperties
|
||||
from environments.factory.base.objects import Floor
|
||||
from environments.factory.base.registers import EntityCollection
|
||||
from environments.factory.additional.dirt.dirt_util import Constants as c
|
||||
|
||||
|
||||
class DirtRegister(EntityCollection):
|
||||
class DirtPiles(EntityCollection):
|
||||
|
||||
_accepted_objects = Dirt
|
||||
_accepted_objects = DirtPile
|
||||
|
||||
@property
|
||||
def amount(self):
|
||||
@@ -18,7 +18,7 @@ class DirtRegister(EntityCollection):
|
||||
return self._dirt_properties
|
||||
|
||||
def __init__(self, dirt_properties, *args):
|
||||
super(DirtRegister, self).__init__(*args)
|
||||
super(DirtPiles, self).__init__(*args)
|
||||
self._dirt_properties: DirtProperties = dirt_properties
|
||||
|
||||
def spawn_dirt(self, then_dirty_tiles) -> bool:
|
||||
@@ -28,7 +28,7 @@ class DirtRegister(EntityCollection):
|
||||
if not self.amount > self.dirt_properties.max_global_amount:
|
||||
dirt = self.by_pos(tile.pos)
|
||||
if dirt is None:
|
||||
dirt = Dirt(tile, self, amount=self.dirt_properties.max_spawn_amount)
|
||||
dirt = DirtPile(tile, self, amount=self.dirt_properties.max_spawn_amount)
|
||||
self.add_item(dirt)
|
||||
else:
|
||||
new_value = dirt.amount + self.dirt_properties.max_spawn_amount
|
||||
@@ -38,5 +38,5 @@ class DirtRegister(EntityCollection):
|
||||
return c.VALID
|
||||
|
||||
def __repr__(self):
|
||||
s = super(DirtRegister, self).__repr__()
|
||||
s = super(DirtPiles, self).__repr__()
|
||||
return f'{s[:-1]}, {self.amount})'
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from environments.factory.base.objects import Entity
|
||||
|
||||
|
||||
class Dirt(Entity):
|
||||
class DirtPile(Entity):
|
||||
|
||||
@property
|
||||
def amount(self):
|
||||
@@ -13,14 +13,14 @@ class Dirt(Entity):
|
||||
return self._amount
|
||||
|
||||
def __init__(self, *args, amount=None, **kwargs):
|
||||
super(Dirt, self).__init__(*args, **kwargs)
|
||||
super(DirtPile, self).__init__(*args, **kwargs)
|
||||
self._amount = amount
|
||||
|
||||
def set_new_amount(self, amount):
|
||||
self._amount = amount
|
||||
self._collection.notify_change_to_value(self)
|
||||
|
||||
def summarize_state(self, **kwargs):
|
||||
state_dict = super().summarize_state(**kwargs)
|
||||
def summarize_state(self):
|
||||
state_dict = super().summarize_state()
|
||||
state_dict.update(amount=float(self.amount))
|
||||
return state_dict
|
||||
|
||||
@@ -4,7 +4,7 @@ from environments.helpers import Constants as BaseConstants, EnvActions as BaseA
|
||||
|
||||
|
||||
class Constants(BaseConstants):
|
||||
DIRT = 'Dirt'
|
||||
DIRT = 'DirtPile'
|
||||
|
||||
|
||||
class Actions(BaseActions):
|
||||
|
||||
@@ -5,8 +5,8 @@ import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
from environments.factory.additional.dirt.dirt_collections import DirtRegister
|
||||
from environments.factory.additional.dirt.dirt_entity import Dirt
|
||||
from environments.factory.additional.dirt.dirt_collections import DirtPiles
|
||||
from environments.factory.additional.dirt.dirt_entity import DirtPile
|
||||
from environments.factory.additional.dirt.dirt_util import Constants, Actions, RewardsDirt, DirtProperties
|
||||
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
@@ -43,7 +43,7 @@ class DirtFactory(BaseFactory):
|
||||
@property
|
||||
def entities_hook(self) -> Dict[(str, Entities)]:
|
||||
super_entities = super().entities_hook
|
||||
dirt_register = DirtRegister(self.dirt_prop, self._level_shape)
|
||||
dirt_register = DirtPiles(self.dirt_prop, self._level_shape)
|
||||
super_entities.update(({c.DIRT: dirt_register}))
|
||||
return super_entities
|
||||
|
||||
@@ -57,7 +57,7 @@ class DirtFactory(BaseFactory):
|
||||
self.dirt_prop = dirt_prop
|
||||
self.rewards_dirt = rewards_dirt
|
||||
self._dirt_rng = np.random.default_rng(env_seed)
|
||||
self._dirt: DirtRegister
|
||||
self._dirt: DirtPiles
|
||||
kwargs.update(env_seed=env_seed)
|
||||
# TODO: Reset ---> document this
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -96,7 +96,7 @@ class DirtFactory(BaseFactory):
|
||||
def trigger_dirt_spawn(self, initial_spawn=False):
|
||||
dirt_rng = self._dirt_rng
|
||||
free_for_dirt = [x for x in self[c.FLOOR]
|
||||
if len(x.guests) == 0 or (len(x.guests) == 1 and isinstance(next(y for y in x.guests), Dirt))
|
||||
if len(x.guests) == 0 or (len(x.guests) == 1 and isinstance(next(y for y in x.guests), DirtPile))
|
||||
]
|
||||
self._dirt_rng.shuffle(free_for_dirt)
|
||||
if initial_spawn:
|
||||
@@ -123,7 +123,7 @@ class DirtFactory(BaseFactory):
|
||||
new_pos_dirt = self[c.DIRT].by_pos(agent.pos)
|
||||
new_pos_dirt.set_new_amount(max(0, new_pos_dirt.amount + smeared_dirt))
|
||||
if self._next_dirt_spawn < 0:
|
||||
pass # No Dirt Spawn
|
||||
pass # No DirtPile Spawn
|
||||
elif not self._next_dirt_spawn:
|
||||
self.trigger_dirt_spawn()
|
||||
self._next_dirt_spawn = self.dirt_prop.spawn_frequency
|
||||
|
||||
Reference in New Issue
Block a user