Adjustments and Documentation
This commit is contained in:
@@ -156,14 +156,14 @@ class BaseFactory(gym.Env):
|
||||
np.argwhere(level_array == c.OCCUPIED_CELL),
|
||||
self._level_shape
|
||||
)
|
||||
self._entities.register_additional_items({c.WALLS: walls})
|
||||
self._entities.add_additional_items({c.WALLS: walls})
|
||||
|
||||
# Floor
|
||||
floor = Floors.from_argwhere_coordinates(
|
||||
np.argwhere(level_array == c.FREE_CELL),
|
||||
self._level_shape
|
||||
)
|
||||
self._entities.register_additional_items({c.FLOOR: floor})
|
||||
self._entities.add_additional_items({c.FLOOR: floor})
|
||||
|
||||
# NOPOS
|
||||
self._NO_POS_TILE = Floor(c.NO_POS, None)
|
||||
@@ -177,12 +177,12 @@ class BaseFactory(gym.Env):
|
||||
doors = Doors.from_tiles(door_tiles, self._level_shape, have_area=self.obs_prop.indicate_door_area,
|
||||
entity_kwargs=dict(context=floor)
|
||||
)
|
||||
self._entities.register_additional_items({c.DOORS: doors})
|
||||
self._entities.add_additional_items({c.DOORS: doors})
|
||||
|
||||
# Actions
|
||||
self._actions = Actions(self.mv_prop, can_use_doors=self.parse_doors)
|
||||
if additional_actions := self.actions_hook:
|
||||
self._actions.register_additional_items(additional_actions)
|
||||
self._actions.add_additional_items(additional_actions)
|
||||
|
||||
# Agents
|
||||
agents_to_spawn = self.n_agents-len(self._injected_agents)
|
||||
@@ -196,10 +196,10 @@ class BaseFactory(gym.Env):
|
||||
if self._injected_agents:
|
||||
initialized_injections = list()
|
||||
for i, injection in enumerate(self._injected_agents):
|
||||
agents.register_item(injection(self, floor.empty_tiles[0], agents, static_problem=False))
|
||||
agents.add_item(injection(self, floor.empty_tiles[0], agents, static_problem=False))
|
||||
initialized_injections.append(agents[-1])
|
||||
self._initialized_injections = initialized_injections
|
||||
self._entities.register_additional_items({c.AGENT: agents})
|
||||
self._entities.add_additional_items({c.AGENT: agents})
|
||||
|
||||
if self.obs_prop.additional_agent_placeholder is not None:
|
||||
# TODO: Make this accept Lists for multiple placeholders
|
||||
@@ -210,18 +210,18 @@ class BaseFactory(gym.Env):
|
||||
fill_value=self.obs_prop.additional_agent_placeholder)
|
||||
)
|
||||
|
||||
self._entities.register_additional_items({c.AGENT_PLACEHOLDER: placeholder})
|
||||
self._entities.add_additional_items({c.AGENT_PLACEHOLDER: placeholder})
|
||||
|
||||
# Additional Entitites from SubEnvs
|
||||
if additional_entities := self.entities_hook:
|
||||
self._entities.register_additional_items(additional_entities)
|
||||
self._entities.add_additional_items(additional_entities)
|
||||
|
||||
if self.obs_prop.show_global_position_info:
|
||||
global_positions = GlobalPositions(self._level_shape)
|
||||
# This moved into the GlobalPosition object
|
||||
# obs_shape_2d = self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2)
|
||||
global_positions.spawn_global_position_objects(self[c.AGENT])
|
||||
self._entities.register_additional_items({c.GLOBAL_POSITION: global_positions})
|
||||
self._entities.add_additional_items({c.GLOBAL_POSITION: global_positions})
|
||||
|
||||
# Return
|
||||
return self._entities
|
||||
@@ -535,7 +535,7 @@ class BaseFactory(gym.Env):
|
||||
|
||||
def _check_agent_move(self, agent, action: Action) -> (Floor, bool):
|
||||
# Actions
|
||||
x_diff, y_diff = h.ACTIONMAP[action.identifier]
|
||||
x_diff, y_diff = a.resolve_movement_action_to_coords(action.identifier)
|
||||
x_new = agent.x + x_diff
|
||||
y_new = agent.y + y_diff
|
||||
|
||||
|
@@ -72,15 +72,15 @@ class EnvObject(Object):
|
||||
def encoding(self):
|
||||
return c.OCCUPIED_CELL
|
||||
|
||||
def __init__(self, register, **kwargs):
|
||||
def __init__(self, collection, **kwargs):
|
||||
super(EnvObject, self).__init__(**kwargs)
|
||||
self._register = register
|
||||
self._collection = collection
|
||||
|
||||
def change_register(self, register):
|
||||
register.register_item(self)
|
||||
self._register.delete_env_object(self)
|
||||
self._register = register
|
||||
return self._register == register
|
||||
def change_parent_collection(self, other_collection):
|
||||
other_collection.add_item(self)
|
||||
self._collection.delete_env_object(self)
|
||||
self._collection = other_collection
|
||||
return self._collection == other_collection
|
||||
# With Rendering
|
||||
|
||||
|
||||
@@ -153,7 +153,7 @@ class MoveableEntity(Entity):
|
||||
curr_tile.leave(self)
|
||||
self._tile = next_tile
|
||||
self._last_tile = curr_tile
|
||||
self._register.notify_change_to_value(self)
|
||||
self._collection.notify_change_to_value(self)
|
||||
return c.VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
@@ -371,13 +371,13 @@ class Door(Entity):
|
||||
def _open(self):
|
||||
self.connectivity.add_edges_from([(self.pos, x) for x in range(len(self.connectivity_subgroups))])
|
||||
self._state = c.OPEN_DOOR
|
||||
self._register.notify_change_to_value(self)
|
||||
self._collection.notify_change_to_value(self)
|
||||
self.time_to_close = self.auto_close_interval
|
||||
|
||||
def _close(self):
|
||||
self.connectivity.remove_node(self.pos)
|
||||
self._state = c.CLOSED_DOOR
|
||||
self._register.notify_change_to_value(self)
|
||||
self._collection.notify_change_to_value(self)
|
||||
|
||||
def is_linked(self, old_pos, new_pos):
|
||||
try:
|
||||
|
@@ -13,11 +13,11 @@ from environments import helpers as h
|
||||
from environments.helpers import Constants as c
|
||||
|
||||
##########################################################################
|
||||
# ##################### Base Register Definition ####################### #
|
||||
# ################## Base Collections Definition ####################### #
|
||||
##########################################################################
|
||||
|
||||
|
||||
class ObjectRegister:
|
||||
class ObjectCollection:
|
||||
_accepted_objects = Object
|
||||
|
||||
@property
|
||||
@@ -25,59 +25,59 @@ class ObjectRegister:
|
||||
return f'{self.__class__.__name__}'
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._register = dict()
|
||||
self._collection = dict()
|
||||
|
||||
def __len__(self):
|
||||
return len(self._register)
|
||||
return len(self._collection)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.values())
|
||||
|
||||
def register_item(self, other: _accepted_objects):
|
||||
def add_item(self, other: _accepted_objects):
|
||||
assert isinstance(other, self._accepted_objects), f'All item names have to be of type ' \
|
||||
f'{self._accepted_objects}, ' \
|
||||
f'but were {other.__class__}.,'
|
||||
self._register.update({other.name: other})
|
||||
self._collection.update({other.name: other})
|
||||
return self
|
||||
|
||||
def register_additional_items(self, others: List[_accepted_objects]):
|
||||
def add_additional_items(self, others: List[_accepted_objects]):
|
||||
for other in others:
|
||||
self.register_item(other)
|
||||
self.add_item(other)
|
||||
return self
|
||||
|
||||
def keys(self):
|
||||
return self._register.keys()
|
||||
return self._collection.keys()
|
||||
|
||||
def values(self):
|
||||
return self._register.values()
|
||||
return self._collection.values()
|
||||
|
||||
def items(self):
|
||||
return self._register.items()
|
||||
return self._collection.items()
|
||||
|
||||
def _get_index(self, item):
|
||||
try:
|
||||
return next(i for i, v in enumerate(self._register.values()) if v == item)
|
||||
return next(i for i, v in enumerate(self._collection.values()) if v == item)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def __getitem__(self, item):
|
||||
if isinstance(item, (int, np.int64, np.int32)):
|
||||
if item < 0:
|
||||
item = len(self._register) - abs(item)
|
||||
item = len(self._collection) - abs(item)
|
||||
try:
|
||||
return next(v for i, v in enumerate(self._register.values()) if i == item)
|
||||
return next(v for i, v in enumerate(self._collection.values()) if i == item)
|
||||
except StopIteration:
|
||||
return None
|
||||
try:
|
||||
return self._register[item]
|
||||
return self._collection[item]
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}[{self._register}]'
|
||||
return f'{self.__class__.__name__}[{self._collection}]'
|
||||
|
||||
|
||||
class EnvObjectRegister(ObjectRegister):
|
||||
class EnvObjectCollection(ObjectCollection):
|
||||
|
||||
_accepted_objects = EnvObject
|
||||
|
||||
@@ -90,7 +90,7 @@ class EnvObjectRegister(ObjectRegister):
|
||||
is_blocking_light: bool = False,
|
||||
can_collide: bool = False,
|
||||
can_be_shadowed: bool = True, **kwargs):
|
||||
super(EnvObjectRegister, self).__init__(*args, **kwargs)
|
||||
super(EnvObjectCollection, self).__init__(*args, **kwargs)
|
||||
self._shape = obs_shape
|
||||
self._array = None
|
||||
self._individual_slices = individual_slices
|
||||
@@ -99,8 +99,8 @@ class EnvObjectRegister(ObjectRegister):
|
||||
self.can_be_shadowed = can_be_shadowed
|
||||
self.can_collide = can_collide
|
||||
|
||||
def register_item(self, other: EnvObject):
|
||||
super(EnvObjectRegister, self).register_item(other)
|
||||
def add_item(self, other: EnvObject):
|
||||
super(EnvObjectCollection, self).add_item(other)
|
||||
if self._array is None:
|
||||
self._array = np.zeros((1, *self._shape))
|
||||
else:
|
||||
@@ -145,13 +145,13 @@ class EnvObjectRegister(ObjectRegister):
|
||||
if self._individual_slices:
|
||||
self._array = np.delete(self._array, idx, axis=0)
|
||||
else:
|
||||
self.notify_change_to_free(self._register[name])
|
||||
self.notify_change_to_free(self._collection[name])
|
||||
# Dirty Hack to check if not beeing subclassed. In that case we need to refresh the array since positions
|
||||
# in the observation array are result of enumeration. They can overide each other.
|
||||
# Todo: Find a better solution
|
||||
if not issubclass(self.__class__, EntityRegister) and issubclass(self.__class__, EnvObjectRegister):
|
||||
if not issubclass(self.__class__, EntityCollection) and issubclass(self.__class__, EnvObjectCollection):
|
||||
self._refresh_arrays()
|
||||
del self._register[name]
|
||||
del self._collection[name]
|
||||
|
||||
def delete_env_object(self, env_object: EnvObject):
|
||||
del self[env_object.name]
|
||||
@@ -160,19 +160,19 @@ class EnvObjectRegister(ObjectRegister):
|
||||
del self[name]
|
||||
|
||||
|
||||
class EntityRegister(EnvObjectRegister, ABC):
|
||||
class EntityCollection(EnvObjectCollection, ABC):
|
||||
|
||||
_accepted_objects = Entity
|
||||
|
||||
@classmethod
|
||||
def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs):
|
||||
# objects_name = cls._accepted_objects.__name__
|
||||
register_obj = cls(*args, **kwargs)
|
||||
entities = [cls._accepted_objects(tile, register_obj, str_ident=i,
|
||||
collection = cls(*args, **kwargs)
|
||||
entities = [cls._accepted_objects(tile, collection, str_ident=i,
|
||||
**entity_kwargs if entity_kwargs is not None else {})
|
||||
for i, tile in enumerate(tiles)]
|
||||
register_obj.register_additional_items(entities)
|
||||
return register_obj
|
||||
collection.add_additional_items(entities)
|
||||
return collection
|
||||
|
||||
@classmethod
|
||||
def from_argwhere_coordinates(cls, positions: [(int, int)], tiles, *args, entity_kwargs=None, **kwargs, ):
|
||||
@@ -188,13 +188,13 @@ class EntityRegister(EnvObjectRegister, ABC):
|
||||
return [entity.tile for entity in self]
|
||||
|
||||
def __init__(self, level_shape, *args, **kwargs):
|
||||
super(EntityRegister, self).__init__(level_shape, *args, **kwargs)
|
||||
super(EntityCollection, self).__init__(level_shape, *args, **kwargs)
|
||||
self._lazy_eval_transforms = []
|
||||
|
||||
def __delitem__(self, name):
|
||||
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
|
||||
obj.tile.leave(obj)
|
||||
super(EntityRegister, self).__delitem__(name)
|
||||
super(EntityCollection, self).__delitem__(name)
|
||||
|
||||
def as_array(self):
|
||||
if self._lazy_eval_transforms:
|
||||
@@ -223,7 +223,7 @@ class EntityRegister(EnvObjectRegister, ABC):
|
||||
return None
|
||||
|
||||
|
||||
class BoundEnvObjRegister(EnvObjectRegister, ABC):
|
||||
class BoundEnvObjCollection(EnvObjectCollection, ABC):
|
||||
|
||||
def __init__(self, entity_to_be_bound, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -248,13 +248,13 @@ class BoundEnvObjRegister(EnvObjectRegister, ABC):
|
||||
return self._array[self.idx_by_entity(entity)]
|
||||
|
||||
|
||||
class MovingEntityObjectRegister(EntityRegister, ABC):
|
||||
class MovingEntityObjectCollection(EntityCollection, ABC):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MovingEntityObjectRegister, self).__init__(*args, **kwargs)
|
||||
super(MovingEntityObjectCollection, self).__init__(*args, **kwargs)
|
||||
|
||||
def notify_change_to_value(self, entity):
|
||||
super(MovingEntityObjectRegister, self).notify_change_to_value(entity)
|
||||
super(MovingEntityObjectCollection, self).notify_change_to_value(entity)
|
||||
if entity.last_pos != c.NO_POS:
|
||||
try:
|
||||
self._array_change_notifyer(entity, entity.last_pos, value=c.FREE_CELL)
|
||||
@@ -263,11 +263,11 @@ class MovingEntityObjectRegister(EntityRegister, ABC):
|
||||
|
||||
|
||||
##########################################################################
|
||||
# ################# Objects and Entity Registers ####################### #
|
||||
# ################# Objects and Entity Collection ###################### #
|
||||
##########################################################################
|
||||
|
||||
|
||||
class GlobalPositions(EnvObjectRegister):
|
||||
class GlobalPositions(EnvObjectCollection):
|
||||
|
||||
_accepted_objects = GlobalPosition
|
||||
|
||||
@@ -288,7 +288,7 @@ class GlobalPositions(EnvObjectRegister):
|
||||
global_positions = [self._accepted_objects(self._shape, agent, self)
|
||||
for _, agent in enumerate(agents)]
|
||||
# noinspection PyTypeChecker
|
||||
self.register_additional_items(global_positions)
|
||||
self.add_additional_items(global_positions)
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
return {}
|
||||
@@ -306,7 +306,7 @@ class GlobalPositions(EnvObjectRegister):
|
||||
return None
|
||||
|
||||
|
||||
class PlaceHolders(EnvObjectRegister):
|
||||
class PlaceHolders(EnvObjectCollection):
|
||||
_accepted_objects = PlaceHolder
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -320,12 +320,12 @@ class PlaceHolders(EnvObjectRegister):
|
||||
# objects_name = cls._accepted_objects.__name__
|
||||
if isinstance(values, (str, numbers.Number)):
|
||||
values = [values]
|
||||
register_obj = cls(*args, **kwargs)
|
||||
objects = [cls._accepted_objects(register_obj, str_ident=i, fill_value=value,
|
||||
collection = cls(*args, **kwargs)
|
||||
objects = [cls._accepted_objects(collection, str_ident=i, fill_value=value,
|
||||
**object_kwargs if object_kwargs is not None else {})
|
||||
for i, value in enumerate(values)]
|
||||
register_obj.register_additional_items(objects)
|
||||
return register_obj
|
||||
collection.add_additional_items(objects)
|
||||
return collection
|
||||
|
||||
# noinspection DuplicatedCode
|
||||
def as_array(self):
|
||||
@@ -343,8 +343,8 @@ class PlaceHolders(EnvObjectRegister):
|
||||
return self._array
|
||||
|
||||
|
||||
class Entities(ObjectRegister):
|
||||
_accepted_objects = EntityRegister
|
||||
class Entities(ObjectCollection):
|
||||
_accepted_objects = EntityCollection
|
||||
|
||||
@property
|
||||
def arrays(self):
|
||||
@@ -352,7 +352,7 @@ class Entities(ObjectRegister):
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
return list(self._register.keys())
|
||||
return list(self._collection.keys())
|
||||
|
||||
def __init__(self):
|
||||
super(Entities, self).__init__()
|
||||
@@ -360,21 +360,21 @@ class Entities(ObjectRegister):
|
||||
def iter_individual_entitites(self):
|
||||
return iter((x for sublist in self.values() for x in sublist))
|
||||
|
||||
def register_item(self, other: dict):
|
||||
def add_item(self, other: dict):
|
||||
assert not any([key for key in other.keys() if key in self.keys()]), \
|
||||
"This group of entities has already been registered!"
|
||||
self._register.update(other)
|
||||
"This group of entities has already been added!"
|
||||
self._collection.update(other)
|
||||
return self
|
||||
|
||||
def register_additional_items(self, others: Dict):
|
||||
return self.register_item(others)
|
||||
def add_additional_items(self, others: Dict):
|
||||
return self.add_item(others)
|
||||
|
||||
def by_pos(self, pos: (int, int)):
|
||||
found_entities = [y for y in (x.by_pos(pos) for x in self.values() if hasattr(x, 'by_pos')) if y is not None]
|
||||
return found_entities
|
||||
|
||||
|
||||
class Walls(EntityRegister):
|
||||
class Walls(EntityCollection):
|
||||
_accepted_objects = Wall
|
||||
|
||||
def as_array(self):
|
||||
@@ -396,7 +396,7 @@ class Walls(EntityRegister):
|
||||
def from_argwhere_coordinates(cls, argwhere_coordinates, *args, **kwargs):
|
||||
tiles = cls(*args, **kwargs)
|
||||
# noinspection PyTypeChecker
|
||||
tiles.register_additional_items(
|
||||
tiles.add_additional_items(
|
||||
[cls._accepted_objects(pos, tiles)
|
||||
for pos in argwhere_coordinates]
|
||||
)
|
||||
@@ -441,7 +441,7 @@ class Floors(Walls):
|
||||
return {}
|
||||
|
||||
|
||||
class Agents(MovingEntityObjectRegister):
|
||||
class Agents(MovingEntityObjectCollection):
|
||||
_accepted_objects = Agent
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -455,10 +455,10 @@ class Agents(MovingEntityObjectRegister):
|
||||
old_agent = self[key]
|
||||
self[key].tile.leave(self[key])
|
||||
agent._name = old_agent.name
|
||||
self._register[agent.name] = agent
|
||||
self._collection[agent.name] = agent
|
||||
|
||||
|
||||
class Doors(EntityRegister):
|
||||
class Doors(EntityCollection):
|
||||
|
||||
def __init__(self, *args, have_area: bool = False, **kwargs):
|
||||
self.have_area = have_area
|
||||
@@ -490,7 +490,7 @@ class Doors(EntityRegister):
|
||||
return super(Doors, self).as_array()
|
||||
|
||||
|
||||
class Actions(ObjectRegister):
|
||||
class Actions(ObjectCollection):
|
||||
_accepted_objects = Action
|
||||
|
||||
@property
|
||||
@@ -507,22 +507,22 @@ class Actions(ObjectRegister):
|
||||
|
||||
# Move this to Baseclass, Env init?
|
||||
if self.allow_square_movement:
|
||||
self.register_additional_items([self._accepted_objects(str_ident=direction)
|
||||
for direction in h.EnvActions.square_move()])
|
||||
self.add_additional_items([self._accepted_objects(str_ident=direction)
|
||||
for direction in h.EnvActions.square_move()])
|
||||
if self.allow_diagonal_movement:
|
||||
self.register_additional_items([self._accepted_objects(str_ident=direction)
|
||||
for direction in h.EnvActions.diagonal_move()])
|
||||
self._movement_actions = self._register.copy()
|
||||
self.add_additional_items([self._accepted_objects(str_ident=direction)
|
||||
for direction in h.EnvActions.diagonal_move()])
|
||||
self._movement_actions = self._collection.copy()
|
||||
if self.can_use_doors:
|
||||
self.register_additional_items([self._accepted_objects(str_ident=h.EnvActions.USE_DOOR)])
|
||||
self.add_additional_items([self._accepted_objects(str_ident=h.EnvActions.USE_DOOR)])
|
||||
if self.allow_no_op:
|
||||
self.register_additional_items([self._accepted_objects(str_ident=h.EnvActions.NOOP)])
|
||||
self.add_additional_items([self._accepted_objects(str_ident=h.EnvActions.NOOP)])
|
||||
|
||||
def is_moving_action(self, action: Union[int]):
|
||||
return action in self.movement_actions.values()
|
||||
|
||||
|
||||
class Zones(ObjectRegister):
|
||||
class Zones(ObjectCollection):
|
||||
|
||||
@property
|
||||
def accounting_zones(self):
|
||||
@@ -551,5 +551,5 @@ class Zones(ObjectRegister):
|
||||
def __getitem__(self, item):
|
||||
return self._zone_slices[item]
|
||||
|
||||
def register_additional_items(self, other: Union[str, List[str]]):
|
||||
def add_additional_items(self, other: Union[str, List[str]]):
|
||||
raise AttributeError('You are not allowed to add additional Zones in runtime.')
|
||||
|
Reference in New Issue
Block a user