Smaller Bug Fixes and improvements

This commit is contained in:
Steffen Illium 2021-08-24 08:55:23 +02:00
parent c3d4925653
commit bd0a8090ab
2 changed files with 9 additions and 7 deletions

View File

@ -30,7 +30,7 @@ class DropOffLocation(Entity):
def __init__(self, *args, storage_size_until_full: int = 5, **kwargs):
super(DropOffLocation, self).__init__(*args, **kwargs)
self.storage = deque(maxlen=storage_size_until_full)
self.storage = deque(maxlen=storage_size_until_full or None)
def place_item(self, item):
if self.is_full:
@ -102,10 +102,10 @@ class DoubleTaskFactory(SimpleFactory):
# Hard reset the Inventory Stat in OBS cube
self._slices[agent_slice_idx].slice[:] = 0
if len(agent.inventory) > 0:
max_x = self.pomdp_r if self.pomdp_r else self._level_shape[0]
x, y = (0, 0) if not self.pomdp_r else (max(agent.x - max_x, 0), max(agent.y - max_x, 0))
for item in agent.inventory:
x_diff, y_diff = divmod(item, max_x)
max_x = self.pomdp_r * 2 + 1 if self.pomdp_r else self._level_shape[0]
x, y = (0, 0) if not self.pomdp_r else (max(agent.x - self.pomdp_r, 0), max(agent.y - self.pomdp_r, 0))
for item_idx, item in enumerate(agent.inventory):
x_diff, y_diff = divmod(item_idx, max_x)
self._slices[agent_slice_idx].slice[int(x+x_diff), int(y+y_diff)] = item
self._obs_cube[agent_slice_idx] = self._slices[agent_slice_idx].slice
@ -129,7 +129,8 @@ class DoubleTaskFactory(SimpleFactory):
return c.NOT_VALID
elif item != NO_ITEM:
if len(agent.inventory) < self.item_properties.max_agent_storage_size:
max_sto_size = self.item_properties.max_agent_storage_size or np.prod(self.observation_space.shape[1:])
if len(agent.inventory) < max_sto_size:
agent.inventory.append(item_slice[agent.pos])
item_slice[agent.pos] = NO_ITEM
else:

View File

@ -82,7 +82,8 @@ class SimpleFactory(BaseFactory):
def _flush_state(self):
super(SimpleFactory, self)._flush_state()
self._obs_cube[self._slices.get_idx(c.DIRT)] = self._slices.by_enum(c.DIRT).slice
dirt_slice_idx = self._slices.get_idx(c.DIRT)
self._obs_cube[dirt_slice_idx] = self._slices[dirt_slice_idx].slice
def render_additional_assets(self, mode='human'):
additional_assets = super(SimpleFactory, self).render_additional_assets()