mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-23 03:51:35 +02:00
Error Resolvement
This commit is contained in:
@ -9,7 +9,7 @@ from environments.helpers import Constants as c
|
||||
from environments import helpers as h
|
||||
from environments.factory.base.objects import Agent, Entity, Action, Tile, MoveableEntity
|
||||
from environments.factory.base.registers import Entities, EntityObjectRegister, ObjectRegister, \
|
||||
MovingEntityObjectRegister
|
||||
MovingEntityObjectRegister, Register
|
||||
|
||||
from environments.factory.renderer import RenderEntity
|
||||
|
||||
@ -66,22 +66,19 @@ class Inventory(UserList):
|
||||
def __init__(self, pomdp_r: int, level_shape: (int, int), agent: Agent, capacity: int):
|
||||
super(Inventory, self).__init__()
|
||||
self.agent = agent
|
||||
self.capacity = capacity
|
||||
self.pomdp_r = pomdp_r
|
||||
self._level_shape = level_shape
|
||||
self._array = np.zeros((1, *self._level_shape))
|
||||
if self.pomdp_r:
|
||||
self._array = np.zeros((1, pomdp_r * 2 + 1, pomdp_r * 2 + 1))
|
||||
else:
|
||||
self._array = np.zeros((1, *self._level_shape))
|
||||
self.capacity = min(capacity, self._array.size)
|
||||
|
||||
def as_array(self):
|
||||
self._array[:] = c.FREE_CELL.value
|
||||
max_x = self.pomdp_r * 2 + 1 if self.pomdp_r else self._level_shape[0]
|
||||
if self.pomdp_r:
|
||||
x, y = max(self.agent.x - self.pomdp_r, 0), max(self.agent.y - self.pomdp_r, 0)
|
||||
else:
|
||||
x, y = (0, 0)
|
||||
|
||||
for item_idx, item in enumerate(self):
|
||||
x_diff, y_diff = divmod(item_idx, max_x)
|
||||
self._array[0, int(x + x_diff), int(y + y_diff)] = item.encoding
|
||||
x_diff, y_diff = divmod(item_idx, self._array.shape[1])
|
||||
self._array[0, int(x_diff), int(y_diff)] = item.encoding
|
||||
return self._array
|
||||
|
||||
def __repr__(self):
|
||||
@ -105,8 +102,9 @@ class Inventories(ObjectRegister):
|
||||
_accepted_objects = Inventory
|
||||
is_blocking_light = False
|
||||
can_be_shadowed = False
|
||||
hide_from_obs_builder = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args, pomdp_r=0, **kwargs):
|
||||
super(Inventories, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
|
||||
self.is_observable = True
|
||||
|
||||
@ -213,13 +211,18 @@ class ItemFactory(BaseFactory):
|
||||
empty_tiles = self[c.FLOOR].empty_tiles[:self.item_properties.n_items]
|
||||
item_register.spawn_items(empty_tiles)
|
||||
|
||||
inventories = Inventories(self._level_shape)
|
||||
inventories = Inventories(self._level_shape if not self.pomdp_r else ((self.pomdp_diameter,) * 2))
|
||||
inventories.spawn_inventories(self[c.AGENT], self.pomdp_r,
|
||||
self.item_properties.max_agent_inventory_capacity)
|
||||
|
||||
super_entities.update({c.DROP_OFF: drop_offs, c.ITEM: item_register, c.INVENTORY: inventories})
|
||||
return super_entities
|
||||
|
||||
def additional_obs_build(self) -> List[np.ndarray]:
|
||||
super_additional_obs_build = super().additional_obs_build()
|
||||
super_additional_obs_build.append(self[c.INVENTORY].as_array())
|
||||
return super_additional_obs_build
|
||||
|
||||
def do_item_action(self, agent: Agent):
|
||||
inventory = self[c.INVENTORY].by_entity(agent)
|
||||
if drop_off := self[c.DROP_OFF].by_pos(agent.pos):
|
||||
@ -285,7 +288,7 @@ class ItemFactory(BaseFactory):
|
||||
if self[c.DROP_OFF].by_pos(agent.pos):
|
||||
info_dict.update({f'{agent.name}_item_dropoff': 1})
|
||||
|
||||
reward += 1
|
||||
reward += 0.5
|
||||
else:
|
||||
info_dict.update({f'{agent.name}_item_pickup': 1})
|
||||
reward += 0.1
|
||||
|
Reference in New Issue
Block a user