Rework of Observations and Entity Differentiation, lazy obs build by notification
This commit is contained in:
@ -6,11 +6,11 @@ import numpy as np
|
||||
import random
|
||||
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.helpers import Constants as c
|
||||
from environments.helpers import Constants as c, Constants
|
||||
from environments import helpers as h
|
||||
from environments.factory.base.objects import Agent, Entity, Action, Tile, MoveableEntity
|
||||
from environments.factory.base.registers import Entities, EntityObjectRegister, ObjectRegister, \
|
||||
MovingEntityObjectRegister
|
||||
from environments.factory.base.registers import Entities, EntityRegister, EnvObjectRegister, MovingEntityObjectRegister, \
|
||||
BoundRegisterMixin
|
||||
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
|
||||
@ -19,7 +19,7 @@ NO_ITEM = 0
|
||||
ITEM_DROP_OFF = 1
|
||||
|
||||
|
||||
class Item(MoveableEntity):
|
||||
class Item(Entity):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -41,20 +41,21 @@ class Item(MoveableEntity):
|
||||
def set_auto_despawn(self, auto_despawn):
|
||||
self._auto_despawn = auto_despawn
|
||||
|
||||
def despawn(self):
|
||||
# Todo: Move this to base class?
|
||||
curr_tile = self.tile
|
||||
curr_tile.leave(self)
|
||||
self._tile = None
|
||||
self._register.notify_change_to_value(self)
|
||||
return True
|
||||
|
||||
class ItemRegister(MovingEntityObjectRegister):
|
||||
|
||||
def as_array(self):
|
||||
self._array[:] = c.FREE_CELL.value
|
||||
for item in self:
|
||||
if item.pos != c.NO_POS.value:
|
||||
self._array[0, item.x, item.y] = item.encoding
|
||||
return self._array
|
||||
class ItemRegister(EntityRegister):
|
||||
|
||||
_accepted_objects = Item
|
||||
|
||||
def spawn_items(self, tiles: List[Tile]):
|
||||
items = [Item(tile) for tile in tiles]
|
||||
items = [Item(tile, self) for tile in tiles]
|
||||
self.register_additional_items(items)
|
||||
|
||||
def despawn_items(self, items: List[Item]):
|
||||
@ -63,7 +64,7 @@ class ItemRegister(MovingEntityObjectRegister):
|
||||
del self[item]
|
||||
|
||||
|
||||
class Inventory(UserList):
|
||||
class Inventory(EntityRegister, BoundRegisterMixin):
|
||||
|
||||
@property
|
||||
def is_blocking_light(self):
|
||||
@ -73,19 +74,18 @@ class Inventory(UserList):
|
||||
def name(self):
|
||||
return f'{self.__class__.__name__}({self.agent.name})'
|
||||
|
||||
def __init__(self, pomdp_r: int, level_shape: (int, int), agent: Agent, capacity: int):
|
||||
def __init__(self, obs_shape: (int, int), agent: Agent, capacity: int):
|
||||
super(Inventory, self).__init__()
|
||||
self.agent = agent
|
||||
self.pomdp_r = pomdp_r
|
||||
self._level_shape = level_shape
|
||||
if self.pomdp_r:
|
||||
self._array = np.zeros((1, pomdp_r * 2 + 1, pomdp_r * 2 + 1))
|
||||
else:
|
||||
self._array = np.zeros((1, *self._level_shape))
|
||||
self._obs_shape = obs_shape
|
||||
|
||||
self._array = np.zeros((1, *self._obs_shape))
|
||||
|
||||
self.capacity = min(capacity, self._array.size)
|
||||
|
||||
def as_array(self):
|
||||
self._array[:] = c.FREE_CELL.value
|
||||
# ToDo: Make this Lazy
|
||||
for item_idx, item in enumerate(self):
|
||||
x_diff, y_diff = divmod(item_idx, self._array.shape[1])
|
||||
self._array[0, int(x_diff), int(y_diff)] = item.encoding
|
||||
@ -110,25 +110,22 @@ class Inventory(UserList):
|
||||
return attr_dict
|
||||
|
||||
|
||||
class Inventories(ObjectRegister):
|
||||
class Inventories(EnvObjectRegister):
|
||||
|
||||
_accepted_objects = Inventory
|
||||
is_blocking_light = False
|
||||
can_be_shadowed = False
|
||||
hide_from_obs_builder = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, obs_shape, *args, **kwargs):
|
||||
super(Inventories, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
|
||||
self.is_observable = True
|
||||
self._obs_shape = obs_shape
|
||||
|
||||
def as_array(self):
|
||||
# self._array[:] = c.FREE_CELL.value
|
||||
for inv_idx, inventory in enumerate(self):
|
||||
self._array[inv_idx] = inventory.as_array()
|
||||
return self._array
|
||||
return np.stack([inventory.as_array() for inv_idx, inventory in enumerate(self)])
|
||||
|
||||
def spawn_inventories(self, agents, pomdp_r, capacity):
|
||||
inventories = [self._accepted_objects(pomdp_r, self._level_shape, agent, capacity)
|
||||
def spawn_inventories(self, agents, capacity):
|
||||
inventories = [self._accepted_objects(self._obs_shape, agent, capacity)
|
||||
for _, agent in enumerate(agents)]
|
||||
self.register_additional_items(inventories)
|
||||
|
||||
@ -183,20 +180,20 @@ class DropOffLocation(Entity):
|
||||
return super().summarize_state(n_steps=n_steps)
|
||||
|
||||
|
||||
class DropOffLocations(EntityObjectRegister):
|
||||
class DropOffLocations(EntityRegister):
|
||||
|
||||
_accepted_objects = DropOffLocation
|
||||
|
||||
def as_array(self):
|
||||
@DeprecationWarning
|
||||
def Xas_array(self):
|
||||
# Todo: Which is faster?
|
||||
# indices = list(zip(range(len(self)), *zip(*[x.pos for x in self])))
|
||||
# np.put(self._array, [np.ravel_multi_index(x, self._array.shape) for x in indices], self.encodings)
|
||||
self._array[:] = c.FREE_CELL.value
|
||||
for item in self:
|
||||
if item.pos != c.NO_POS.value:
|
||||
self._array[0, item.x, item.y] = item.encoding
|
||||
indices = list(zip([0, ] * len(self), *zip(*[x.pos for x in self])))
|
||||
np.put(self._array, [np.ravel_multi_index(x, self._array.shape) for x in indices], self.encodings)
|
||||
return self._array
|
||||
|
||||
def __repr__(self):
|
||||
super(DropOffLocations, self).__repr__()
|
||||
|
||||
|
||||
class ItemProperties(NamedTuple):
|
||||
n_items: int = 5 # How many items are there at the same time
|
||||
@ -241,17 +238,23 @@ class ItemFactory(BaseFactory):
|
||||
empty_tiles = self[c.FLOOR].empty_tiles[:self.item_prop.n_items]
|
||||
item_register.spawn_items(empty_tiles)
|
||||
|
||||
inventories = Inventories(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2))
|
||||
inventories.spawn_inventories(self[c.AGENT], self._pomdp_r,
|
||||
self.item_prop.max_agent_inventory_capacity)
|
||||
inventories = Inventories(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2),
|
||||
self._level_shape)
|
||||
inventories.spawn_inventories(self[c.AGENT], self.item_prop.max_agent_inventory_capacity)
|
||||
|
||||
super_entities.update({c.DROP_OFF: drop_offs, c.ITEM: item_register, c.INVENTORY: inventories})
|
||||
return super_entities
|
||||
|
||||
def additional_per_agent_obs_build(self, agent) -> List[np.ndarray]:
|
||||
additional_per_agent_obs_build = super().additional_per_agent_obs_build(agent)
|
||||
additional_per_agent_obs_build.append(self[c.INVENTORY].by_entity(agent).as_array())
|
||||
return additional_per_agent_obs_build
|
||||
def _additional_raw_observations(self, agent) -> Dict[Constants, np.typing.ArrayLike]:
|
||||
additional_raw_observations = super()._additional_raw_observations(agent)
|
||||
additional_raw_observations.update({c.INVENTORY: self[c.INVENTORY].by_entity(agent).as_array()})
|
||||
return additional_raw_observations
|
||||
|
||||
def _additional_observations(self) -> Dict[Constants, np.typing.ArrayLike]:
|
||||
additional_observations = super()._additional_observations()
|
||||
additional_observations.update({c.ITEM: self[c.ITEM].as_array()})
|
||||
additional_observations.update({c.DROP_OFF: self[c.DROP_OFF].as_array()})
|
||||
return additional_observations
|
||||
|
||||
def do_item_action(self, agent: Agent):
|
||||
inventory = self[c.INVENTORY].by_entity(agent)
|
||||
@ -264,7 +267,7 @@ class ItemFactory(BaseFactory):
|
||||
elif item := self[c.ITEM].by_pos(agent.pos):
|
||||
try:
|
||||
inventory.append(item)
|
||||
item.move(self._NO_POS_TILE)
|
||||
item.despawn()
|
||||
return c.VALID
|
||||
except RuntimeError:
|
||||
return c.NOT_VALID
|
||||
@ -308,7 +311,7 @@ class ItemFactory(BaseFactory):
|
||||
if item.auto_despawn >= 1:
|
||||
item.set_auto_despawn(item.auto_despawn-1)
|
||||
elif not item.auto_despawn:
|
||||
self[c.ITEM].delete_entity(item)
|
||||
self[c.ITEM].delete_env_object(item)
|
||||
else:
|
||||
pass
|
||||
|
||||
@ -327,12 +330,12 @@ class ItemFactory(BaseFactory):
|
||||
info_dict.update({f'{agent.name}_item_drop_off': 1})
|
||||
info_dict.update(item_drop_off=1)
|
||||
self.print(f'{agent.name} just dropped of an item at {drop_off.pos}.')
|
||||
reward += 0.5
|
||||
reward += 1
|
||||
else:
|
||||
info_dict.update({f'{agent.name}_item_pickup': 1})
|
||||
info_dict.update(item_pickup=1)
|
||||
self.print(f'{agent.name} just picked up an item at {agent.pos}')
|
||||
reward += 0.1
|
||||
reward += 0.2
|
||||
else:
|
||||
if self[c.DROP_OFF].by_pos(agent.pos):
|
||||
info_dict.update({f'{agent.name}_failed_drop_off': 1})
|
||||
@ -363,13 +366,13 @@ if __name__ == '__main__':
|
||||
|
||||
item_probs = ItemProperties()
|
||||
|
||||
obs_props = ObservationProperties(render_agents=ARO.LEVEL, omit_agent_self=True, pomdp_r=2)
|
||||
obs_props = ObservationProperties(render_agents=ARO.SEPERATE, omit_agent_self=True, pomdp_r=2)
|
||||
|
||||
move_props = {'allow_square_movement': True,
|
||||
'allow_diagonal_movement': False,
|
||||
'allow_diagonal_movement': True,
|
||||
'allow_no_op': False}
|
||||
|
||||
factory = ItemFactory(n_agents=3, done_at_collision=False,
|
||||
factory = ItemFactory(n_agents=2, done_at_collision=False,
|
||||
level_name='rooms', max_steps=400,
|
||||
obs_prop=obs_props, parse_doors=True,
|
||||
record_episodes=True, verbose=True,
|
||||
@ -378,7 +381,8 @@ if __name__ == '__main__':
|
||||
|
||||
# noinspection DuplicatedCode
|
||||
n_actions = factory.action_space.n - 1
|
||||
_ = factory.observation_space
|
||||
obs_space = factory.observation_space
|
||||
obs_space_named = factory.named_observation_space
|
||||
|
||||
for epoch in range(4):
|
||||
random_actions = [[random.randint(0, n_actions) for _
|
||||
|
Reference in New Issue
Block a user