Adjustments and Documentation

This commit is contained in:
Steffen Illium
2022-04-11 16:15:44 +02:00
parent 3e19970a60
commit 0218f8f4e9
12 changed files with 394 additions and 182 deletions

View File

@@ -156,14 +156,14 @@ class BaseFactory(gym.Env):
np.argwhere(level_array == c.OCCUPIED_CELL),
self._level_shape
)
self._entities.register_additional_items({c.WALLS: walls})
self._entities.add_additional_items({c.WALLS: walls})
# Floor
floor = Floors.from_argwhere_coordinates(
np.argwhere(level_array == c.FREE_CELL),
self._level_shape
)
self._entities.register_additional_items({c.FLOOR: floor})
self._entities.add_additional_items({c.FLOOR: floor})
# NOPOS
self._NO_POS_TILE = Floor(c.NO_POS, None)
@@ -177,12 +177,12 @@ class BaseFactory(gym.Env):
doors = Doors.from_tiles(door_tiles, self._level_shape, have_area=self.obs_prop.indicate_door_area,
entity_kwargs=dict(context=floor)
)
self._entities.register_additional_items({c.DOORS: doors})
self._entities.add_additional_items({c.DOORS: doors})
# Actions
self._actions = Actions(self.mv_prop, can_use_doors=self.parse_doors)
if additional_actions := self.actions_hook:
self._actions.register_additional_items(additional_actions)
self._actions.add_additional_items(additional_actions)
# Agents
agents_to_spawn = self.n_agents-len(self._injected_agents)
@@ -196,10 +196,10 @@ class BaseFactory(gym.Env):
if self._injected_agents:
initialized_injections = list()
for i, injection in enumerate(self._injected_agents):
agents.register_item(injection(self, floor.empty_tiles[0], agents, static_problem=False))
agents.add_item(injection(self, floor.empty_tiles[0], agents, static_problem=False))
initialized_injections.append(agents[-1])
self._initialized_injections = initialized_injections
self._entities.register_additional_items({c.AGENT: agents})
self._entities.add_additional_items({c.AGENT: agents})
if self.obs_prop.additional_agent_placeholder is not None:
# TODO: Make this accept Lists for multiple placeholders
@@ -210,18 +210,18 @@ class BaseFactory(gym.Env):
fill_value=self.obs_prop.additional_agent_placeholder)
)
self._entities.register_additional_items({c.AGENT_PLACEHOLDER: placeholder})
self._entities.add_additional_items({c.AGENT_PLACEHOLDER: placeholder})
# Additional Entitites from SubEnvs
if additional_entities := self.entities_hook:
self._entities.register_additional_items(additional_entities)
self._entities.add_additional_items(additional_entities)
if self.obs_prop.show_global_position_info:
global_positions = GlobalPositions(self._level_shape)
# This moved into the GlobalPosition object
# obs_shape_2d = self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2)
global_positions.spawn_global_position_objects(self[c.AGENT])
self._entities.register_additional_items({c.GLOBAL_POSITION: global_positions})
self._entities.add_additional_items({c.GLOBAL_POSITION: global_positions})
# Return
return self._entities
@@ -535,7 +535,7 @@ class BaseFactory(gym.Env):
def _check_agent_move(self, agent, action: Action) -> (Floor, bool):
# Actions
x_diff, y_diff = h.ACTIONMAP[action.identifier]
x_diff, y_diff = a.resolve_movement_action_to_coords(action.identifier)
x_new = agent.x + x_diff
y_new = agent.y + y_diff

View File

@@ -72,15 +72,15 @@ class EnvObject(Object):
def encoding(self):
return c.OCCUPIED_CELL
def __init__(self, register, **kwargs):
def __init__(self, collection, **kwargs):
super(EnvObject, self).__init__(**kwargs)
self._register = register
self._collection = collection
def change_register(self, register):
register.register_item(self)
self._register.delete_env_object(self)
self._register = register
return self._register == register
def change_parent_collection(self, other_collection):
other_collection.add_item(self)
self._collection.delete_env_object(self)
self._collection = other_collection
return self._collection == other_collection
# With Rendering
@@ -153,7 +153,7 @@ class MoveableEntity(Entity):
curr_tile.leave(self)
self._tile = next_tile
self._last_tile = curr_tile
self._register.notify_change_to_value(self)
self._collection.notify_change_to_value(self)
return c.VALID
else:
return c.NOT_VALID
@@ -371,13 +371,13 @@ class Door(Entity):
def _open(self):
self.connectivity.add_edges_from([(self.pos, x) for x in range(len(self.connectivity_subgroups))])
self._state = c.OPEN_DOOR
self._register.notify_change_to_value(self)
self._collection.notify_change_to_value(self)
self.time_to_close = self.auto_close_interval
def _close(self):
self.connectivity.remove_node(self.pos)
self._state = c.CLOSED_DOOR
self._register.notify_change_to_value(self)
self._collection.notify_change_to_value(self)
def is_linked(self, old_pos, new_pos):
try:

View File

@@ -13,11 +13,11 @@ from environments import helpers as h
from environments.helpers import Constants as c
##########################################################################
# ##################### Base Register Definition ####################### #
# ################## Base Collections Definition ####################### #
##########################################################################
class ObjectRegister:
class ObjectCollection:
_accepted_objects = Object
@property
@@ -25,59 +25,59 @@ class ObjectRegister:
return f'{self.__class__.__name__}'
def __init__(self, *args, **kwargs):
self._register = dict()
self._collection = dict()
def __len__(self):
return len(self._register)
return len(self._collection)
def __iter__(self):
return iter(self.values())
def register_item(self, other: _accepted_objects):
def add_item(self, other: _accepted_objects):
assert isinstance(other, self._accepted_objects), f'All item names have to be of type ' \
f'{self._accepted_objects}, ' \
f'but were {other.__class__}.,'
self._register.update({other.name: other})
self._collection.update({other.name: other})
return self
def register_additional_items(self, others: List[_accepted_objects]):
def add_additional_items(self, others: List[_accepted_objects]):
for other in others:
self.register_item(other)
self.add_item(other)
return self
def keys(self):
return self._register.keys()
return self._collection.keys()
def values(self):
return self._register.values()
return self._collection.values()
def items(self):
return self._register.items()
return self._collection.items()
def _get_index(self, item):
try:
return next(i for i, v in enumerate(self._register.values()) if v == item)
return next(i for i, v in enumerate(self._collection.values()) if v == item)
except StopIteration:
return None
def __getitem__(self, item):
if isinstance(item, (int, np.int64, np.int32)):
if item < 0:
item = len(self._register) - abs(item)
item = len(self._collection) - abs(item)
try:
return next(v for i, v in enumerate(self._register.values()) if i == item)
return next(v for i, v in enumerate(self._collection.values()) if i == item)
except StopIteration:
return None
try:
return self._register[item]
return self._collection[item]
except KeyError:
return None
def __repr__(self):
return f'{self.__class__.__name__}[{self._register}]'
return f'{self.__class__.__name__}[{self._collection}]'
class EnvObjectRegister(ObjectRegister):
class EnvObjectCollection(ObjectCollection):
_accepted_objects = EnvObject
@@ -90,7 +90,7 @@ class EnvObjectRegister(ObjectRegister):
is_blocking_light: bool = False,
can_collide: bool = False,
can_be_shadowed: bool = True, **kwargs):
super(EnvObjectRegister, self).__init__(*args, **kwargs)
super(EnvObjectCollection, self).__init__(*args, **kwargs)
self._shape = obs_shape
self._array = None
self._individual_slices = individual_slices
@@ -99,8 +99,8 @@ class EnvObjectRegister(ObjectRegister):
self.can_be_shadowed = can_be_shadowed
self.can_collide = can_collide
def register_item(self, other: EnvObject):
super(EnvObjectRegister, self).register_item(other)
def add_item(self, other: EnvObject):
super(EnvObjectCollection, self).add_item(other)
if self._array is None:
self._array = np.zeros((1, *self._shape))
else:
@@ -145,13 +145,13 @@ class EnvObjectRegister(ObjectRegister):
if self._individual_slices:
self._array = np.delete(self._array, idx, axis=0)
else:
self.notify_change_to_free(self._register[name])
self.notify_change_to_free(self._collection[name])
# Dirty Hack to check if not beeing subclassed. In that case we need to refresh the array since positions
# in the observation array are result of enumeration. They can overide each other.
# Todo: Find a better solution
if not issubclass(self.__class__, EntityRegister) and issubclass(self.__class__, EnvObjectRegister):
if not issubclass(self.__class__, EntityCollection) and issubclass(self.__class__, EnvObjectCollection):
self._refresh_arrays()
del self._register[name]
del self._collection[name]
def delete_env_object(self, env_object: EnvObject):
del self[env_object.name]
@@ -160,19 +160,19 @@ class EnvObjectRegister(ObjectRegister):
del self[name]
class EntityRegister(EnvObjectRegister, ABC):
class EntityCollection(EnvObjectCollection, ABC):
_accepted_objects = Entity
@classmethod
def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs):
# objects_name = cls._accepted_objects.__name__
register_obj = cls(*args, **kwargs)
entities = [cls._accepted_objects(tile, register_obj, str_ident=i,
collection = cls(*args, **kwargs)
entities = [cls._accepted_objects(tile, collection, str_ident=i,
**entity_kwargs if entity_kwargs is not None else {})
for i, tile in enumerate(tiles)]
register_obj.register_additional_items(entities)
return register_obj
collection.add_additional_items(entities)
return collection
@classmethod
def from_argwhere_coordinates(cls, positions: [(int, int)], tiles, *args, entity_kwargs=None, **kwargs, ):
@@ -188,13 +188,13 @@ class EntityRegister(EnvObjectRegister, ABC):
return [entity.tile for entity in self]
def __init__(self, level_shape, *args, **kwargs):
super(EntityRegister, self).__init__(level_shape, *args, **kwargs)
super(EntityCollection, self).__init__(level_shape, *args, **kwargs)
self._lazy_eval_transforms = []
def __delitem__(self, name):
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
obj.tile.leave(obj)
super(EntityRegister, self).__delitem__(name)
super(EntityCollection, self).__delitem__(name)
def as_array(self):
if self._lazy_eval_transforms:
@@ -223,7 +223,7 @@ class EntityRegister(EnvObjectRegister, ABC):
return None
class BoundEnvObjRegister(EnvObjectRegister, ABC):
class BoundEnvObjCollection(EnvObjectCollection, ABC):
def __init__(self, entity_to_be_bound, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -248,13 +248,13 @@ class BoundEnvObjRegister(EnvObjectRegister, ABC):
return self._array[self.idx_by_entity(entity)]
class MovingEntityObjectRegister(EntityRegister, ABC):
class MovingEntityObjectCollection(EntityCollection, ABC):
def __init__(self, *args, **kwargs):
super(MovingEntityObjectRegister, self).__init__(*args, **kwargs)
super(MovingEntityObjectCollection, self).__init__(*args, **kwargs)
def notify_change_to_value(self, entity):
super(MovingEntityObjectRegister, self).notify_change_to_value(entity)
super(MovingEntityObjectCollection, self).notify_change_to_value(entity)
if entity.last_pos != c.NO_POS:
try:
self._array_change_notifyer(entity, entity.last_pos, value=c.FREE_CELL)
@@ -263,11 +263,11 @@ class MovingEntityObjectRegister(EntityRegister, ABC):
##########################################################################
# ################# Objects and Entity Registers ####################### #
# ################# Objects and Entity Collection ###################### #
##########################################################################
class GlobalPositions(EnvObjectRegister):
class GlobalPositions(EnvObjectCollection):
_accepted_objects = GlobalPosition
@@ -288,7 +288,7 @@ class GlobalPositions(EnvObjectRegister):
global_positions = [self._accepted_objects(self._shape, agent, self)
for _, agent in enumerate(agents)]
# noinspection PyTypeChecker
self.register_additional_items(global_positions)
self.add_additional_items(global_positions)
def summarize_states(self, n_steps=None):
return {}
@@ -306,7 +306,7 @@ class GlobalPositions(EnvObjectRegister):
return None
class PlaceHolders(EnvObjectRegister):
class PlaceHolders(EnvObjectCollection):
_accepted_objects = PlaceHolder
def __init__(self, *args, **kwargs):
@@ -320,12 +320,12 @@ class PlaceHolders(EnvObjectRegister):
# objects_name = cls._accepted_objects.__name__
if isinstance(values, (str, numbers.Number)):
values = [values]
register_obj = cls(*args, **kwargs)
objects = [cls._accepted_objects(register_obj, str_ident=i, fill_value=value,
collection = cls(*args, **kwargs)
objects = [cls._accepted_objects(collection, str_ident=i, fill_value=value,
**object_kwargs if object_kwargs is not None else {})
for i, value in enumerate(values)]
register_obj.register_additional_items(objects)
return register_obj
collection.add_additional_items(objects)
return collection
# noinspection DuplicatedCode
def as_array(self):
@@ -343,8 +343,8 @@ class PlaceHolders(EnvObjectRegister):
return self._array
class Entities(ObjectRegister):
_accepted_objects = EntityRegister
class Entities(ObjectCollection):
_accepted_objects = EntityCollection
@property
def arrays(self):
@@ -352,7 +352,7 @@ class Entities(ObjectRegister):
@property
def names(self):
return list(self._register.keys())
return list(self._collection.keys())
def __init__(self):
super(Entities, self).__init__()
@@ -360,21 +360,21 @@ class Entities(ObjectRegister):
def iter_individual_entitites(self):
return iter((x for sublist in self.values() for x in sublist))
def register_item(self, other: dict):
def add_item(self, other: dict):
assert not any([key for key in other.keys() if key in self.keys()]), \
"This group of entities has already been registered!"
self._register.update(other)
"This group of entities has already been added!"
self._collection.update(other)
return self
def register_additional_items(self, others: Dict):
return self.register_item(others)
def add_additional_items(self, others: Dict):
return self.add_item(others)
def by_pos(self, pos: (int, int)):
found_entities = [y for y in (x.by_pos(pos) for x in self.values() if hasattr(x, 'by_pos')) if y is not None]
return found_entities
class Walls(EntityRegister):
class Walls(EntityCollection):
_accepted_objects = Wall
def as_array(self):
@@ -396,7 +396,7 @@ class Walls(EntityRegister):
def from_argwhere_coordinates(cls, argwhere_coordinates, *args, **kwargs):
tiles = cls(*args, **kwargs)
# noinspection PyTypeChecker
tiles.register_additional_items(
tiles.add_additional_items(
[cls._accepted_objects(pos, tiles)
for pos in argwhere_coordinates]
)
@@ -441,7 +441,7 @@ class Floors(Walls):
return {}
class Agents(MovingEntityObjectRegister):
class Agents(MovingEntityObjectCollection):
_accepted_objects = Agent
def __init__(self, *args, **kwargs):
@@ -455,10 +455,10 @@ class Agents(MovingEntityObjectRegister):
old_agent = self[key]
self[key].tile.leave(self[key])
agent._name = old_agent.name
self._register[agent.name] = agent
self._collection[agent.name] = agent
class Doors(EntityRegister):
class Doors(EntityCollection):
def __init__(self, *args, have_area: bool = False, **kwargs):
self.have_area = have_area
@@ -490,7 +490,7 @@ class Doors(EntityRegister):
return super(Doors, self).as_array()
class Actions(ObjectRegister):
class Actions(ObjectCollection):
_accepted_objects = Action
@property
@@ -507,22 +507,22 @@ class Actions(ObjectRegister):
# Move this to Baseclass, Env init?
if self.allow_square_movement:
self.register_additional_items([self._accepted_objects(str_ident=direction)
for direction in h.EnvActions.square_move()])
self.add_additional_items([self._accepted_objects(str_ident=direction)
for direction in h.EnvActions.square_move()])
if self.allow_diagonal_movement:
self.register_additional_items([self._accepted_objects(str_ident=direction)
for direction in h.EnvActions.diagonal_move()])
self._movement_actions = self._register.copy()
self.add_additional_items([self._accepted_objects(str_ident=direction)
for direction in h.EnvActions.diagonal_move()])
self._movement_actions = self._collection.copy()
if self.can_use_doors:
self.register_additional_items([self._accepted_objects(str_ident=h.EnvActions.USE_DOOR)])
self.add_additional_items([self._accepted_objects(str_ident=h.EnvActions.USE_DOOR)])
if self.allow_no_op:
self.register_additional_items([self._accepted_objects(str_ident=h.EnvActions.NOOP)])
self.add_additional_items([self._accepted_objects(str_ident=h.EnvActions.NOOP)])
def is_moving_action(self, action: Union[int]):
return action in self.movement_actions.values()
class Zones(ObjectRegister):
class Zones(ObjectCollection):
@property
def accounting_zones(self):
@@ -551,5 +551,5 @@ class Zones(ObjectRegister):
def __getitem__(self, item):
return self._zone_slices[item]
def register_additional_items(self, other: Union[str, List[str]]):
def add_additional_items(self, other: Union[str, List[str]]):
raise AttributeError('You are not allowed to add additional Zones in runtime.')