Adjustments and Documentation

This commit is contained in:
Steffen Illium
2022-04-11 16:15:44 +02:00
parent 3e19970a60
commit 0218f8f4e9
12 changed files with 394 additions and 182 deletions

View File

@@ -156,14 +156,14 @@ class BaseFactory(gym.Env):
np.argwhere(level_array == c.OCCUPIED_CELL),
self._level_shape
)
self._entities.register_additional_items({c.WALLS: walls})
self._entities.add_additional_items({c.WALLS: walls})
# Floor
floor = Floors.from_argwhere_coordinates(
np.argwhere(level_array == c.FREE_CELL),
self._level_shape
)
self._entities.register_additional_items({c.FLOOR: floor})
self._entities.add_additional_items({c.FLOOR: floor})
# NOPOS
self._NO_POS_TILE = Floor(c.NO_POS, None)
@@ -177,12 +177,12 @@ class BaseFactory(gym.Env):
doors = Doors.from_tiles(door_tiles, self._level_shape, have_area=self.obs_prop.indicate_door_area,
entity_kwargs=dict(context=floor)
)
self._entities.register_additional_items({c.DOORS: doors})
self._entities.add_additional_items({c.DOORS: doors})
# Actions
self._actions = Actions(self.mv_prop, can_use_doors=self.parse_doors)
if additional_actions := self.actions_hook:
self._actions.register_additional_items(additional_actions)
self._actions.add_additional_items(additional_actions)
# Agents
agents_to_spawn = self.n_agents-len(self._injected_agents)
@@ -196,10 +196,10 @@ class BaseFactory(gym.Env):
if self._injected_agents:
initialized_injections = list()
for i, injection in enumerate(self._injected_agents):
agents.register_item(injection(self, floor.empty_tiles[0], agents, static_problem=False))
agents.add_item(injection(self, floor.empty_tiles[0], agents, static_problem=False))
initialized_injections.append(agents[-1])
self._initialized_injections = initialized_injections
self._entities.register_additional_items({c.AGENT: agents})
self._entities.add_additional_items({c.AGENT: agents})
if self.obs_prop.additional_agent_placeholder is not None:
# TODO: Make this accept Lists for multiple placeholders
@@ -210,18 +210,18 @@ class BaseFactory(gym.Env):
fill_value=self.obs_prop.additional_agent_placeholder)
)
self._entities.register_additional_items({c.AGENT_PLACEHOLDER: placeholder})
self._entities.add_additional_items({c.AGENT_PLACEHOLDER: placeholder})
# Additional Entitites from SubEnvs
if additional_entities := self.entities_hook:
self._entities.register_additional_items(additional_entities)
self._entities.add_additional_items(additional_entities)
if self.obs_prop.show_global_position_info:
global_positions = GlobalPositions(self._level_shape)
# This moved into the GlobalPosition object
# obs_shape_2d = self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2)
global_positions.spawn_global_position_objects(self[c.AGENT])
self._entities.register_additional_items({c.GLOBAL_POSITION: global_positions})
self._entities.add_additional_items({c.GLOBAL_POSITION: global_positions})
# Return
return self._entities
@@ -535,7 +535,7 @@ class BaseFactory(gym.Env):
def _check_agent_move(self, agent, action: Action) -> (Floor, bool):
# Actions
x_diff, y_diff = h.ACTIONMAP[action.identifier]
x_diff, y_diff = a.resolve_movement_action_to_coords(action.identifier)
x_new = agent.x + x_diff
y_new = agent.y + y_diff

View File

@@ -72,15 +72,15 @@ class EnvObject(Object):
def encoding(self):
return c.OCCUPIED_CELL
def __init__(self, register, **kwargs):
def __init__(self, collection, **kwargs):
super(EnvObject, self).__init__(**kwargs)
self._register = register
self._collection = collection
def change_register(self, register):
register.register_item(self)
self._register.delete_env_object(self)
self._register = register
return self._register == register
def change_parent_collection(self, other_collection):
other_collection.add_item(self)
self._collection.delete_env_object(self)
self._collection = other_collection
return self._collection == other_collection
# With Rendering
@@ -153,7 +153,7 @@ class MoveableEntity(Entity):
curr_tile.leave(self)
self._tile = next_tile
self._last_tile = curr_tile
self._register.notify_change_to_value(self)
self._collection.notify_change_to_value(self)
return c.VALID
else:
return c.NOT_VALID
@@ -371,13 +371,13 @@ class Door(Entity):
def _open(self):
self.connectivity.add_edges_from([(self.pos, x) for x in range(len(self.connectivity_subgroups))])
self._state = c.OPEN_DOOR
self._register.notify_change_to_value(self)
self._collection.notify_change_to_value(self)
self.time_to_close = self.auto_close_interval
def _close(self):
self.connectivity.remove_node(self.pos)
self._state = c.CLOSED_DOOR
self._register.notify_change_to_value(self)
self._collection.notify_change_to_value(self)
def is_linked(self, old_pos, new_pos):
try:

View File

@@ -13,11 +13,11 @@ from environments import helpers as h
from environments.helpers import Constants as c
##########################################################################
# ##################### Base Register Definition ####################### #
# ################## Base Collections Definition ####################### #
##########################################################################
class ObjectRegister:
class ObjectCollection:
_accepted_objects = Object
@property
@@ -25,59 +25,59 @@ class ObjectRegister:
return f'{self.__class__.__name__}'
def __init__(self, *args, **kwargs):
self._register = dict()
self._collection = dict()
def __len__(self):
return len(self._register)
return len(self._collection)
def __iter__(self):
return iter(self.values())
def register_item(self, other: _accepted_objects):
def add_item(self, other: _accepted_objects):
assert isinstance(other, self._accepted_objects), f'All item names have to be of type ' \
f'{self._accepted_objects}, ' \
f'but were {other.__class__}.,'
self._register.update({other.name: other})
self._collection.update({other.name: other})
return self
def register_additional_items(self, others: List[_accepted_objects]):
def add_additional_items(self, others: List[_accepted_objects]):
for other in others:
self.register_item(other)
self.add_item(other)
return self
def keys(self):
return self._register.keys()
return self._collection.keys()
def values(self):
return self._register.values()
return self._collection.values()
def items(self):
return self._register.items()
return self._collection.items()
def _get_index(self, item):
try:
return next(i for i, v in enumerate(self._register.values()) if v == item)
return next(i for i, v in enumerate(self._collection.values()) if v == item)
except StopIteration:
return None
def __getitem__(self, item):
if isinstance(item, (int, np.int64, np.int32)):
if item < 0:
item = len(self._register) - abs(item)
item = len(self._collection) - abs(item)
try:
return next(v for i, v in enumerate(self._register.values()) if i == item)
return next(v for i, v in enumerate(self._collection.values()) if i == item)
except StopIteration:
return None
try:
return self._register[item]
return self._collection[item]
except KeyError:
return None
def __repr__(self):
return f'{self.__class__.__name__}[{self._register}]'
return f'{self.__class__.__name__}[{self._collection}]'
class EnvObjectRegister(ObjectRegister):
class EnvObjectCollection(ObjectCollection):
_accepted_objects = EnvObject
@@ -90,7 +90,7 @@ class EnvObjectRegister(ObjectRegister):
is_blocking_light: bool = False,
can_collide: bool = False,
can_be_shadowed: bool = True, **kwargs):
super(EnvObjectRegister, self).__init__(*args, **kwargs)
super(EnvObjectCollection, self).__init__(*args, **kwargs)
self._shape = obs_shape
self._array = None
self._individual_slices = individual_slices
@@ -99,8 +99,8 @@ class EnvObjectRegister(ObjectRegister):
self.can_be_shadowed = can_be_shadowed
self.can_collide = can_collide
def register_item(self, other: EnvObject):
super(EnvObjectRegister, self).register_item(other)
def add_item(self, other: EnvObject):
super(EnvObjectCollection, self).add_item(other)
if self._array is None:
self._array = np.zeros((1, *self._shape))
else:
@@ -145,13 +145,13 @@ class EnvObjectRegister(ObjectRegister):
if self._individual_slices:
self._array = np.delete(self._array, idx, axis=0)
else:
self.notify_change_to_free(self._register[name])
self.notify_change_to_free(self._collection[name])
# Dirty Hack to check if not beeing subclassed. In that case we need to refresh the array since positions
# in the observation array are result of enumeration. They can overide each other.
# Todo: Find a better solution
if not issubclass(self.__class__, EntityRegister) and issubclass(self.__class__, EnvObjectRegister):
if not issubclass(self.__class__, EntityCollection) and issubclass(self.__class__, EnvObjectCollection):
self._refresh_arrays()
del self._register[name]
del self._collection[name]
def delete_env_object(self, env_object: EnvObject):
del self[env_object.name]
@@ -160,19 +160,19 @@ class EnvObjectRegister(ObjectRegister):
del self[name]
class EntityRegister(EnvObjectRegister, ABC):
class EntityCollection(EnvObjectCollection, ABC):
_accepted_objects = Entity
@classmethod
def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs):
# objects_name = cls._accepted_objects.__name__
register_obj = cls(*args, **kwargs)
entities = [cls._accepted_objects(tile, register_obj, str_ident=i,
collection = cls(*args, **kwargs)
entities = [cls._accepted_objects(tile, collection, str_ident=i,
**entity_kwargs if entity_kwargs is not None else {})
for i, tile in enumerate(tiles)]
register_obj.register_additional_items(entities)
return register_obj
collection.add_additional_items(entities)
return collection
@classmethod
def from_argwhere_coordinates(cls, positions: [(int, int)], tiles, *args, entity_kwargs=None, **kwargs, ):
@@ -188,13 +188,13 @@ class EntityRegister(EnvObjectRegister, ABC):
return [entity.tile for entity in self]
def __init__(self, level_shape, *args, **kwargs):
super(EntityRegister, self).__init__(level_shape, *args, **kwargs)
super(EntityCollection, self).__init__(level_shape, *args, **kwargs)
self._lazy_eval_transforms = []
def __delitem__(self, name):
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
obj.tile.leave(obj)
super(EntityRegister, self).__delitem__(name)
super(EntityCollection, self).__delitem__(name)
def as_array(self):
if self._lazy_eval_transforms:
@@ -223,7 +223,7 @@ class EntityRegister(EnvObjectRegister, ABC):
return None
class BoundEnvObjRegister(EnvObjectRegister, ABC):
class BoundEnvObjCollection(EnvObjectCollection, ABC):
def __init__(self, entity_to_be_bound, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -248,13 +248,13 @@ class BoundEnvObjRegister(EnvObjectRegister, ABC):
return self._array[self.idx_by_entity(entity)]
class MovingEntityObjectRegister(EntityRegister, ABC):
class MovingEntityObjectCollection(EntityCollection, ABC):
def __init__(self, *args, **kwargs):
super(MovingEntityObjectRegister, self).__init__(*args, **kwargs)
super(MovingEntityObjectCollection, self).__init__(*args, **kwargs)
def notify_change_to_value(self, entity):
super(MovingEntityObjectRegister, self).notify_change_to_value(entity)
super(MovingEntityObjectCollection, self).notify_change_to_value(entity)
if entity.last_pos != c.NO_POS:
try:
self._array_change_notifyer(entity, entity.last_pos, value=c.FREE_CELL)
@@ -263,11 +263,11 @@ class MovingEntityObjectRegister(EntityRegister, ABC):
##########################################################################
# ################# Objects and Entity Registers ####################### #
# ################# Objects and Entity Collection ###################### #
##########################################################################
class GlobalPositions(EnvObjectRegister):
class GlobalPositions(EnvObjectCollection):
_accepted_objects = GlobalPosition
@@ -288,7 +288,7 @@ class GlobalPositions(EnvObjectRegister):
global_positions = [self._accepted_objects(self._shape, agent, self)
for _, agent in enumerate(agents)]
# noinspection PyTypeChecker
self.register_additional_items(global_positions)
self.add_additional_items(global_positions)
def summarize_states(self, n_steps=None):
return {}
@@ -306,7 +306,7 @@ class GlobalPositions(EnvObjectRegister):
return None
class PlaceHolders(EnvObjectRegister):
class PlaceHolders(EnvObjectCollection):
_accepted_objects = PlaceHolder
def __init__(self, *args, **kwargs):
@@ -320,12 +320,12 @@ class PlaceHolders(EnvObjectRegister):
# objects_name = cls._accepted_objects.__name__
if isinstance(values, (str, numbers.Number)):
values = [values]
register_obj = cls(*args, **kwargs)
objects = [cls._accepted_objects(register_obj, str_ident=i, fill_value=value,
collection = cls(*args, **kwargs)
objects = [cls._accepted_objects(collection, str_ident=i, fill_value=value,
**object_kwargs if object_kwargs is not None else {})
for i, value in enumerate(values)]
register_obj.register_additional_items(objects)
return register_obj
collection.add_additional_items(objects)
return collection
# noinspection DuplicatedCode
def as_array(self):
@@ -343,8 +343,8 @@ class PlaceHolders(EnvObjectRegister):
return self._array
class Entities(ObjectRegister):
_accepted_objects = EntityRegister
class Entities(ObjectCollection):
_accepted_objects = EntityCollection
@property
def arrays(self):
@@ -352,7 +352,7 @@ class Entities(ObjectRegister):
@property
def names(self):
return list(self._register.keys())
return list(self._collection.keys())
def __init__(self):
super(Entities, self).__init__()
@@ -360,21 +360,21 @@ class Entities(ObjectRegister):
def iter_individual_entitites(self):
return iter((x for sublist in self.values() for x in sublist))
def register_item(self, other: dict):
def add_item(self, other: dict):
assert not any([key for key in other.keys() if key in self.keys()]), \
"This group of entities has already been registered!"
self._register.update(other)
"This group of entities has already been added!"
self._collection.update(other)
return self
def register_additional_items(self, others: Dict):
return self.register_item(others)
def add_additional_items(self, others: Dict):
return self.add_item(others)
def by_pos(self, pos: (int, int)):
found_entities = [y for y in (x.by_pos(pos) for x in self.values() if hasattr(x, 'by_pos')) if y is not None]
return found_entities
class Walls(EntityRegister):
class Walls(EntityCollection):
_accepted_objects = Wall
def as_array(self):
@@ -396,7 +396,7 @@ class Walls(EntityRegister):
def from_argwhere_coordinates(cls, argwhere_coordinates, *args, **kwargs):
tiles = cls(*args, **kwargs)
# noinspection PyTypeChecker
tiles.register_additional_items(
tiles.add_additional_items(
[cls._accepted_objects(pos, tiles)
for pos in argwhere_coordinates]
)
@@ -441,7 +441,7 @@ class Floors(Walls):
return {}
class Agents(MovingEntityObjectRegister):
class Agents(MovingEntityObjectCollection):
_accepted_objects = Agent
def __init__(self, *args, **kwargs):
@@ -455,10 +455,10 @@ class Agents(MovingEntityObjectRegister):
old_agent = self[key]
self[key].tile.leave(self[key])
agent._name = old_agent.name
self._register[agent.name] = agent
self._collection[agent.name] = agent
class Doors(EntityRegister):
class Doors(EntityCollection):
def __init__(self, *args, have_area: bool = False, **kwargs):
self.have_area = have_area
@@ -490,7 +490,7 @@ class Doors(EntityRegister):
return super(Doors, self).as_array()
class Actions(ObjectRegister):
class Actions(ObjectCollection):
_accepted_objects = Action
@property
@@ -507,22 +507,22 @@ class Actions(ObjectRegister):
# Move this to Baseclass, Env init?
if self.allow_square_movement:
self.register_additional_items([self._accepted_objects(str_ident=direction)
for direction in h.EnvActions.square_move()])
self.add_additional_items([self._accepted_objects(str_ident=direction)
for direction in h.EnvActions.square_move()])
if self.allow_diagonal_movement:
self.register_additional_items([self._accepted_objects(str_ident=direction)
for direction in h.EnvActions.diagonal_move()])
self._movement_actions = self._register.copy()
self.add_additional_items([self._accepted_objects(str_ident=direction)
for direction in h.EnvActions.diagonal_move()])
self._movement_actions = self._collection.copy()
if self.can_use_doors:
self.register_additional_items([self._accepted_objects(str_ident=h.EnvActions.USE_DOOR)])
self.add_additional_items([self._accepted_objects(str_ident=h.EnvActions.USE_DOOR)])
if self.allow_no_op:
self.register_additional_items([self._accepted_objects(str_ident=h.EnvActions.NOOP)])
self.add_additional_items([self._accepted_objects(str_ident=h.EnvActions.NOOP)])
def is_moving_action(self, action: Union[int]):
return action in self.movement_actions.values()
class Zones(ObjectRegister):
class Zones(ObjectCollection):
@property
def accounting_zones(self):
@@ -551,5 +551,5 @@ class Zones(ObjectRegister):
def __getitem__(self, item):
return self._zone_slices[item]
def register_additional_items(self, other: Union[str, List[str]]):
def add_additional_items(self, other: Union[str, List[str]]):
raise AttributeError('You are not allowed to add additional Zones in runtime.')

View File

@@ -4,7 +4,7 @@ import numpy as np
from environments.factory.base.base_factory import BaseFactory
from environments.factory.base.objects import Agent, Action, Entity, EnvObject, BoundingMixin
from environments.factory.base.registers import EntityRegister, EnvObjectRegister
from environments.factory.base.registers import EntityCollection, EnvObjectCollection
from environments.factory.base.renderer import RenderEntity
from environments.helpers import Constants as BaseConstants
from environments.helpers import EnvActions as BaseActions
@@ -68,7 +68,7 @@ class Battery(BoundingMixin, EnvObject):
if self.charge_level != 0:
# noinspection PyTypeChecker
self.charge_level = max(0, amount + self.charge_level)
self._register.notify_change_to_value(self)
self._collection.notify_change_to_value(self)
return c.VALID
else:
return c.NOT_VALID
@@ -79,7 +79,7 @@ class Battery(BoundingMixin, EnvObject):
return attr_dict
class BatteriesRegister(EnvObjectRegister):
class BatteriesRegister(EnvObjectCollection):
_accepted_objects = Battery
@@ -90,7 +90,7 @@ class BatteriesRegister(EnvObjectRegister):
def spawn_batteries(self, agents, initial_charge_level):
batteries = [self._accepted_objects(initial_charge_level, agent, self) for _, agent in enumerate(agents)]
self.register_additional_items(batteries)
self.add_additional_items(batteries)
def summarize_states(self, n_steps=None):
# as dict with additional nesting
@@ -140,7 +140,7 @@ class ChargePod(Entity):
return summary
class ChargePods(EntityRegister):
class ChargePods(EntityCollection):
_accepted_objects = ChargePod

View File

@@ -9,7 +9,7 @@ from environments.factory.base.base_factory import BaseFactory
from environments.helpers import Constants as BaseConstants
from environments.helpers import EnvActions as BaseActions
from environments.factory.base.objects import Agent, Entity, Action
from environments.factory.base.registers import Entities, EntityRegister
from environments.factory.base.registers import Entities, EntityCollection
from environments.factory.base.renderer import RenderEntity
@@ -73,7 +73,7 @@ class Destination(Entity):
return state_summary
class Destinations(EntityRegister):
class Destinations(EntityCollection):
_accepted_objects = Destination
@@ -208,13 +208,13 @@ class DestFactory(BaseFactory):
n_dest_to_spawn = len(destinations_to_spawn)
if self.dest_prop.spawn_mode != DestModeOptions.GROUPED:
destinations = [Destination(tile, c.DEST) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
self[c.DEST].register_additional_items(destinations)
self[c.DEST].add_additional_items(destinations)
for dest in destinations_to_spawn:
del self._dest_spawn_timer[dest]
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
elif self.dest_prop.spawn_mode == DestModeOptions.GROUPED and n_dest_to_spawn == self.dest_prop.n_dests:
destinations = [Destination(tile, self[c.DEST]) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
self[c.DEST].register_additional_items(destinations)
self[c.DEST].add_additional_items(destinations)
for dest in destinations_to_spawn:
del self._dest_spawn_timer[dest]
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
@@ -231,7 +231,7 @@ class DestFactory(BaseFactory):
self._dest_spawn_timer[key] = min(self.dest_prop.spawn_frequency, self._dest_spawn_timer[key] + 1)
for dest in list(self[c.DEST].values()):
if dest.is_considered_reached:
dest.change_register(self[c.DEST])
dest.change_parent_collection(self[c.DEST])
self._dest_spawn_timer[dest.name] = 0
self.print(f'{dest.name} is reached now, removing...')
else:

View File

@@ -11,7 +11,7 @@ from environments.helpers import EnvActions as BaseActions
from environments.factory.base.base_factory import BaseFactory
from environments.factory.base.objects import Agent, Action, Entity, Floor
from environments.factory.base.registers import Entities, EntityRegister
from environments.factory.base.registers import Entities, EntityCollection
from environments.factory.base.renderer import RenderEntity
from environments.utility_classes import ObservationProperties
@@ -61,7 +61,7 @@ class Dirt(Entity):
def set_new_amount(self, amount):
self._amount = amount
self._register.notify_change_to_value(self)
self._collection.notify_change_to_value(self)
def summarize_state(self, **kwargs):
state_dict = super().summarize_state(**kwargs)
@@ -69,7 +69,7 @@ class Dirt(Entity):
return state_dict
class DirtRegister(EntityRegister):
class DirtRegister(EntityCollection):
_accepted_objects = Dirt
@@ -93,7 +93,7 @@ class DirtRegister(EntityRegister):
dirt = self.by_pos(tile.pos)
if dirt is None:
dirt = Dirt(tile, self, amount=self.dirt_properties.max_spawn_amount)
self.register_item(dirt)
self.add_item(dirt)
else:
new_value = dirt.amount + self.dirt_properties.max_spawn_amount
dirt.set_new_amount(min(new_value, self.dirt_properties.max_local_amount))

View File

@@ -5,10 +5,10 @@ import numpy as np
from environments.factory.base.objects import Agent, Entity, Action
from environments.factory.factory_dirt import Dirt, DirtRegister, DirtFactory
from environments.factory.base.objects import Floor
from environments.factory.base.registers import Floors, Entities, EntityRegister
from environments.factory.base.registers import Floors, Entities, EntityCollection
class Machines(EntityRegister):
class Machines(EntityCollection):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

View File

@@ -9,7 +9,7 @@ from environments.helpers import Constants as BaseConstants
from environments.helpers import EnvActions as BaseActions
from environments import helpers as h
from environments.factory.base.objects import Agent, Entity, Action, Floor
from environments.factory.base.registers import Entities, EntityRegister, BoundEnvObjRegister, ObjectRegister
from environments.factory.base.registers import Entities, EntityCollection, BoundEnvObjCollection, ObjectCollection
from environments.factory.base.renderer import RenderEntity
@@ -53,17 +53,17 @@ class Item(Entity):
self._auto_despawn = auto_despawn
def set_tile_to(self, no_pos_tile):
assert self._register.__class__.__name__ != ItemRegister.__class__
assert self._collection.__class__.__name__ != ItemRegister.__class__
self._tile = no_pos_tile
class ItemRegister(EntityRegister):
class ItemRegister(EntityCollection):
_accepted_objects = Item
def spawn_items(self, tiles: List[Floor]):
items = [Item(tile, self) for tile in tiles]
self.register_additional_items(items)
self.add_additional_items(items)
def despawn_items(self, items: List[Item]):
items = [items] if isinstance(items, Item) else items
@@ -71,7 +71,7 @@ class ItemRegister(EntityRegister):
del self[item]
class Inventory(BoundEnvObjRegister):
class Inventory(BoundEnvObjCollection):
@property
def name(self):
@@ -98,7 +98,7 @@ class Inventory(BoundEnvObjRegister):
return item_to_pop
class Inventories(ObjectRegister):
class Inventories(ObjectCollection):
_accepted_objects = Inventory
is_blocking_light = False
@@ -114,7 +114,7 @@ class Inventories(ObjectRegister):
def spawn_inventories(self, agents, capacity):
inventories = [self._accepted_objects(agent, capacity, self._obs_shape)
for _, agent in enumerate(agents)]
self.register_additional_items(inventories)
self.add_additional_items(inventories)
def idx_by_entity(self, entity):
try:
@@ -161,7 +161,7 @@ class DropOffLocation(Entity):
return super().summarize_state(n_steps=n_steps)
class DropOffLocations(EntityRegister):
class DropOffLocations(EntityCollection):
_accepted_objects = DropOffLocation
@@ -250,7 +250,7 @@ class ItemFactory(BaseFactory):
reason=a.ITEM_ACTION, info=info_dict)
return valid, reward
elif item := self[c.ITEM].by_pos(agent.pos):
item.change_register(inventory)
item.change_parent_collection(inventory)
item.set_tile_to(self._NO_POS_TILE)
self.print(f'{agent.name} just picked up an item at {agent.pos}')
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_VALID': 1, f'{a.ITEM_ACTION}_VALID': 1}