Item and Dirt Factory Working again
This commit is contained in:
@ -1,22 +1,30 @@
|
||||
import time
|
||||
from collections import deque, UserList
|
||||
from enum import Enum
|
||||
from collections import deque
|
||||
from typing import List, Union, NamedTuple, Dict
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.helpers import Constants as c, Constants
|
||||
from environments.helpers import Constants as BaseConstants
|
||||
from environments.helpers import EnvActions as BaseActions
|
||||
from environments import helpers as h
|
||||
from environments.factory.base.objects import Agent, Entity, Action, Tile, MoveableEntity
|
||||
from environments.factory.base.registers import Entities, EntityRegister, EnvObjectRegister, MovingEntityObjectRegister, \
|
||||
BoundRegisterMixin
|
||||
from environments.factory.base.objects import Agent, Entity, Action, Tile
|
||||
from environments.factory.base.registers import Entities, EntityRegister, BoundRegisterMixin, ObjectRegister
|
||||
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
|
||||
|
||||
NO_ITEM = 0
|
||||
ITEM_DROP_OFF = 1
|
||||
class Constants(BaseConstants):
|
||||
NO_ITEM = 0
|
||||
ITEM_DROP_OFF = 1
|
||||
# Item Env
|
||||
ITEM = 'Item'
|
||||
INVENTORY = 'Inventory'
|
||||
DROP_OFF = 'Drop_Off'
|
||||
|
||||
|
||||
class EnvActions(BaseActions):
|
||||
ITEM_ACTION = 'item_action'
|
||||
|
||||
|
||||
class Item(Entity):
|
||||
@ -41,13 +49,9 @@ class Item(Entity):
|
||||
def set_auto_despawn(self, auto_despawn):
|
||||
self._auto_despawn = auto_despawn
|
||||
|
||||
def despawn(self):
|
||||
# Todo: Move this to base class?
|
||||
curr_tile = self.tile
|
||||
curr_tile.leave(self)
|
||||
self._tile = None
|
||||
self._register.notify_change_to_value(self)
|
||||
return True
|
||||
def set_tile_to(self, no_pos_tile):
|
||||
assert self._register.__class__.__name__ != ItemRegister.__class__
|
||||
self._tile = no_pos_tile
|
||||
|
||||
|
||||
class ItemRegister(EntityRegister):
|
||||
@ -64,58 +68,38 @@ class ItemRegister(EntityRegister):
|
||||
del self[item]
|
||||
|
||||
|
||||
class Inventory(EntityRegister, BoundRegisterMixin):
|
||||
|
||||
@property
|
||||
def is_blocking_light(self):
|
||||
return False
|
||||
class Inventory(BoundRegisterMixin):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f'{self.__class__.__name__}({self.agent.name})'
|
||||
return f'{self.__class__.__name__}({self._bound_entity.name})'
|
||||
|
||||
def __init__(self, obs_shape: (int, int), agent: Agent, capacity: int):
|
||||
super(Inventory, self).__init__()
|
||||
self.agent = agent
|
||||
self._obs_shape = obs_shape
|
||||
|
||||
self._array = np.zeros((1, *self._obs_shape))
|
||||
|
||||
self.capacity = min(capacity, self._array.size)
|
||||
def __init__(self, agent: Agent, capacity: int, *args, **kwargs):
|
||||
super(Inventory, self).__init__(agent, *args, is_blocking_light=False, can_be_shadowed=False, **kwargs)
|
||||
self.capacity = capacity
|
||||
|
||||
def as_array(self):
|
||||
self._array[:] = c.FREE_CELL.value
|
||||
# ToDo: Make this Lazy
|
||||
for item_idx, item in enumerate(self):
|
||||
x_diff, y_diff = divmod(item_idx, self._array.shape[1])
|
||||
self._array[0, int(x_diff), int(y_diff)] = item.encoding
|
||||
return self._array
|
||||
if self._array is None:
|
||||
self._array = np.zeros((1, *self._shape))
|
||||
return super(Inventory, self).as_array()
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}[{self.agent.name}]({self.data})'
|
||||
|
||||
def append(self, item) -> None:
|
||||
if len(self) < self.capacity:
|
||||
super(Inventory, self).append(item)
|
||||
else:
|
||||
raise RuntimeError('Inventory is full')
|
||||
|
||||
def belongs_to_entity(self, entity):
|
||||
return self.agent == entity
|
||||
|
||||
def summarize_state(self, **kwargs):
|
||||
def summarize_states(self, **kwargs):
|
||||
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||
attr_dict.update(dict(items={val.name: val.summarize_state(**kwargs) for val in self}))
|
||||
attr_dict.update(dict(items={key: val.summarize_state(**kwargs) for key, val in self.items()}))
|
||||
attr_dict.update(dict(name=self.name))
|
||||
return attr_dict
|
||||
|
||||
def pop(self):
|
||||
item_to_pop = self[0]
|
||||
self.delete_env_object(item_to_pop)
|
||||
return item_to_pop
|
||||
|
||||
class Inventories(EnvObjectRegister):
|
||||
|
||||
class Inventories(ObjectRegister):
|
||||
|
||||
_accepted_objects = Inventory
|
||||
is_blocking_light = False
|
||||
can_be_shadowed = False
|
||||
hide_from_obs_builder = True
|
||||
|
||||
def __init__(self, obs_shape, *args, **kwargs):
|
||||
super(Inventories, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
|
||||
@ -125,7 +109,7 @@ class Inventories(EnvObjectRegister):
|
||||
return np.stack([inventory.as_array() for inv_idx, inventory in enumerate(self)])
|
||||
|
||||
def spawn_inventories(self, agents, capacity):
|
||||
inventories = [self._accepted_objects(self._obs_shape, agent, capacity)
|
||||
inventories = [self._accepted_objects(agent, capacity, self._obs_shape)
|
||||
for _, agent in enumerate(agents)]
|
||||
self.register_additional_items(inventories)
|
||||
|
||||
@ -141,10 +125,8 @@ class Inventories(EnvObjectRegister):
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
# as dict with additional nesting
|
||||
# return dict(items=super(Inventories, self).summarize_states())
|
||||
return super(Inventories, self).summarize_states(n_steps=n_steps)
|
||||
def summarize_states(self, **kwargs):
|
||||
return {key: val.summarize_states(**kwargs) for key, val in self.items()}
|
||||
|
||||
|
||||
class DropOffLocation(Entity):
|
||||
@ -155,7 +137,7 @@ class DropOffLocation(Entity):
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return ITEM_DROP_OFF
|
||||
return Constants.ITEM_DROP_OFF
|
||||
|
||||
def __init__(self, *args, storage_size_until_full: int = 5, auto_item_despawn_interval: int = 5, **kwargs):
|
||||
super(DropOffLocation, self).__init__(*args, **kwargs)
|
||||
@ -184,24 +166,17 @@ class DropOffLocations(EntityRegister):
|
||||
|
||||
_accepted_objects = DropOffLocation
|
||||
|
||||
@DeprecationWarning
|
||||
def Xas_array(self):
|
||||
# Todo: Which is faster?
|
||||
# indices = list(zip(range(len(self)), *zip(*[x.pos for x in self])))
|
||||
# np.put(self._array, [np.ravel_multi_index(x, self._array.shape) for x in indices], self.encodings)
|
||||
self._array[:] = c.FREE_CELL.value
|
||||
indices = list(zip([0, ] * len(self), *zip(*[x.pos for x in self])))
|
||||
np.put(self._array, [np.ravel_multi_index(x, self._array.shape) for x in indices], self.encodings)
|
||||
return self._array
|
||||
|
||||
|
||||
class ItemProperties(NamedTuple):
|
||||
n_items: int = 5 # How many items are there at the same time
|
||||
spawn_frequency: int = 10 # Spawn Frequency in Steps
|
||||
n_drop_off_locations: int = 5 # How many DropOff locations are there at the same time
|
||||
max_dropoff_storage_size: int = 0 # How many items are needed until the drop off is full
|
||||
n_items: int = 5 # How many items are there at the same time
|
||||
spawn_frequency: int = 10 # Spawn Frequency in Steps
|
||||
n_drop_off_locations: int = 5 # How many DropOff locations are there at the same time
|
||||
max_dropoff_storage_size: int = 0 # How many items are needed until the dropoff is full
|
||||
max_agent_inventory_capacity: int = 5 # How many items are needed until the agent inventory is full
|
||||
agent_can_interact: bool = True # Whether agents have the possibility to interact with the domain items
|
||||
|
||||
|
||||
c = Constants
|
||||
a = EnvActions
|
||||
|
||||
|
||||
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
||||
@ -220,11 +195,11 @@ class ItemFactory(BaseFactory):
|
||||
def additional_actions(self) -> Union[Action, List[Action]]:
|
||||
# noinspection PyUnresolvedReferences
|
||||
super_actions = super().additional_actions
|
||||
super_actions.append(Action(enum_ident=h.EnvActions.ITEM_ACTION))
|
||||
super_actions.append(Action(str_ident=a.ITEM_ACTION))
|
||||
return super_actions
|
||||
|
||||
@property
|
||||
def additional_entities(self) -> Dict[(Enum, Entities)]:
|
||||
def additional_entities(self) -> Dict[(str, Entities)]:
|
||||
# noinspection PyUnresolvedReferences
|
||||
super_entities = super().additional_entities
|
||||
|
||||
@ -238,19 +213,18 @@ class ItemFactory(BaseFactory):
|
||||
empty_tiles = self[c.FLOOR].empty_tiles[:self.item_prop.n_items]
|
||||
item_register.spawn_items(empty_tiles)
|
||||
|
||||
inventories = Inventories(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2),
|
||||
self._level_shape)
|
||||
inventories = Inventories(self._obs_shape, self._level_shape)
|
||||
inventories.spawn_inventories(self[c.AGENT], self.item_prop.max_agent_inventory_capacity)
|
||||
|
||||
super_entities.update({c.DROP_OFF: drop_offs, c.ITEM: item_register, c.INVENTORY: inventories})
|
||||
return super_entities
|
||||
|
||||
def _additional_raw_observations(self, agent) -> Dict[Constants, np.typing.ArrayLike]:
|
||||
additional_raw_observations = super()._additional_raw_observations(agent)
|
||||
def _additional_per_agent_raw_observations(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_raw_observations = super()._additional_per_agent_raw_observations(agent)
|
||||
additional_raw_observations.update({c.INVENTORY: self[c.INVENTORY].by_entity(agent).as_array()})
|
||||
return additional_raw_observations
|
||||
|
||||
def _additional_observations(self) -> Dict[Constants, np.typing.ArrayLike]:
|
||||
def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_observations = super()._additional_observations()
|
||||
additional_observations.update({c.ITEM: self[c.ITEM].as_array()})
|
||||
additional_observations.update({c.DROP_OFF: self[c.DROP_OFF].as_array()})
|
||||
@ -260,14 +234,16 @@ class ItemFactory(BaseFactory):
|
||||
inventory = self[c.INVENTORY].by_entity(agent)
|
||||
if drop_off := self[c.DROP_OFF].by_pos(agent.pos):
|
||||
if inventory:
|
||||
valid = drop_off.place_item(inventory.pop(0))
|
||||
valid = drop_off.place_item(inventory.pop())
|
||||
return valid
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
elif item := self[c.ITEM].by_pos(agent.pos):
|
||||
try:
|
||||
inventory.append(item)
|
||||
item.despawn()
|
||||
inventory.register_item(item)
|
||||
item.change_register(inventory)
|
||||
self[c.ITEM].delete_env_object(item)
|
||||
item.set_tile_to(self._NO_POS_TILE)
|
||||
return c.VALID
|
||||
except RuntimeError:
|
||||
return c.NOT_VALID
|
||||
@ -278,12 +254,9 @@ class ItemFactory(BaseFactory):
|
||||
# noinspection PyUnresolvedReferences
|
||||
valid = super().do_additional_actions(agent, action)
|
||||
if valid is None:
|
||||
if action == h.EnvActions.ITEM_ACTION:
|
||||
if self.item_prop.agent_can_interact:
|
||||
valid = self.do_item_action(agent)
|
||||
return valid
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
if action == a.ITEM_ACTION:
|
||||
valid = self.do_item_action(agent)
|
||||
return valid
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
@ -324,7 +297,7 @@ class ItemFactory(BaseFactory):
|
||||
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
||||
# noinspection PyUnresolvedReferences
|
||||
reward, info_dict = super().calculate_additional_reward(agent)
|
||||
if h.EnvActions.ITEM_ACTION == agent.temp_action:
|
||||
if a.ITEM_ACTION == agent.temp_action:
|
||||
if agent.temp_valid:
|
||||
if drop_off := self[c.DROP_OFF].by_pos(agent.pos):
|
||||
info_dict.update({f'{agent.name}_item_drop_off': 1})
|
||||
@ -352,21 +325,21 @@ class ItemFactory(BaseFactory):
|
||||
def render_additional_assets(self, mode='human'):
|
||||
# noinspection PyUnresolvedReferences
|
||||
additional_assets = super().render_additional_assets()
|
||||
items = [RenderEntity(c.ITEM.value, item.tile.pos) for item in self[c.ITEM]]
|
||||
items = [RenderEntity(c.ITEM, item.tile.pos) for item in self[c.ITEM] if item.tile != self._NO_POS_TILE]
|
||||
additional_assets.extend(items)
|
||||
drop_offs = [RenderEntity(c.DROP_OFF.value, drop_off.tile.pos) for drop_off in self[c.DROP_OFF]]
|
||||
drop_offs = [RenderEntity(c.DROP_OFF, drop_off.tile.pos) for drop_off in self[c.DROP_OFF]]
|
||||
additional_assets.extend(drop_offs)
|
||||
return additional_assets
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from environments.utility_classes import AgentRenderOptions as ARO, ObservationProperties
|
||||
from environments.utility_classes import AgentRenderOptions as aro, ObservationProperties
|
||||
|
||||
render = True
|
||||
|
||||
item_probs = ItemProperties()
|
||||
item_probs = ItemProperties(n_items=30)
|
||||
|
||||
obs_props = ObservationProperties(render_agents=ARO.SEPERATE, omit_agent_self=True, pomdp_r=2)
|
||||
obs_props = ObservationProperties(render_agents=aro.SEPERATE, omit_agent_self=True, pomdp_r=2)
|
||||
|
||||
move_props = {'allow_square_movement': True,
|
||||
'allow_diagonal_movement': True,
|
||||
|
Reference in New Issue
Block a user