From bd0a8090ab239a2d7622870ba82b78088dbd5fdb Mon Sep 17 00:00:00 2001 From: Steffen Illium Date: Tue, 24 Aug 2021 08:55:23 +0200 Subject: [PATCH] Smaller Bug Fixes and improvements --- environments/factory/double_task_factory.py | 13 +++++++------ environments/factory/simple_factory.py | 3 ++- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/environments/factory/double_task_factory.py b/environments/factory/double_task_factory.py index af8dc9c..76e537b 100644 --- a/environments/factory/double_task_factory.py +++ b/environments/factory/double_task_factory.py @@ -30,7 +30,7 @@ class DropOffLocation(Entity): def __init__(self, *args, storage_size_until_full: int = 5, **kwargs): super(DropOffLocation, self).__init__(*args, **kwargs) - self.storage = deque(maxlen=storage_size_until_full) + self.storage = deque(maxlen=storage_size_until_full or None) def place_item(self, item): if self.is_full: @@ -102,10 +102,10 @@ class DoubleTaskFactory(SimpleFactory): # Hard reset the Inventory Stat in OBS cube self._slices[agent_slice_idx].slice[:] = 0 if len(agent.inventory) > 0: - max_x = self.pomdp_r if self.pomdp_r else self._level_shape[0] - x, y = (0, 0) if not self.pomdp_r else (max(agent.x - max_x, 0), max(agent.y - max_x, 0)) - for item in agent.inventory: - x_diff, y_diff = divmod(item, max_x) + max_x = self.pomdp_r * 2 + 1 if self.pomdp_r else self._level_shape[0] + x, y = (0, 0) if not self.pomdp_r else (max(agent.x - self.pomdp_r, 0), max(agent.y - self.pomdp_r, 0)) + for item_idx, item in enumerate(agent.inventory): + x_diff, y_diff = divmod(item_idx, max_x) self._slices[agent_slice_idx].slice[int(x+x_diff), int(y+y_diff)] = item self._obs_cube[agent_slice_idx] = self._slices[agent_slice_idx].slice @@ -129,7 +129,8 @@ class DoubleTaskFactory(SimpleFactory): return c.NOT_VALID elif item != NO_ITEM: - if len(agent.inventory) < self.item_properties.max_agent_storage_size: + max_sto_size = self.item_properties.max_agent_storage_size or np.prod(self.observation_space.shape[1:]) + if len(agent.inventory) < max_sto_size: agent.inventory.append(item_slice[agent.pos]) item_slice[agent.pos] = NO_ITEM else: diff --git a/environments/factory/simple_factory.py b/environments/factory/simple_factory.py index 20b340c..78f614d 100644 --- a/environments/factory/simple_factory.py +++ b/environments/factory/simple_factory.py @@ -82,7 +82,8 @@ class SimpleFactory(BaseFactory): def _flush_state(self): super(SimpleFactory, self)._flush_state() - self._obs_cube[self._slices.get_idx(c.DIRT)] = self._slices.by_enum(c.DIRT).slice + dirt_slice_idx = self._slices.get_idx(c.DIRT) + self._obs_cube[dirt_slice_idx] = self._slices[dirt_slice_idx].slice def render_additional_assets(self, mode='human'): additional_assets = super(SimpleFactory, self).render_additional_assets()