88 lines
2.9 KiB
Python

from typing import List
import numpy as np
from environments.factory.base.objects import Floor, Agent
from environments.factory.base.registers import EntityCollection, BoundEnvObjCollection, ObjectCollection
from environments.factory.additional.item.item_entities import Item, DropOffLocation
class ItemRegister(EntityCollection):
_accepted_objects = Item
def spawn_items(self, tiles: List[Floor]):
items = [Item(tile, self) for tile in tiles]
self.add_additional_items(items)
def despawn_items(self, items: List[Item]):
items = [items] if isinstance(items, Item) else items
for item in items:
del self[item]
class Inventory(BoundEnvObjCollection):
@property
def name(self):
return f'{self.__class__.__name__}({self._bound_entity.name})'
def __init__(self, agent: Agent, capacity: int, *args, **kwargs):
super(Inventory, self).__init__(agent, *args, is_blocking_light=False, can_be_shadowed=False, **kwargs)
self.capacity = capacity
def as_array(self):
if self._array is None:
self._array = np.zeros((1, *self._shape))
return super(Inventory, self).as_array()
def summarize_states(self, **kwargs):
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
attr_dict.update(dict(items={key: val.summarize_state(**kwargs) for key, val in self.items()}))
attr_dict.update(dict(name=self.name))
return attr_dict
def pop(self):
item_to_pop = self[0]
self.delete_env_object(item_to_pop)
return item_to_pop
class Inventories(ObjectCollection):
_accepted_objects = Inventory
is_blocking_light = False
can_be_shadowed = False
def __init__(self, obs_shape, *args, **kwargs):
super(Inventories, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
self._obs_shape = obs_shape
def as_array(self):
return np.stack([inventory.as_array() for inv_idx, inventory in enumerate(self)])
def spawn_inventories(self, agents, capacity):
inventories = [self._accepted_objects(agent, capacity, self._obs_shape)
for _, agent in enumerate(agents)]
self.add_additional_items(inventories)
def idx_by_entity(self, entity):
try:
return next((idx for idx, inv in enumerate(self) if inv.belongs_to_entity(entity)))
except StopIteration:
return None
def by_entity(self, entity):
try:
return next((inv for inv in self if inv.belongs_to_entity(entity)))
except StopIteration:
return None
def summarize_states(self, **kwargs):
return {key: val.summarize_states(**kwargs) for key, val in self.items()}
class DropOffLocations(EntityCollection):
_accepted_objects = DropOffLocation