mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
Smaller Bug Fixes and improvements
This commit is contained in:
parent
c3d4925653
commit
bd0a8090ab
@ -30,7 +30,7 @@ class DropOffLocation(Entity):
|
|||||||
|
|
||||||
def __init__(self, *args, storage_size_until_full: int = 5, **kwargs):
|
def __init__(self, *args, storage_size_until_full: int = 5, **kwargs):
|
||||||
super(DropOffLocation, self).__init__(*args, **kwargs)
|
super(DropOffLocation, self).__init__(*args, **kwargs)
|
||||||
self.storage = deque(maxlen=storage_size_until_full)
|
self.storage = deque(maxlen=storage_size_until_full or None)
|
||||||
|
|
||||||
def place_item(self, item):
|
def place_item(self, item):
|
||||||
if self.is_full:
|
if self.is_full:
|
||||||
@ -102,10 +102,10 @@ class DoubleTaskFactory(SimpleFactory):
|
|||||||
# Hard reset the Inventory Stat in OBS cube
|
# Hard reset the Inventory Stat in OBS cube
|
||||||
self._slices[agent_slice_idx].slice[:] = 0
|
self._slices[agent_slice_idx].slice[:] = 0
|
||||||
if len(agent.inventory) > 0:
|
if len(agent.inventory) > 0:
|
||||||
max_x = self.pomdp_r if self.pomdp_r else self._level_shape[0]
|
max_x = self.pomdp_r * 2 + 1 if self.pomdp_r else self._level_shape[0]
|
||||||
x, y = (0, 0) if not self.pomdp_r else (max(agent.x - max_x, 0), max(agent.y - max_x, 0))
|
x, y = (0, 0) if not self.pomdp_r else (max(agent.x - self.pomdp_r, 0), max(agent.y - self.pomdp_r, 0))
|
||||||
for item in agent.inventory:
|
for item_idx, item in enumerate(agent.inventory):
|
||||||
x_diff, y_diff = divmod(item, max_x)
|
x_diff, y_diff = divmod(item_idx, max_x)
|
||||||
self._slices[agent_slice_idx].slice[int(x+x_diff), int(y+y_diff)] = item
|
self._slices[agent_slice_idx].slice[int(x+x_diff), int(y+y_diff)] = item
|
||||||
self._obs_cube[agent_slice_idx] = self._slices[agent_slice_idx].slice
|
self._obs_cube[agent_slice_idx] = self._slices[agent_slice_idx].slice
|
||||||
|
|
||||||
@ -129,7 +129,8 @@ class DoubleTaskFactory(SimpleFactory):
|
|||||||
return c.NOT_VALID
|
return c.NOT_VALID
|
||||||
|
|
||||||
elif item != NO_ITEM:
|
elif item != NO_ITEM:
|
||||||
if len(agent.inventory) < self.item_properties.max_agent_storage_size:
|
max_sto_size = self.item_properties.max_agent_storage_size or np.prod(self.observation_space.shape[1:])
|
||||||
|
if len(agent.inventory) < max_sto_size:
|
||||||
agent.inventory.append(item_slice[agent.pos])
|
agent.inventory.append(item_slice[agent.pos])
|
||||||
item_slice[agent.pos] = NO_ITEM
|
item_slice[agent.pos] = NO_ITEM
|
||||||
else:
|
else:
|
||||||
|
@ -82,7 +82,8 @@ class SimpleFactory(BaseFactory):
|
|||||||
|
|
||||||
def _flush_state(self):
|
def _flush_state(self):
|
||||||
super(SimpleFactory, self)._flush_state()
|
super(SimpleFactory, self)._flush_state()
|
||||||
self._obs_cube[self._slices.get_idx(c.DIRT)] = self._slices.by_enum(c.DIRT).slice
|
dirt_slice_idx = self._slices.get_idx(c.DIRT)
|
||||||
|
self._obs_cube[dirt_slice_idx] = self._slices[dirt_slice_idx].slice
|
||||||
|
|
||||||
def render_additional_assets(self, mode='human'):
|
def render_additional_assets(self, mode='human'):
|
||||||
additional_assets = super(SimpleFactory, self).render_additional_assets()
|
additional_assets = super(SimpleFactory, self).render_additional_assets()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user