mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 15:26:43 +02:00
Smaller Bug Fixes and improvements
This commit is contained in:
parent
c3d4925653
commit
bd0a8090ab
@ -30,7 +30,7 @@ class DropOffLocation(Entity):
|
||||
|
||||
def __init__(self, *args, storage_size_until_full: int = 5, **kwargs):
|
||||
super(DropOffLocation, self).__init__(*args, **kwargs)
|
||||
self.storage = deque(maxlen=storage_size_until_full)
|
||||
self.storage = deque(maxlen=storage_size_until_full or None)
|
||||
|
||||
def place_item(self, item):
|
||||
if self.is_full:
|
||||
@ -102,10 +102,10 @@ class DoubleTaskFactory(SimpleFactory):
|
||||
# Hard reset the Inventory Stat in OBS cube
|
||||
self._slices[agent_slice_idx].slice[:] = 0
|
||||
if len(agent.inventory) > 0:
|
||||
max_x = self.pomdp_r if self.pomdp_r else self._level_shape[0]
|
||||
x, y = (0, 0) if not self.pomdp_r else (max(agent.x - max_x, 0), max(agent.y - max_x, 0))
|
||||
for item in agent.inventory:
|
||||
x_diff, y_diff = divmod(item, max_x)
|
||||
max_x = self.pomdp_r * 2 + 1 if self.pomdp_r else self._level_shape[0]
|
||||
x, y = (0, 0) if not self.pomdp_r else (max(agent.x - self.pomdp_r, 0), max(agent.y - self.pomdp_r, 0))
|
||||
for item_idx, item in enumerate(agent.inventory):
|
||||
x_diff, y_diff = divmod(item_idx, max_x)
|
||||
self._slices[agent_slice_idx].slice[int(x+x_diff), int(y+y_diff)] = item
|
||||
self._obs_cube[agent_slice_idx] = self._slices[agent_slice_idx].slice
|
||||
|
||||
@ -129,7 +129,8 @@ class DoubleTaskFactory(SimpleFactory):
|
||||
return c.NOT_VALID
|
||||
|
||||
elif item != NO_ITEM:
|
||||
if len(agent.inventory) < self.item_properties.max_agent_storage_size:
|
||||
max_sto_size = self.item_properties.max_agent_storage_size or np.prod(self.observation_space.shape[1:])
|
||||
if len(agent.inventory) < max_sto_size:
|
||||
agent.inventory.append(item_slice[agent.pos])
|
||||
item_slice[agent.pos] = NO_ITEM
|
||||
else:
|
||||
|
@ -82,7 +82,8 @@ class SimpleFactory(BaseFactory):
|
||||
|
||||
def _flush_state(self):
|
||||
super(SimpleFactory, self)._flush_state()
|
||||
self._obs_cube[self._slices.get_idx(c.DIRT)] = self._slices.by_enum(c.DIRT).slice
|
||||
dirt_slice_idx = self._slices.get_idx(c.DIRT)
|
||||
self._obs_cube[dirt_slice_idx] = self._slices[dirt_slice_idx].slice
|
||||
|
||||
def render_additional_assets(self, mode='human'):
|
||||
additional_assets = super(SimpleFactory, self).render_additional_assets()
|
||||
|
Loading…
x
Reference in New Issue
Block a user