diff --git a/environments/factory/base/base_factory.py b/environments/factory/base/base_factory.py index f00e049..cd397e4 100644 --- a/environments/factory/base/base_factory.py +++ b/environments/factory/base/base_factory.py @@ -156,14 +156,14 @@ class BaseFactory(gym.Env): np.argwhere(level_array == c.OCCUPIED_CELL), self._level_shape ) - self._entities.register_additional_items({c.WALLS: walls}) + self._entities.add_additional_items({c.WALLS: walls}) # Floor floor = Floors.from_argwhere_coordinates( np.argwhere(level_array == c.FREE_CELL), self._level_shape ) - self._entities.register_additional_items({c.FLOOR: floor}) + self._entities.add_additional_items({c.FLOOR: floor}) # NOPOS self._NO_POS_TILE = Floor(c.NO_POS, None) @@ -177,12 +177,12 @@ class BaseFactory(gym.Env): doors = Doors.from_tiles(door_tiles, self._level_shape, have_area=self.obs_prop.indicate_door_area, entity_kwargs=dict(context=floor) ) - self._entities.register_additional_items({c.DOORS: doors}) + self._entities.add_additional_items({c.DOORS: doors}) # Actions self._actions = Actions(self.mv_prop, can_use_doors=self.parse_doors) if additional_actions := self.actions_hook: - self._actions.register_additional_items(additional_actions) + self._actions.add_additional_items(additional_actions) # Agents agents_to_spawn = self.n_agents-len(self._injected_agents) @@ -196,10 +196,10 @@ class BaseFactory(gym.Env): if self._injected_agents: initialized_injections = list() for i, injection in enumerate(self._injected_agents): - agents.register_item(injection(self, floor.empty_tiles[0], agents, static_problem=False)) + agents.add_item(injection(self, floor.empty_tiles[0], agents, static_problem=False)) initialized_injections.append(agents[-1]) self._initialized_injections = initialized_injections - self._entities.register_additional_items({c.AGENT: agents}) + self._entities.add_additional_items({c.AGENT: agents}) if self.obs_prop.additional_agent_placeholder is not None: # TODO: Make this accept Lists for multiple placeholders @@ -210,18 +210,18 @@ class BaseFactory(gym.Env): fill_value=self.obs_prop.additional_agent_placeholder) ) - self._entities.register_additional_items({c.AGENT_PLACEHOLDER: placeholder}) + self._entities.add_additional_items({c.AGENT_PLACEHOLDER: placeholder}) # Additional Entitites from SubEnvs if additional_entities := self.entities_hook: - self._entities.register_additional_items(additional_entities) + self._entities.add_additional_items(additional_entities) if self.obs_prop.show_global_position_info: global_positions = GlobalPositions(self._level_shape) # This moved into the GlobalPosition object # obs_shape_2d = self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2) global_positions.spawn_global_position_objects(self[c.AGENT]) - self._entities.register_additional_items({c.GLOBAL_POSITION: global_positions}) + self._entities.add_additional_items({c.GLOBAL_POSITION: global_positions}) # Return return self._entities @@ -535,7 +535,7 @@ class BaseFactory(gym.Env): def _check_agent_move(self, agent, action: Action) -> (Floor, bool): # Actions - x_diff, y_diff = h.ACTIONMAP[action.identifier] + x_diff, y_diff = a.resolve_movement_action_to_coords(action.identifier) x_new = agent.x + x_diff y_new = agent.y + y_diff diff --git a/environments/factory/base/objects.py b/environments/factory/base/objects.py index d8837fd..8980b39 100644 --- a/environments/factory/base/objects.py +++ b/environments/factory/base/objects.py @@ -72,15 +72,15 @@ class EnvObject(Object): def encoding(self): return c.OCCUPIED_CELL - def __init__(self, register, **kwargs): + def __init__(self, collection, **kwargs): super(EnvObject, self).__init__(**kwargs) - self._register = register + self._collection = collection - def change_register(self, register): - register.register_item(self) - self._register.delete_env_object(self) - self._register = register - return self._register == register + def change_parent_collection(self, other_collection): + other_collection.add_item(self) + self._collection.delete_env_object(self) + self._collection = other_collection + return self._collection == other_collection # With Rendering @@ -153,7 +153,7 @@ class MoveableEntity(Entity): curr_tile.leave(self) self._tile = next_tile self._last_tile = curr_tile - self._register.notify_change_to_value(self) + self._collection.notify_change_to_value(self) return c.VALID else: return c.NOT_VALID @@ -371,13 +371,13 @@ class Door(Entity): def _open(self): self.connectivity.add_edges_from([(self.pos, x) for x in range(len(self.connectivity_subgroups))]) self._state = c.OPEN_DOOR - self._register.notify_change_to_value(self) + self._collection.notify_change_to_value(self) self.time_to_close = self.auto_close_interval def _close(self): self.connectivity.remove_node(self.pos) self._state = c.CLOSED_DOOR - self._register.notify_change_to_value(self) + self._collection.notify_change_to_value(self) def is_linked(self, old_pos, new_pos): try: diff --git a/environments/factory/base/registers.py b/environments/factory/base/registers.py index f6078dd..66ca818 100644 --- a/environments/factory/base/registers.py +++ b/environments/factory/base/registers.py @@ -13,11 +13,11 @@ from environments import helpers as h from environments.helpers import Constants as c ########################################################################## -# ##################### Base Register Definition ####################### # +# ################## Base Collections Definition ####################### # ########################################################################## -class ObjectRegister: +class ObjectCollection: _accepted_objects = Object @property @@ -25,59 +25,59 @@ class ObjectRegister: return f'{self.__class__.__name__}' def __init__(self, *args, **kwargs): - self._register = dict() + self._collection = dict() def __len__(self): - return len(self._register) + return len(self._collection) def __iter__(self): return iter(self.values()) - def register_item(self, other: _accepted_objects): + def add_item(self, other: _accepted_objects): assert isinstance(other, self._accepted_objects), f'All item names have to be of type ' \ f'{self._accepted_objects}, ' \ f'but were {other.__class__}.,' - self._register.update({other.name: other}) + self._collection.update({other.name: other}) return self - def register_additional_items(self, others: List[_accepted_objects]): + def add_additional_items(self, others: List[_accepted_objects]): for other in others: - self.register_item(other) + self.add_item(other) return self def keys(self): - return self._register.keys() + return self._collection.keys() def values(self): - return self._register.values() + return self._collection.values() def items(self): - return self._register.items() + return self._collection.items() def _get_index(self, item): try: - return next(i for i, v in enumerate(self._register.values()) if v == item) + return next(i for i, v in enumerate(self._collection.values()) if v == item) except StopIteration: return None def __getitem__(self, item): if isinstance(item, (int, np.int64, np.int32)): if item < 0: - item = len(self._register) - abs(item) + item = len(self._collection) - abs(item) try: - return next(v for i, v in enumerate(self._register.values()) if i == item) + return next(v for i, v in enumerate(self._collection.values()) if i == item) except StopIteration: return None try: - return self._register[item] + return self._collection[item] except KeyError: return None def __repr__(self): - return f'{self.__class__.__name__}[{self._register}]' + return f'{self.__class__.__name__}[{self._collection}]' -class EnvObjectRegister(ObjectRegister): +class EnvObjectCollection(ObjectCollection): _accepted_objects = EnvObject @@ -90,7 +90,7 @@ class EnvObjectRegister(ObjectRegister): is_blocking_light: bool = False, can_collide: bool = False, can_be_shadowed: bool = True, **kwargs): - super(EnvObjectRegister, self).__init__(*args, **kwargs) + super(EnvObjectCollection, self).__init__(*args, **kwargs) self._shape = obs_shape self._array = None self._individual_slices = individual_slices @@ -99,8 +99,8 @@ class EnvObjectRegister(ObjectRegister): self.can_be_shadowed = can_be_shadowed self.can_collide = can_collide - def register_item(self, other: EnvObject): - super(EnvObjectRegister, self).register_item(other) + def add_item(self, other: EnvObject): + super(EnvObjectCollection, self).add_item(other) if self._array is None: self._array = np.zeros((1, *self._shape)) else: @@ -145,13 +145,13 @@ class EnvObjectRegister(ObjectRegister): if self._individual_slices: self._array = np.delete(self._array, idx, axis=0) else: - self.notify_change_to_free(self._register[name]) + self.notify_change_to_free(self._collection[name]) # Dirty Hack to check if not beeing subclassed. In that case we need to refresh the array since positions # in the observation array are result of enumeration. They can overide each other. # Todo: Find a better solution - if not issubclass(self.__class__, EntityRegister) and issubclass(self.__class__, EnvObjectRegister): + if not issubclass(self.__class__, EntityCollection) and issubclass(self.__class__, EnvObjectCollection): self._refresh_arrays() - del self._register[name] + del self._collection[name] def delete_env_object(self, env_object: EnvObject): del self[env_object.name] @@ -160,19 +160,19 @@ class EnvObjectRegister(ObjectRegister): del self[name] -class EntityRegister(EnvObjectRegister, ABC): +class EntityCollection(EnvObjectCollection, ABC): _accepted_objects = Entity @classmethod def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs): # objects_name = cls._accepted_objects.__name__ - register_obj = cls(*args, **kwargs) - entities = [cls._accepted_objects(tile, register_obj, str_ident=i, + collection = cls(*args, **kwargs) + entities = [cls._accepted_objects(tile, collection, str_ident=i, **entity_kwargs if entity_kwargs is not None else {}) for i, tile in enumerate(tiles)] - register_obj.register_additional_items(entities) - return register_obj + collection.add_additional_items(entities) + return collection @classmethod def from_argwhere_coordinates(cls, positions: [(int, int)], tiles, *args, entity_kwargs=None, **kwargs, ): @@ -188,13 +188,13 @@ class EntityRegister(EnvObjectRegister, ABC): return [entity.tile for entity in self] def __init__(self, level_shape, *args, **kwargs): - super(EntityRegister, self).__init__(level_shape, *args, **kwargs) + super(EntityCollection, self).__init__(level_shape, *args, **kwargs) self._lazy_eval_transforms = [] def __delitem__(self, name): idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name) obj.tile.leave(obj) - super(EntityRegister, self).__delitem__(name) + super(EntityCollection, self).__delitem__(name) def as_array(self): if self._lazy_eval_transforms: @@ -223,7 +223,7 @@ class EntityRegister(EnvObjectRegister, ABC): return None -class BoundEnvObjRegister(EnvObjectRegister, ABC): +class BoundEnvObjCollection(EnvObjectCollection, ABC): def __init__(self, entity_to_be_bound, *args, **kwargs): super().__init__(*args, **kwargs) @@ -248,13 +248,13 @@ class BoundEnvObjRegister(EnvObjectRegister, ABC): return self._array[self.idx_by_entity(entity)] -class MovingEntityObjectRegister(EntityRegister, ABC): +class MovingEntityObjectCollection(EntityCollection, ABC): def __init__(self, *args, **kwargs): - super(MovingEntityObjectRegister, self).__init__(*args, **kwargs) + super(MovingEntityObjectCollection, self).__init__(*args, **kwargs) def notify_change_to_value(self, entity): - super(MovingEntityObjectRegister, self).notify_change_to_value(entity) + super(MovingEntityObjectCollection, self).notify_change_to_value(entity) if entity.last_pos != c.NO_POS: try: self._array_change_notifyer(entity, entity.last_pos, value=c.FREE_CELL) @@ -263,11 +263,11 @@ class MovingEntityObjectRegister(EntityRegister, ABC): ########################################################################## -# ################# Objects and Entity Registers ####################### # +# ################# Objects and Entity Collection ###################### # ########################################################################## -class GlobalPositions(EnvObjectRegister): +class GlobalPositions(EnvObjectCollection): _accepted_objects = GlobalPosition @@ -288,7 +288,7 @@ class GlobalPositions(EnvObjectRegister): global_positions = [self._accepted_objects(self._shape, agent, self) for _, agent in enumerate(agents)] # noinspection PyTypeChecker - self.register_additional_items(global_positions) + self.add_additional_items(global_positions) def summarize_states(self, n_steps=None): return {} @@ -306,7 +306,7 @@ class GlobalPositions(EnvObjectRegister): return None -class PlaceHolders(EnvObjectRegister): +class PlaceHolders(EnvObjectCollection): _accepted_objects = PlaceHolder def __init__(self, *args, **kwargs): @@ -320,12 +320,12 @@ class PlaceHolders(EnvObjectRegister): # objects_name = cls._accepted_objects.__name__ if isinstance(values, (str, numbers.Number)): values = [values] - register_obj = cls(*args, **kwargs) - objects = [cls._accepted_objects(register_obj, str_ident=i, fill_value=value, + collection = cls(*args, **kwargs) + objects = [cls._accepted_objects(collection, str_ident=i, fill_value=value, **object_kwargs if object_kwargs is not None else {}) for i, value in enumerate(values)] - register_obj.register_additional_items(objects) - return register_obj + collection.add_additional_items(objects) + return collection # noinspection DuplicatedCode def as_array(self): @@ -343,8 +343,8 @@ class PlaceHolders(EnvObjectRegister): return self._array -class Entities(ObjectRegister): - _accepted_objects = EntityRegister +class Entities(ObjectCollection): + _accepted_objects = EntityCollection @property def arrays(self): @@ -352,7 +352,7 @@ class Entities(ObjectRegister): @property def names(self): - return list(self._register.keys()) + return list(self._collection.keys()) def __init__(self): super(Entities, self).__init__() @@ -360,21 +360,21 @@ class Entities(ObjectRegister): def iter_individual_entitites(self): return iter((x for sublist in self.values() for x in sublist)) - def register_item(self, other: dict): + def add_item(self, other: dict): assert not any([key for key in other.keys() if key in self.keys()]), \ - "This group of entities has already been registered!" - self._register.update(other) + "This group of entities has already been added!" + self._collection.update(other) return self - def register_additional_items(self, others: Dict): - return self.register_item(others) + def add_additional_items(self, others: Dict): + return self.add_item(others) def by_pos(self, pos: (int, int)): found_entities = [y for y in (x.by_pos(pos) for x in self.values() if hasattr(x, 'by_pos')) if y is not None] return found_entities -class Walls(EntityRegister): +class Walls(EntityCollection): _accepted_objects = Wall def as_array(self): @@ -396,7 +396,7 @@ class Walls(EntityRegister): def from_argwhere_coordinates(cls, argwhere_coordinates, *args, **kwargs): tiles = cls(*args, **kwargs) # noinspection PyTypeChecker - tiles.register_additional_items( + tiles.add_additional_items( [cls._accepted_objects(pos, tiles) for pos in argwhere_coordinates] ) @@ -441,7 +441,7 @@ class Floors(Walls): return {} -class Agents(MovingEntityObjectRegister): +class Agents(MovingEntityObjectCollection): _accepted_objects = Agent def __init__(self, *args, **kwargs): @@ -455,10 +455,10 @@ class Agents(MovingEntityObjectRegister): old_agent = self[key] self[key].tile.leave(self[key]) agent._name = old_agent.name - self._register[agent.name] = agent + self._collection[agent.name] = agent -class Doors(EntityRegister): +class Doors(EntityCollection): def __init__(self, *args, have_area: bool = False, **kwargs): self.have_area = have_area @@ -490,7 +490,7 @@ class Doors(EntityRegister): return super(Doors, self).as_array() -class Actions(ObjectRegister): +class Actions(ObjectCollection): _accepted_objects = Action @property @@ -507,22 +507,22 @@ class Actions(ObjectRegister): # Move this to Baseclass, Env init? if self.allow_square_movement: - self.register_additional_items([self._accepted_objects(str_ident=direction) - for direction in h.EnvActions.square_move()]) + self.add_additional_items([self._accepted_objects(str_ident=direction) + for direction in h.EnvActions.square_move()]) if self.allow_diagonal_movement: - self.register_additional_items([self._accepted_objects(str_ident=direction) - for direction in h.EnvActions.diagonal_move()]) - self._movement_actions = self._register.copy() + self.add_additional_items([self._accepted_objects(str_ident=direction) + for direction in h.EnvActions.diagonal_move()]) + self._movement_actions = self._collection.copy() if self.can_use_doors: - self.register_additional_items([self._accepted_objects(str_ident=h.EnvActions.USE_DOOR)]) + self.add_additional_items([self._accepted_objects(str_ident=h.EnvActions.USE_DOOR)]) if self.allow_no_op: - self.register_additional_items([self._accepted_objects(str_ident=h.EnvActions.NOOP)]) + self.add_additional_items([self._accepted_objects(str_ident=h.EnvActions.NOOP)]) def is_moving_action(self, action: Union[int]): return action in self.movement_actions.values() -class Zones(ObjectRegister): +class Zones(ObjectCollection): @property def accounting_zones(self): @@ -551,5 +551,5 @@ class Zones(ObjectRegister): def __getitem__(self, item): return self._zone_slices[item] - def register_additional_items(self, other: Union[str, List[str]]): + def add_additional_items(self, other: Union[str, List[str]]): raise AttributeError('You are not allowed to add additional Zones in runtime.') diff --git a/environments/factory/factory_battery.py b/environments/factory/factory_battery.py index c09cb10..0ab132c 100644 --- a/environments/factory/factory_battery.py +++ b/environments/factory/factory_battery.py @@ -4,7 +4,7 @@ import numpy as np from environments.factory.base.base_factory import BaseFactory from environments.factory.base.objects import Agent, Action, Entity, EnvObject, BoundingMixin -from environments.factory.base.registers import EntityRegister, EnvObjectRegister +from environments.factory.base.registers import EntityCollection, EnvObjectCollection from environments.factory.base.renderer import RenderEntity from environments.helpers import Constants as BaseConstants from environments.helpers import EnvActions as BaseActions @@ -68,7 +68,7 @@ class Battery(BoundingMixin, EnvObject): if self.charge_level != 0: # noinspection PyTypeChecker self.charge_level = max(0, amount + self.charge_level) - self._register.notify_change_to_value(self) + self._collection.notify_change_to_value(self) return c.VALID else: return c.NOT_VALID @@ -79,7 +79,7 @@ class Battery(BoundingMixin, EnvObject): return attr_dict -class BatteriesRegister(EnvObjectRegister): +class BatteriesRegister(EnvObjectCollection): _accepted_objects = Battery @@ -90,7 +90,7 @@ class BatteriesRegister(EnvObjectRegister): def spawn_batteries(self, agents, initial_charge_level): batteries = [self._accepted_objects(initial_charge_level, agent, self) for _, agent in enumerate(agents)] - self.register_additional_items(batteries) + self.add_additional_items(batteries) def summarize_states(self, n_steps=None): # as dict with additional nesting @@ -140,7 +140,7 @@ class ChargePod(Entity): return summary -class ChargePods(EntityRegister): +class ChargePods(EntityCollection): _accepted_objects = ChargePod diff --git a/environments/factory/factory_dest.py b/environments/factory/factory_dest.py index a6bb6d7..b5c1477 100644 --- a/environments/factory/factory_dest.py +++ b/environments/factory/factory_dest.py @@ -9,7 +9,7 @@ from environments.factory.base.base_factory import BaseFactory from environments.helpers import Constants as BaseConstants from environments.helpers import EnvActions as BaseActions from environments.factory.base.objects import Agent, Entity, Action -from environments.factory.base.registers import Entities, EntityRegister +from environments.factory.base.registers import Entities, EntityCollection from environments.factory.base.renderer import RenderEntity @@ -73,7 +73,7 @@ class Destination(Entity): return state_summary -class Destinations(EntityRegister): +class Destinations(EntityCollection): _accepted_objects = Destination @@ -208,13 +208,13 @@ class DestFactory(BaseFactory): n_dest_to_spawn = len(destinations_to_spawn) if self.dest_prop.spawn_mode != DestModeOptions.GROUPED: destinations = [Destination(tile, c.DEST) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]] - self[c.DEST].register_additional_items(destinations) + self[c.DEST].add_additional_items(destinations) for dest in destinations_to_spawn: del self._dest_spawn_timer[dest] self.print(f'{n_dest_to_spawn} new destinations have been spawned') elif self.dest_prop.spawn_mode == DestModeOptions.GROUPED and n_dest_to_spawn == self.dest_prop.n_dests: destinations = [Destination(tile, self[c.DEST]) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]] - self[c.DEST].register_additional_items(destinations) + self[c.DEST].add_additional_items(destinations) for dest in destinations_to_spawn: del self._dest_spawn_timer[dest] self.print(f'{n_dest_to_spawn} new destinations have been spawned') @@ -231,7 +231,7 @@ class DestFactory(BaseFactory): self._dest_spawn_timer[key] = min(self.dest_prop.spawn_frequency, self._dest_spawn_timer[key] + 1) for dest in list(self[c.DEST].values()): if dest.is_considered_reached: - dest.change_register(self[c.DEST]) + dest.change_parent_collection(self[c.DEST]) self._dest_spawn_timer[dest.name] = 0 self.print(f'{dest.name} is reached now, removing...') else: diff --git a/environments/factory/factory_dirt.py b/environments/factory/factory_dirt.py index bdf1834..3ba9779 100644 --- a/environments/factory/factory_dirt.py +++ b/environments/factory/factory_dirt.py @@ -11,7 +11,7 @@ from environments.helpers import EnvActions as BaseActions from environments.factory.base.base_factory import BaseFactory from environments.factory.base.objects import Agent, Action, Entity, Floor -from environments.factory.base.registers import Entities, EntityRegister +from environments.factory.base.registers import Entities, EntityCollection from environments.factory.base.renderer import RenderEntity from environments.utility_classes import ObservationProperties @@ -61,7 +61,7 @@ class Dirt(Entity): def set_new_amount(self, amount): self._amount = amount - self._register.notify_change_to_value(self) + self._collection.notify_change_to_value(self) def summarize_state(self, **kwargs): state_dict = super().summarize_state(**kwargs) @@ -69,7 +69,7 @@ class Dirt(Entity): return state_dict -class DirtRegister(EntityRegister): +class DirtRegister(EntityCollection): _accepted_objects = Dirt @@ -93,7 +93,7 @@ class DirtRegister(EntityRegister): dirt = self.by_pos(tile.pos) if dirt is None: dirt = Dirt(tile, self, amount=self.dirt_properties.max_spawn_amount) - self.register_item(dirt) + self.add_item(dirt) else: new_value = dirt.amount + self.dirt_properties.max_spawn_amount dirt.set_new_amount(min(new_value, self.dirt_properties.max_local_amount)) diff --git a/environments/factory/factory_dirt_stationary_machines.py b/environments/factory/factory_dirt_stationary_machines.py index e0bf477..377666e 100644 --- a/environments/factory/factory_dirt_stationary_machines.py +++ b/environments/factory/factory_dirt_stationary_machines.py @@ -5,10 +5,10 @@ import numpy as np from environments.factory.base.objects import Agent, Entity, Action from environments.factory.factory_dirt import Dirt, DirtRegister, DirtFactory from environments.factory.base.objects import Floor -from environments.factory.base.registers import Floors, Entities, EntityRegister +from environments.factory.base.registers import Floors, Entities, EntityCollection -class Machines(EntityRegister): +class Machines(EntityCollection): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/environments/factory/factory_item.py b/environments/factory/factory_item.py index 5de3fe4..8d91e7f 100644 --- a/environments/factory/factory_item.py +++ b/environments/factory/factory_item.py @@ -9,7 +9,7 @@ from environments.helpers import Constants as BaseConstants from environments.helpers import EnvActions as BaseActions from environments import helpers as h from environments.factory.base.objects import Agent, Entity, Action, Floor -from environments.factory.base.registers import Entities, EntityRegister, BoundEnvObjRegister, ObjectRegister +from environments.factory.base.registers import Entities, EntityCollection, BoundEnvObjCollection, ObjectCollection from environments.factory.base.renderer import RenderEntity @@ -53,17 +53,17 @@ class Item(Entity): self._auto_despawn = auto_despawn def set_tile_to(self, no_pos_tile): - assert self._register.__class__.__name__ != ItemRegister.__class__ + assert self._collection.__class__.__name__ != ItemRegister.__class__ self._tile = no_pos_tile -class ItemRegister(EntityRegister): +class ItemRegister(EntityCollection): _accepted_objects = Item def spawn_items(self, tiles: List[Floor]): items = [Item(tile, self) for tile in tiles] - self.register_additional_items(items) + self.add_additional_items(items) def despawn_items(self, items: List[Item]): items = [items] if isinstance(items, Item) else items @@ -71,7 +71,7 @@ class ItemRegister(EntityRegister): del self[item] -class Inventory(BoundEnvObjRegister): +class Inventory(BoundEnvObjCollection): @property def name(self): @@ -98,7 +98,7 @@ class Inventory(BoundEnvObjRegister): return item_to_pop -class Inventories(ObjectRegister): +class Inventories(ObjectCollection): _accepted_objects = Inventory is_blocking_light = False @@ -114,7 +114,7 @@ class Inventories(ObjectRegister): def spawn_inventories(self, agents, capacity): inventories = [self._accepted_objects(agent, capacity, self._obs_shape) for _, agent in enumerate(agents)] - self.register_additional_items(inventories) + self.add_additional_items(inventories) def idx_by_entity(self, entity): try: @@ -161,7 +161,7 @@ class DropOffLocation(Entity): return super().summarize_state(n_steps=n_steps) -class DropOffLocations(EntityRegister): +class DropOffLocations(EntityCollection): _accepted_objects = DropOffLocation @@ -250,7 +250,7 @@ class ItemFactory(BaseFactory): reason=a.ITEM_ACTION, info=info_dict) return valid, reward elif item := self[c.ITEM].by_pos(agent.pos): - item.change_register(inventory) + item.change_parent_collection(inventory) item.set_tile_to(self._NO_POS_TILE) self.print(f'{agent.name} just picked up an item at {agent.pos}') info_dict = {f'{agent.name}_{a.ITEM_ACTION}_VALID': 1, f'{a.ITEM_ACTION}_VALID': 1} diff --git a/environments/helpers.py b/environments/helpers.py index 954ff78..526b612 100644 --- a/environments/helpers.py +++ b/environments/helpers.py @@ -7,47 +7,76 @@ import numpy as np from numpy.typing import ArrayLike from stable_baselines3 import PPO, DQN, A2C -MODEL_MAP = dict(PPO=PPO, DQN=DQN, A2C=A2C) -LEVELS_DIR = 'levels' -STEPS_START = 1 - -TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles'] -IGNORED_DF_COLUMNS = ['Episode', 'Run', 'train_step', 'step', 'index', 'dirt_amount', - 'dirty_tile_count', 'terminal_observation', 'episode'] +""" +This file is used for: + 1. string based definition + Use a class like `Constants`, to define attributes, which then reveal strings. + These can be used for naming convention along the environments as well as keys for mappings such as dicts etc. + When defining new envs, use class inheritance. + + 2. utility function definition + There are static utility functions which are not bound to a specific environment. + In this file they are defined to be used across the entire package. +""" + + +MODEL_MAP = dict(PPO=PPO, DQN=DQN, A2C=A2C) # For use in studies and experiments + + +LEVELS_DIR = 'levels' # for use in studies and experiments +STEPS_START = 1 # Define where to the stepcount; which is the first step + +# Not used anymore? Clean! +# TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles'] +IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignored when loading monitor files + 'train_step', 'step', 'index', 'dirt_amount', 'dirty_tile_count', 'terminal_observation', + 'episode'] -# Constants class Constants: - WALL = '#' - WALLS = 'Walls' - FLOOR = 'Floor' - DOOR = 'D' - DANGER_ZONE = 'x' - LEVEL = 'Level' - AGENT = 'Agent' - AGENT_PLACEHOLDER = 'AGENT_PLACEHOLDER' - GLOBAL_POSITION = 'GLOBAL_POSITION' - FREE_CELL = 0 - OCCUPIED_CELL = 1 - SHADOWED_CELL = -1 - ACCESS_DOOR_CELL = 1/3 - OPEN_DOOR_CELL = 2/3 - CLOSED_DOOR_CELL = 3/3 - NO_POS = (-9999, -9999) - DOORS = 'Doors' - CLOSED_DOOR = 'closed' - OPEN_DOOR = 'open' - ACCESS_DOOR = 'access' + """ + String based mapping. Use these to handle keys or define values, which can be then be used globaly. + Please use class inheritance when defining new environments. + """ - ACTION = 'action' - COLLISION = 'collision' - VALID = True - NOT_VALID = False + WALL = '#' # Wall tile identifier for resolving the string based map files. + DOOR = 'D' # Door identifier for resolving the string based map files. + DANGER_ZONE = 'x' # Dange Zone tile identifier for resolving the string based map files. + + WALLS = 'Walls' # Identifier of Wall-objects and sets (collections). + FLOOR = 'Floor' # Identifier of Floor-objects and sets (collections). + DOORS = 'Doors' # Identifier of Door-objects and sets (collections). + LEVEL = 'Level' # Identifier of Level-objects and sets (collections). + AGENT = 'Agent' # Identifier of Agent-objects and sets (collections). + AGENT_PLACEHOLDER = 'AGENT_PLACEHOLDER' # Identifier of Placeholder-objects and sets (collections). + GLOBAL_POSITION = 'GLOBAL_POSITION' # Identifier of the global position slice + + FREE_CELL = 0 # Free-Cell value used in observation + OCCUPIED_CELL = 1 # Occupied-Cell value used in observation + SHADOWED_CELL = -1 # Shadowed-Cell value used in observation + ACCESS_DOOR_CELL = 1/3 # Access-door-Cell value used in observation + OPEN_DOOR_CELL = 2/3 # Open-door-Cell value used in observation + CLOSED_DOOR_CELL = 3/3 # Closed-door-Cell value used in observation + + NO_POS = (-9999, -9999) # Invalid Position value used in the environment (something is off-grid) + + CLOSED_DOOR = 'closed' # Identifier to compare door-is-closed state + OPEN_DOOR = 'open' # Identifier to compare door-is-open state + # ACCESS_DOOR = 'access' # Identifier to compare access positions + + ACTION = 'action' # Identifier of Action-objects and sets (collections). + COLLISION = 'collision' # Identifier to use in the context of collitions. + VALID = True # Identifier to rename boolean values in the context of actions. + NOT_VALID = False # Identifier to rename boolean values in the context of actions. class EnvActions: + """ + String based mapping. Use these to identifiy actions, can be used globaly. + Please use class inheritance when defining new environments with new actions. + """ # Movements NORTH = 'north' EAST = 'east' @@ -63,24 +92,77 @@ class EnvActions: NOOP = 'no_op' USE_DOOR = 'use_door' + _ACTIONMAP = defaultdict(lambda: (0, 0), + {NORTH: (-1, 0), NORTHEAST: (-1, 1), + EAST: (0, 1), SOUTHEAST: (1, 1), + SOUTH: (1, 0), SOUTHWEST: (1, -1), + WEST: (0, -1), NORTHWEST: (-1, -1) + } + ) + @classmethod - def is_move(cls, other): - return any([other == direction for direction in cls.movement_actions()]) + def is_move(cls, action): + """ + Classmethod; checks if given action is a movement action or not. Depending on the env. configuration, + Movement actions are either `manhattan` (square) style movements (up,down, left, right) and/or diagonal. + + :param action: Action to be checked + :type action: str + :return: Whether the given action is a movement action. + :rtype: bool + """ + return any([action == direction for direction in cls.movement_actions()]) @classmethod def square_move(cls): + """ + Classmethod; return a list of movement actions that are considered square or `manhattan` style movements. + + :return: A list of movement actions. + :rtype: list(str) + """ return [cls.NORTH, cls.EAST, cls.SOUTH, cls.WEST] @classmethod def diagonal_move(cls): + """ + Classmethod; return a list of movement actions that are considered diagonal movements. + + :return: A list of movement actions. + :rtype: list(str) + """ return [cls.NORTHEAST, cls.SOUTHEAST, cls.SOUTHWEST, cls.NORTHWEST] @classmethod def movement_actions(cls): + """ + Classmethod; return a list of all available movement actions. + Please note, that this is indipendent from the env. properties + + :return: A list of movement actions. + :rtype: list(str) + """ return list(itertools.chain(cls.square_move(), cls.diagonal_move())) + @classmethod + def resolve_movement_action_to_coords(cls, action): + """ + Classmethod; resolve movement actions. Given a movement action, return the delta in coordinates it stands for. + How does the current entity coordinate change if it performs the given action? + Please note, this is indipendent from the env. properties + + :return: Delta coorinates. + :rtype: tuple(int, int) + """ + return cls._ACTIONMAP[action] + class RewardsBase(NamedTuple): + """ + Value based mapping. Use these to define reward values for specific conditions (i.e. the action + in a given context), can be used globaly. + Please use class inheritance when defining new environments with new rewards. + """ MOVEMENTS_VALID: float = -0.001 MOVEMENTS_FAIL: float = -0.05 NOOP: float = -0.01 @@ -89,23 +171,31 @@ class RewardsBase(NamedTuple): COLLISION: float = -0.5 -m = EnvActions -c = Constants - -ACTIONMAP = defaultdict(lambda: (0, 0), - {m.NORTH: (-1, 0), m.NORTHEAST: (-1, 1), - m.EAST: (0, 1), m.SOUTHEAST: (1, 1), - m.SOUTH: (1, 0), m.SOUTHWEST: (1, -1), - m.WEST: (0, -1), m.NORTHWEST: (-1, -1) - } - ) - - class ObservationTranslator: def __init__(self, obs_shape_2d: (int, int), this_named_observation_space: Dict[str, dict], - *per_agent_named_obs_space: Dict[str, dict], + *per_agent_named_obs_spaces: Dict[str, dict], placeholder_fill_value: Union[int, str] = 'N'): + """ + This is a helper class, which converts agents observations from joined environments. + For example, agents trained in different environments may expect different observations. + This class translates from larger observations spaces to smaller. + A string identifier based approach is used. + Currently, it is not possible to mix different obs shapes. + + :param obs_shape_2d: The shape of the observation the agents expect. + :type obs_shape_2d: tuple(int, int) + + :param this_named_observation_space: `Named observation space` of the joined environment. + :type this_named_observation_space: Dict[str, dict] + + :param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded. + type per_agent_named_obs_spaces: Dict[str, dict] + + :param placeholder_fill_value: Currently not fully implemented!!! + :type placeholder_fill_value: Union[int, str] = 'N') + """ + assert len(obs_shape_2d) == 2 self.obs_shape = obs_shape_2d if isinstance(placeholder_fill_value, str): @@ -119,7 +209,7 @@ class ObservationTranslator: self.random_fill = None self._this_named_obs_space = this_named_observation_space - self._per_agent_named_obs_space = list(per_agent_named_obs_space) + self._per_agent_named_obs_space = list(per_agent_named_obs_spaces) def translate_observation(self, agent_idx: int, obs: np.ndarray): target_obs_space = self._per_agent_named_obs_space[agent_idx] @@ -137,6 +227,19 @@ class ObservationTranslator: class ActionTranslator: def __init__(self, target_named_action_space: Dict[str, int], *per_agent_named_action_space: Dict[str, int]): + """ + This is a helper class, which converts agents action spaces to a joined environments action space. + For example, agents trained in different environments may have different action spaces. + This class translates from smaller individual agent action spaces to larger joined spaces. + A string identifier based approach is used. + + :param target_named_action_space: Joined `Named action space` for the current environment. + :type target_named_action_space: Dict[str, dict] + + :param per_agent_named_action_space: `Named action space` one for each agent. Overloaded. + :type per_agent_named_action_space: Dict[str, dict] + """ + self._target_named_action_space = target_named_action_space self._per_agent_named_action_space = list(per_agent_named_action_space) self._per_agent_idx_actions = [{idx: a for a, idx in x.items()} for x in self._per_agent_named_action_space] @@ -155,6 +258,16 @@ class ActionTranslator: # Utility functions def parse_level(path): + """ + Given the path to a strin based `level` or `map` representation, this function reads the content. + Cleans `space`, checks for equal length of each row and returns a list of lists. + + :param path: Path to the `level` or `map` file on harddrive. + :type path: os.Pathlike + + :return: The read string representation of the `level` or `map` + :rtype: List[List[str]] + """ with path.open('r') as lvl: level = list(map(lambda x: list(x.strip()), lvl.readlines())) if len(set([len(line) for line in level])) > 1: @@ -162,29 +275,56 @@ def parse_level(path): return level -def one_hot_level(level, wall_char: str = c.WALL): +def one_hot_level(level, wall_char: str = Constants.WALL): + """ + Given a string based level representation (list of lists, see function `parse_level`), this function creates a + binary numpy array or `grid`. Grid values that equal `wall_char` become of `Constants.OCCUPIED_CELL` value. + Can be changed to filter for any symbol. + + :param level: String based level representation (list of lists, see function `parse_level`). + :param wall_char: List[List[str]] + + :return: Binary numpy array + :rtype: np.typing._array_like.ArrayLike + """ + grid = np.array(level) binary_grid = np.zeros(grid.shape, dtype=np.int8) - binary_grid[grid == wall_char] = c.OCCUPIED_CELL + binary_grid[grid == wall_char] = Constants.OCCUPIED_CELL return binary_grid def check_position(slice_to_check_against: ArrayLike, position_to_check: Tuple[int, int]): + """ + Given a slice (2-D Arraylike object) + + :param slice_to_check_against: The slice to check for accessability + :type slice_to_check_against: np.typing._array_like.ArrayLike + + :param position_to_check: Position in slice that should be checked. Can be outside of slice boundarys. + :type position_to_check: tuple(int, int) + + :return: Whether a position can be moved to. + :rtype: bool + """ x_pos, y_pos = position_to_check # Check if agent colides with grid boundrys valid = not ( x_pos < 0 or y_pos < 0 or x_pos >= slice_to_check_against.shape[0] - or y_pos >= slice_to_check_against.shape[0] + or y_pos >= slice_to_check_against.shape[1] ) # Check for collision with level walls valid = valid and not slice_to_check_against[x_pos, y_pos] - return c.VALID if valid else c.NOT_VALID + return Constants.VALID if valid else Constants.NOT_VALID def asset_str(agent): + """ + FIXME @ romue + """ # What does this abonimation do? # if any([x is None for x in [cls._slices[j] for j in agent.collisions]]): # print('error') @@ -192,33 +332,50 @@ def asset_str(agent): action = step_result['action_name'] valid = step_result['action_valid'] col_names = [x.name for x in step_result['collisions']] - if any(c.AGENT in name for name in col_names): + if any(Constants.AGENT in name for name in col_names): return 'agent_collision', 'blank' - elif not valid or c.LEVEL in col_names or c.AGENT in col_names: - return c.AGENT, 'invalid' + elif not valid or Constants.LEVEL in col_names or Constants.AGENT in col_names: + return Constants.AGENT, 'invalid' elif valid and not EnvActions.is_move(action): - return c.AGENT, 'valid' + return Constants.AGENT, 'valid' elif valid and EnvActions.is_move(action): - return c.AGENT, 'move' + return Constants.AGENT, 'move' else: - return c.AGENT, 'idle' + return Constants.AGENT, 'idle' else: - return c.AGENT, 'idle' + return Constants.AGENT, 'idle' def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True): + """ + Given a set of coordinates, this function contructs a non-directed graph, by conncting adjected points. + There are three combinations of settings: + Allow all neigbors: Distance(a, b) <= sqrt(2) + Allow only manhattan: Distance(a, b) == 1 + Allow only euclidean: Distance(a, b) == sqrt(2) + + + :param coordiniates_or_tiles: A set of coordinates. + :type coordiniates_or_tiles: Tiles + :param allow_euclidean_connections: Whether to regard diagonal adjected cells as neighbors + :type: bool + :param allow_manhattan_connections: Whether to regard directly adjected cells as neighbors + :type: bool + + :return: A graph with nodes that are conneceted as specified by the parameters. + :rtype: nx.Graph + """ assert allow_euclidean_connections or allow_manhattan_connections if hasattr(coordiniates_or_tiles, 'positions'): coordiniates_or_tiles = coordiniates_or_tiles.positions possible_connections = itertools.combinations(coordiniates_or_tiles, 2) graph = nx.Graph() for a, b in possible_connections: - diff = abs(np.subtract(a, b)) - if not max(diff) > 1: - if allow_manhattan_connections and allow_euclidean_connections: - graph.add_edge(a, b) - elif not allow_manhattan_connections and allow_euclidean_connections and all(diff): - graph.add_edge(a, b) - elif allow_manhattan_connections and not allow_euclidean_connections and not all(diff) and any(diff): - graph.add_edge(a, b) + diff = np.linalg.norm(np.asarray(a)-np.asarray(b)) + if allow_manhattan_connections and allow_euclidean_connections and diff <= np.sqrt(2): + graph.add_edge(a, b) + elif not allow_manhattan_connections and allow_euclidean_connections and diff == np.sqrt(2): + graph.add_edge(a, b) + elif allow_manhattan_connections and not allow_euclidean_connections and diff == 1: + graph.add_edge(a, b) return graph diff --git a/environments/utility_classes.py b/environments/utility_classes.py index 9943a61..c91d42b 100644 --- a/environments/utility_classes.py +++ b/environments/utility_classes.py @@ -4,6 +4,22 @@ from gym.wrappers.frame_stack import FrameStack class AgentRenderOptions(object): + """ + Class that specifies the available options for the way agents are represented in the env observation. + + SEPERATE: + Each agent is represented in a seperate slice as Constant.OCCUPIED_CELL value (one hot) + + COMBINED: + For all agent, value of Constant.OCCUPIED_CELL is added to a zero-value slice at the agents position (sum(SEPERATE)) + + LEVEL: + The combined slice is added to the LEVEL-slice. (Agents appear as obstacle / wall) + + NOT: + The position of individual agents can not be read from the observation. + """ + SEPERATE = 'seperate' COMBINED = 'combined' LEVEL = 'lvl' @@ -11,24 +27,61 @@ class AgentRenderOptions(object): class MovementProperties(NamedTuple): + """ + Property holder; for setting multiple related parameters through a single parameter. Comes with default values. + """ + + """Allow the manhattan style movement on a grid (move to cells that are connected by square edges).""" allow_square_movement: bool = True + + """Allow diagonal movement on the grid (move to cells that are connected by square corners).""" allow_diagonal_movement: bool = False + + """Allow the agent to just do nothing; not move (NO-OP).""" allow_no_op: bool = False class ObservationProperties(NamedTuple): - # Todo: Add Description + """ + Property holder; for setting multiple related parameters through a single parameter. Comes with default values. + """ + + """How to represent agents in the observation space. This may also alters the obs-shape.""" render_agents: AgentRenderOptions = AgentRenderOptions.SEPERATE + + """Obserations are build per agent; whether the current agent should be represented in its own observation.""" omit_agent_self: bool = True + + """Their might be the case you want to modify the agents obs-space, so that it can be used with additional obs. + The additional slice can be filled with any number""" additional_agent_placeholder: Union[None, str, int] = None + + """Whether to cast shadows (make floortiles and items hidden).; """ cast_shadows: bool = True + + """Frame Stacking is a methode do give some temporal information to the agents. + This paramters controls how many "old-frames" """ frames_to_stack: int = 0 - pomdp_r: int = 0 + + """Specifies the radius (_r) of the agents field of view. Please note, that the agents grid cellis not taken + accountance for. This means, that the resulting field of view diameter = `pomdp_r * 2 + 1`. + A 'pomdp_r' of 0 always returns the full env == no partial observability.""" + pomdp_r: int = 2 + + """Whether to place a visual encoding on walkable tiles around the doors. This is helpfull when the doors can be + operated from their surrounding area. So the agent can more easily get a notion of where to choose the door option. + However, this is not necesarry at all. + """ indicate_door_area: bool = False + + """Whether to add the agents normalized global position as float values (2,1) to a seperate information slice. + More optional informations are to come. + """ show_global_position_info: bool = False class MarlFrameStack(gym.ObservationWrapper): + """todo @romue404""" def __init__(self, env): super().__init__(env) diff --git a/studies/e_1.py b/studies/e_1.py index ffdfe82..ee1ae28 100644 --- a/studies/e_1.py +++ b/studies/e_1.py @@ -215,7 +215,7 @@ if __name__ == '__main__': clean_amount=0.34, max_spawn_amount=0.1, max_global_amount=20, max_local_amount=1, spawn_frequency=0, max_spawn_ratio=0.05, - dirt_smear_amount=0.0, agent_can_interact=True) + dirt_smear_amount=0.0) item_props = ItemProperties(n_items=10, spawn_frequency=30, n_drop_off_locations=2, max_agent_inventory_capacity=15) @@ -349,6 +349,7 @@ if __name__ == '__main__': # Env Init & Model kwargs definition if model_cls.__name__ in ["PPO", "A2C"]: # env_factory = env_class(**env_kwargs) + env_factory = SubprocVecEnv([encapsule_env_factory(env_class, env_kwargs) for _ in range(6)], start_method="spawn") model_kwargs = policy_model_kwargs() diff --git a/studies/single_run_with_export.py b/studies/single_run_with_export.py index 885fdbb..2227f49 100644 --- a/studies/single_run_with_export.py +++ b/studies/single_run_with_export.py @@ -213,7 +213,8 @@ if __name__ == '__main__': env_factory.save_params(param_path) # EnvMonitor Init - callbacks = [EnvMonitor(env_factory)] + env_monitor = EnvMonitor(env_factory) + callbacks = [env_monitor] # Model Init model = model_cls("MlpPolicy", env_factory, **policy_model_kwargs, @@ -233,7 +234,7 @@ if __name__ == '__main__': model.save(save_path) # Monitor Save - callbacks[0].save_run(combination_path / 'monitor.pick', + env_monitor.save_run(combination_path / 'monitor.pick', auto_plotting_keys=['step_reward', 'collision'] + env_plot_keys) # Better be save then sorry: Clean up!