88 lines
2.9 KiB
Python
88 lines
2.9 KiB
Python
from typing import List
|
|
|
|
import numpy as np
|
|
|
|
from environments.factory.base.objects import Floor, Agent
|
|
from environments.factory.base.registers import EntityCollection, BoundEnvObjCollection, ObjectCollection
|
|
from environments.factory.additional.item.item_entities import Item, DropOffLocation
|
|
|
|
|
|
class ItemRegister(EntityCollection):
|
|
|
|
_accepted_objects = Item
|
|
|
|
def spawn_items(self, tiles: List[Floor]):
|
|
items = [Item(tile, self) for tile in tiles]
|
|
self.add_additional_items(items)
|
|
|
|
def despawn_items(self, items: List[Item]):
|
|
items = [items] if isinstance(items, Item) else items
|
|
for item in items:
|
|
del self[item]
|
|
|
|
|
|
class Inventory(BoundEnvObjCollection):
|
|
|
|
@property
|
|
def name(self):
|
|
return f'{self.__class__.__name__}({self._bound_entity.name})'
|
|
|
|
def __init__(self, agent: Agent, capacity: int, *args, **kwargs):
|
|
super(Inventory, self).__init__(agent, *args, is_blocking_light=False, can_be_shadowed=False, **kwargs)
|
|
self.capacity = capacity
|
|
|
|
def as_array(self):
|
|
if self._array is None:
|
|
self._array = np.zeros((1, *self._shape))
|
|
return super(Inventory, self).as_array()
|
|
|
|
def summarize_states(self, **kwargs):
|
|
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
|
attr_dict.update(dict(items={key: val.summarize_state(**kwargs) for key, val in self.items()}))
|
|
attr_dict.update(dict(name=self.name))
|
|
return attr_dict
|
|
|
|
def pop(self):
|
|
item_to_pop = self[0]
|
|
self.delete_env_object(item_to_pop)
|
|
return item_to_pop
|
|
|
|
|
|
class Inventories(ObjectCollection):
|
|
|
|
_accepted_objects = Inventory
|
|
is_blocking_light = False
|
|
can_be_shadowed = False
|
|
|
|
def __init__(self, obs_shape, *args, **kwargs):
|
|
super(Inventories, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
|
|
self._obs_shape = obs_shape
|
|
|
|
def as_array(self):
|
|
return np.stack([inventory.as_array() for inv_idx, inventory in enumerate(self)])
|
|
|
|
def spawn_inventories(self, agents, capacity):
|
|
inventories = [self._accepted_objects(agent, capacity, self._obs_shape)
|
|
for _, agent in enumerate(agents)]
|
|
self.add_additional_items(inventories)
|
|
|
|
def idx_by_entity(self, entity):
|
|
try:
|
|
return next((idx for idx, inv in enumerate(self) if inv.belongs_to_entity(entity)))
|
|
except StopIteration:
|
|
return None
|
|
|
|
def by_entity(self, entity):
|
|
try:
|
|
return next((inv for inv in self if inv.belongs_to_entity(entity)))
|
|
except StopIteration:
|
|
return None
|
|
|
|
def summarize_states(self, **kwargs):
|
|
return {key: val.summarize_states(**kwargs) for key, val in self.items()}
|
|
|
|
|
|
class DropOffLocations(EntityCollection):
|
|
|
|
_accepted_objects = DropOffLocation
|