Rework of Observations and Entity Differentiation, lazy obs build by notification

This commit is contained in:
Steffen Illium
2021-12-22 10:48:36 +01:00
parent 7f7a3d9a3b
commit b43f595207
14 changed files with 961 additions and 487 deletions

View File

@ -6,11 +6,11 @@ import numpy as np
import random
from environments.factory.base.base_factory import BaseFactory
from environments.helpers import Constants as c
from environments.helpers import Constants as c, Constants
from environments import helpers as h
from environments.factory.base.objects import Agent, Entity, Action, Tile, MoveableEntity
from environments.factory.base.registers import Entities, EntityObjectRegister, ObjectRegister, \
MovingEntityObjectRegister
from environments.factory.base.registers import Entities, EntityRegister, EnvObjectRegister, MovingEntityObjectRegister, \
BoundRegisterMixin
from environments.factory.base.renderer import RenderEntity
@ -19,7 +19,7 @@ NO_ITEM = 0
ITEM_DROP_OFF = 1
class Item(MoveableEntity):
class Item(Entity):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -41,20 +41,21 @@ class Item(MoveableEntity):
def set_auto_despawn(self, auto_despawn):
self._auto_despawn = auto_despawn
def despawn(self):
# Todo: Move this to base class?
curr_tile = self.tile
curr_tile.leave(self)
self._tile = None
self._register.notify_change_to_value(self)
return True
class ItemRegister(MovingEntityObjectRegister):
def as_array(self):
self._array[:] = c.FREE_CELL.value
for item in self:
if item.pos != c.NO_POS.value:
self._array[0, item.x, item.y] = item.encoding
return self._array
class ItemRegister(EntityRegister):
_accepted_objects = Item
def spawn_items(self, tiles: List[Tile]):
items = [Item(tile) for tile in tiles]
items = [Item(tile, self) for tile in tiles]
self.register_additional_items(items)
def despawn_items(self, items: List[Item]):
@ -63,7 +64,7 @@ class ItemRegister(MovingEntityObjectRegister):
del self[item]
class Inventory(UserList):
class Inventory(EntityRegister, BoundRegisterMixin):
@property
def is_blocking_light(self):
@ -73,19 +74,18 @@ class Inventory(UserList):
def name(self):
return f'{self.__class__.__name__}({self.agent.name})'
def __init__(self, pomdp_r: int, level_shape: (int, int), agent: Agent, capacity: int):
def __init__(self, obs_shape: (int, int), agent: Agent, capacity: int):
super(Inventory, self).__init__()
self.agent = agent
self.pomdp_r = pomdp_r
self._level_shape = level_shape
if self.pomdp_r:
self._array = np.zeros((1, pomdp_r * 2 + 1, pomdp_r * 2 + 1))
else:
self._array = np.zeros((1, *self._level_shape))
self._obs_shape = obs_shape
self._array = np.zeros((1, *self._obs_shape))
self.capacity = min(capacity, self._array.size)
def as_array(self):
self._array[:] = c.FREE_CELL.value
# ToDo: Make this Lazy
for item_idx, item in enumerate(self):
x_diff, y_diff = divmod(item_idx, self._array.shape[1])
self._array[0, int(x_diff), int(y_diff)] = item.encoding
@ -110,25 +110,22 @@ class Inventory(UserList):
return attr_dict
class Inventories(ObjectRegister):
class Inventories(EnvObjectRegister):
_accepted_objects = Inventory
is_blocking_light = False
can_be_shadowed = False
hide_from_obs_builder = True
def __init__(self, *args, **kwargs):
def __init__(self, obs_shape, *args, **kwargs):
super(Inventories, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
self.is_observable = True
self._obs_shape = obs_shape
def as_array(self):
# self._array[:] = c.FREE_CELL.value
for inv_idx, inventory in enumerate(self):
self._array[inv_idx] = inventory.as_array()
return self._array
return np.stack([inventory.as_array() for inv_idx, inventory in enumerate(self)])
def spawn_inventories(self, agents, pomdp_r, capacity):
inventories = [self._accepted_objects(pomdp_r, self._level_shape, agent, capacity)
def spawn_inventories(self, agents, capacity):
inventories = [self._accepted_objects(self._obs_shape, agent, capacity)
for _, agent in enumerate(agents)]
self.register_additional_items(inventories)
@ -183,20 +180,20 @@ class DropOffLocation(Entity):
return super().summarize_state(n_steps=n_steps)
class DropOffLocations(EntityObjectRegister):
class DropOffLocations(EntityRegister):
_accepted_objects = DropOffLocation
def as_array(self):
@DeprecationWarning
def Xas_array(self):
# Todo: Which is faster?
# indices = list(zip(range(len(self)), *zip(*[x.pos for x in self])))
# np.put(self._array, [np.ravel_multi_index(x, self._array.shape) for x in indices], self.encodings)
self._array[:] = c.FREE_CELL.value
for item in self:
if item.pos != c.NO_POS.value:
self._array[0, item.x, item.y] = item.encoding
indices = list(zip([0, ] * len(self), *zip(*[x.pos for x in self])))
np.put(self._array, [np.ravel_multi_index(x, self._array.shape) for x in indices], self.encodings)
return self._array
def __repr__(self):
super(DropOffLocations, self).__repr__()
class ItemProperties(NamedTuple):
n_items: int = 5 # How many items are there at the same time
@ -241,17 +238,23 @@ class ItemFactory(BaseFactory):
empty_tiles = self[c.FLOOR].empty_tiles[:self.item_prop.n_items]
item_register.spawn_items(empty_tiles)
inventories = Inventories(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2))
inventories.spawn_inventories(self[c.AGENT], self._pomdp_r,
self.item_prop.max_agent_inventory_capacity)
inventories = Inventories(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2),
self._level_shape)
inventories.spawn_inventories(self[c.AGENT], self.item_prop.max_agent_inventory_capacity)
super_entities.update({c.DROP_OFF: drop_offs, c.ITEM: item_register, c.INVENTORY: inventories})
return super_entities
def additional_per_agent_obs_build(self, agent) -> List[np.ndarray]:
additional_per_agent_obs_build = super().additional_per_agent_obs_build(agent)
additional_per_agent_obs_build.append(self[c.INVENTORY].by_entity(agent).as_array())
return additional_per_agent_obs_build
def _additional_raw_observations(self, agent) -> Dict[Constants, np.typing.ArrayLike]:
additional_raw_observations = super()._additional_raw_observations(agent)
additional_raw_observations.update({c.INVENTORY: self[c.INVENTORY].by_entity(agent).as_array()})
return additional_raw_observations
def _additional_observations(self) -> Dict[Constants, np.typing.ArrayLike]:
additional_observations = super()._additional_observations()
additional_observations.update({c.ITEM: self[c.ITEM].as_array()})
additional_observations.update({c.DROP_OFF: self[c.DROP_OFF].as_array()})
return additional_observations
def do_item_action(self, agent: Agent):
inventory = self[c.INVENTORY].by_entity(agent)
@ -264,7 +267,7 @@ class ItemFactory(BaseFactory):
elif item := self[c.ITEM].by_pos(agent.pos):
try:
inventory.append(item)
item.move(self._NO_POS_TILE)
item.despawn()
return c.VALID
except RuntimeError:
return c.NOT_VALID
@ -308,7 +311,7 @@ class ItemFactory(BaseFactory):
if item.auto_despawn >= 1:
item.set_auto_despawn(item.auto_despawn-1)
elif not item.auto_despawn:
self[c.ITEM].delete_entity(item)
self[c.ITEM].delete_env_object(item)
else:
pass
@ -327,12 +330,12 @@ class ItemFactory(BaseFactory):
info_dict.update({f'{agent.name}_item_drop_off': 1})
info_dict.update(item_drop_off=1)
self.print(f'{agent.name} just dropped of an item at {drop_off.pos}.')
reward += 0.5
reward += 1
else:
info_dict.update({f'{agent.name}_item_pickup': 1})
info_dict.update(item_pickup=1)
self.print(f'{agent.name} just picked up an item at {agent.pos}')
reward += 0.1
reward += 0.2
else:
if self[c.DROP_OFF].by_pos(agent.pos):
info_dict.update({f'{agent.name}_failed_drop_off': 1})
@ -363,13 +366,13 @@ if __name__ == '__main__':
item_probs = ItemProperties()
obs_props = ObservationProperties(render_agents=ARO.LEVEL, omit_agent_self=True, pomdp_r=2)
obs_props = ObservationProperties(render_agents=ARO.SEPERATE, omit_agent_self=True, pomdp_r=2)
move_props = {'allow_square_movement': True,
'allow_diagonal_movement': False,
'allow_diagonal_movement': True,
'allow_no_op': False}
factory = ItemFactory(n_agents=3, done_at_collision=False,
factory = ItemFactory(n_agents=2, done_at_collision=False,
level_name='rooms', max_steps=400,
obs_prop=obs_props, parse_doors=True,
record_episodes=True, verbose=True,
@ -378,7 +381,8 @@ if __name__ == '__main__':
# noinspection DuplicatedCode
n_actions = factory.action_space.n - 1
_ = factory.observation_space
obs_space = factory.observation_space
obs_space_named = factory.named_observation_space
for epoch in range(4):
random_actions = [[random.randint(0, n_actions) for _