Error Resolvement

This commit is contained in:
Steffen Illium
2021-09-07 17:41:15 +02:00
parent 444ffe3f37
commit 50c0d90c77
4 changed files with 55 additions and 40 deletions

View File

@ -9,7 +9,7 @@ from environments.helpers import Constants as c
from environments import helpers as h
from environments.factory.base.objects import Agent, Entity, Action, Tile, MoveableEntity
from environments.factory.base.registers import Entities, EntityObjectRegister, ObjectRegister, \
MovingEntityObjectRegister
MovingEntityObjectRegister, Register
from environments.factory.renderer import RenderEntity
@ -66,22 +66,19 @@ class Inventory(UserList):
def __init__(self, pomdp_r: int, level_shape: (int, int), agent: Agent, capacity: int):
super(Inventory, self).__init__()
self.agent = agent
self.capacity = capacity
self.pomdp_r = pomdp_r
self._level_shape = level_shape
self._array = np.zeros((1, *self._level_shape))
if self.pomdp_r:
self._array = np.zeros((1, pomdp_r * 2 + 1, pomdp_r * 2 + 1))
else:
self._array = np.zeros((1, *self._level_shape))
self.capacity = min(capacity, self._array.size)
def as_array(self):
self._array[:] = c.FREE_CELL.value
max_x = self.pomdp_r * 2 + 1 if self.pomdp_r else self._level_shape[0]
if self.pomdp_r:
x, y = max(self.agent.x - self.pomdp_r, 0), max(self.agent.y - self.pomdp_r, 0)
else:
x, y = (0, 0)
for item_idx, item in enumerate(self):
x_diff, y_diff = divmod(item_idx, max_x)
self._array[0, int(x + x_diff), int(y + y_diff)] = item.encoding
x_diff, y_diff = divmod(item_idx, self._array.shape[1])
self._array[0, int(x_diff), int(y_diff)] = item.encoding
return self._array
def __repr__(self):
@ -105,8 +102,9 @@ class Inventories(ObjectRegister):
_accepted_objects = Inventory
is_blocking_light = False
can_be_shadowed = False
hide_from_obs_builder = True
def __init__(self, *args, **kwargs):
def __init__(self, *args, pomdp_r=0, **kwargs):
super(Inventories, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
self.is_observable = True
@ -213,13 +211,18 @@ class ItemFactory(BaseFactory):
empty_tiles = self[c.FLOOR].empty_tiles[:self.item_properties.n_items]
item_register.spawn_items(empty_tiles)
inventories = Inventories(self._level_shape)
inventories = Inventories(self._level_shape if not self.pomdp_r else ((self.pomdp_diameter,) * 2))
inventories.spawn_inventories(self[c.AGENT], self.pomdp_r,
self.item_properties.max_agent_inventory_capacity)
super_entities.update({c.DROP_OFF: drop_offs, c.ITEM: item_register, c.INVENTORY: inventories})
return super_entities
def additional_obs_build(self) -> List[np.ndarray]:
super_additional_obs_build = super().additional_obs_build()
super_additional_obs_build.append(self[c.INVENTORY].as_array())
return super_additional_obs_build
def do_item_action(self, agent: Agent):
inventory = self[c.INVENTORY].by_entity(agent)
if drop_off := self[c.DROP_OFF].by_pos(agent.pos):
@ -285,7 +288,7 @@ class ItemFactory(BaseFactory):
if self[c.DROP_OFF].by_pos(agent.pos):
info_dict.update({f'{agent.name}_item_dropoff': 1})
reward += 1
reward += 0.5
else:
info_dict.update({f'{agent.name}_item_pickup': 1})
reward += 0.1