mirror of
				https://github.com/illiumst/marl-factory-grid.git
				synced 2025-10-31 04:37:25 +01:00 
			
		
		
		
	Adjustments and Documentation
This commit is contained in:
		| @@ -156,14 +156,14 @@ class BaseFactory(gym.Env): | ||||
|             np.argwhere(level_array == c.OCCUPIED_CELL), | ||||
|             self._level_shape | ||||
|         ) | ||||
|         self._entities.register_additional_items({c.WALLS: walls}) | ||||
|         self._entities.add_additional_items({c.WALLS: walls}) | ||||
|  | ||||
|         # Floor | ||||
|         floor = Floors.from_argwhere_coordinates( | ||||
|             np.argwhere(level_array == c.FREE_CELL), | ||||
|             self._level_shape | ||||
|         ) | ||||
|         self._entities.register_additional_items({c.FLOOR: floor}) | ||||
|         self._entities.add_additional_items({c.FLOOR: floor}) | ||||
|  | ||||
|         # NOPOS | ||||
|         self._NO_POS_TILE = Floor(c.NO_POS, None) | ||||
| @@ -177,12 +177,12 @@ class BaseFactory(gym.Env): | ||||
|                 doors = Doors.from_tiles(door_tiles, self._level_shape, have_area=self.obs_prop.indicate_door_area, | ||||
|                                          entity_kwargs=dict(context=floor) | ||||
|                                          ) | ||||
|                 self._entities.register_additional_items({c.DOORS: doors}) | ||||
|                 self._entities.add_additional_items({c.DOORS: doors}) | ||||
|  | ||||
|         # Actions | ||||
|         self._actions = Actions(self.mv_prop, can_use_doors=self.parse_doors) | ||||
|         if additional_actions := self.actions_hook: | ||||
|             self._actions.register_additional_items(additional_actions) | ||||
|             self._actions.add_additional_items(additional_actions) | ||||
|  | ||||
|         # Agents | ||||
|         agents_to_spawn = self.n_agents-len(self._injected_agents) | ||||
| @@ -196,10 +196,10 @@ class BaseFactory(gym.Env): | ||||
|         if self._injected_agents: | ||||
|             initialized_injections = list() | ||||
|             for i, injection in enumerate(self._injected_agents): | ||||
|                 agents.register_item(injection(self, floor.empty_tiles[0], agents, static_problem=False)) | ||||
|                 agents.add_item(injection(self, floor.empty_tiles[0], agents, static_problem=False)) | ||||
|                 initialized_injections.append(agents[-1]) | ||||
|             self._initialized_injections = initialized_injections | ||||
|         self._entities.register_additional_items({c.AGENT: agents}) | ||||
|         self._entities.add_additional_items({c.AGENT: agents}) | ||||
|  | ||||
|         if self.obs_prop.additional_agent_placeholder is not None: | ||||
|             # TODO: Make this accept Lists for multiple placeholders | ||||
| @@ -210,18 +210,18 @@ class BaseFactory(gym.Env): | ||||
|                                                        fill_value=self.obs_prop.additional_agent_placeholder) | ||||
|                                                    ) | ||||
|  | ||||
|             self._entities.register_additional_items({c.AGENT_PLACEHOLDER: placeholder}) | ||||
|             self._entities.add_additional_items({c.AGENT_PLACEHOLDER: placeholder}) | ||||
|  | ||||
|         # Additional Entitites from SubEnvs | ||||
|         if additional_entities := self.entities_hook: | ||||
|             self._entities.register_additional_items(additional_entities) | ||||
|             self._entities.add_additional_items(additional_entities) | ||||
|  | ||||
|         if self.obs_prop.show_global_position_info: | ||||
|             global_positions = GlobalPositions(self._level_shape) | ||||
|             # This moved into the GlobalPosition object | ||||
|             # obs_shape_2d = self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2) | ||||
|             global_positions.spawn_global_position_objects(self[c.AGENT]) | ||||
|             self._entities.register_additional_items({c.GLOBAL_POSITION: global_positions}) | ||||
|             self._entities.add_additional_items({c.GLOBAL_POSITION: global_positions}) | ||||
|  | ||||
|         # Return | ||||
|         return self._entities | ||||
| @@ -535,7 +535,7 @@ class BaseFactory(gym.Env): | ||||
|  | ||||
|     def _check_agent_move(self, agent, action: Action) -> (Floor, bool): | ||||
|         # Actions | ||||
|         x_diff, y_diff = h.ACTIONMAP[action.identifier] | ||||
|         x_diff, y_diff = a.resolve_movement_action_to_coords(action.identifier) | ||||
|         x_new = agent.x + x_diff | ||||
|         y_new = agent.y + y_diff | ||||
|  | ||||
|   | ||||
| @@ -72,15 +72,15 @@ class EnvObject(Object): | ||||
|     def encoding(self): | ||||
|         return c.OCCUPIED_CELL | ||||
|  | ||||
|     def __init__(self, register, **kwargs): | ||||
|     def __init__(self, collection, **kwargs): | ||||
|         super(EnvObject, self).__init__(**kwargs) | ||||
|         self._register = register | ||||
|         self._collection = collection | ||||
|  | ||||
|     def change_register(self, register): | ||||
|         register.register_item(self) | ||||
|         self._register.delete_env_object(self) | ||||
|         self._register = register | ||||
|         return self._register == register | ||||
|     def change_parent_collection(self, other_collection): | ||||
|         other_collection.add_item(self) | ||||
|         self._collection.delete_env_object(self) | ||||
|         self._collection = other_collection | ||||
|         return self._collection == other_collection | ||||
| # With Rendering | ||||
|  | ||||
|  | ||||
| @@ -153,7 +153,7 @@ class MoveableEntity(Entity): | ||||
|             curr_tile.leave(self) | ||||
|             self._tile = next_tile | ||||
|             self._last_tile = curr_tile | ||||
|             self._register.notify_change_to_value(self) | ||||
|             self._collection.notify_change_to_value(self) | ||||
|             return c.VALID | ||||
|         else: | ||||
|             return c.NOT_VALID | ||||
| @@ -371,13 +371,13 @@ class Door(Entity): | ||||
|     def _open(self): | ||||
|         self.connectivity.add_edges_from([(self.pos, x) for x in range(len(self.connectivity_subgroups))]) | ||||
|         self._state = c.OPEN_DOOR | ||||
|         self._register.notify_change_to_value(self) | ||||
|         self._collection.notify_change_to_value(self) | ||||
|         self.time_to_close = self.auto_close_interval | ||||
|  | ||||
|     def _close(self): | ||||
|         self.connectivity.remove_node(self.pos) | ||||
|         self._state = c.CLOSED_DOOR | ||||
|         self._register.notify_change_to_value(self) | ||||
|         self._collection.notify_change_to_value(self) | ||||
|  | ||||
|     def is_linked(self, old_pos, new_pos): | ||||
|         try: | ||||
|   | ||||
| @@ -13,11 +13,11 @@ from environments import helpers as h | ||||
| from environments.helpers import Constants as c | ||||
|  | ||||
| ########################################################################## | ||||
| # ##################### Base Register Definition ####################### # | ||||
| # ################## Base Collections Definition ####################### # | ||||
| ########################################################################## | ||||
|  | ||||
|  | ||||
| class ObjectRegister: | ||||
| class ObjectCollection: | ||||
|     _accepted_objects = Object | ||||
|  | ||||
|     @property | ||||
| @@ -25,59 +25,59 @@ class ObjectRegister: | ||||
|         return f'{self.__class__.__name__}' | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         self._register = dict() | ||||
|         self._collection = dict() | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self._register) | ||||
|         return len(self._collection) | ||||
|  | ||||
|     def __iter__(self): | ||||
|         return iter(self.values()) | ||||
|  | ||||
|     def register_item(self, other: _accepted_objects): | ||||
|     def add_item(self, other: _accepted_objects): | ||||
|         assert isinstance(other, self._accepted_objects), f'All item names have to be of type ' \ | ||||
|                                                           f'{self._accepted_objects}, ' \ | ||||
|                                                           f'but were {other.__class__}.,' | ||||
|         self._register.update({other.name: other}) | ||||
|         self._collection.update({other.name: other}) | ||||
|         return self | ||||
|  | ||||
|     def register_additional_items(self, others: List[_accepted_objects]): | ||||
|     def add_additional_items(self, others: List[_accepted_objects]): | ||||
|         for other in others: | ||||
|             self.register_item(other) | ||||
|             self.add_item(other) | ||||
|         return self | ||||
|  | ||||
|     def keys(self): | ||||
|         return self._register.keys() | ||||
|         return self._collection.keys() | ||||
|  | ||||
|     def values(self): | ||||
|         return self._register.values() | ||||
|         return self._collection.values() | ||||
|  | ||||
|     def items(self): | ||||
|         return self._register.items() | ||||
|         return self._collection.items() | ||||
|  | ||||
|     def _get_index(self, item): | ||||
|         try: | ||||
|             return next(i for i, v in enumerate(self._register.values()) if v == item) | ||||
|             return next(i for i, v in enumerate(self._collection.values()) if v == item) | ||||
|         except StopIteration: | ||||
|             return None | ||||
|  | ||||
|     def __getitem__(self, item): | ||||
|         if isinstance(item, (int, np.int64, np.int32)): | ||||
|             if item < 0: | ||||
|                 item = len(self._register) - abs(item) | ||||
|                 item = len(self._collection) - abs(item) | ||||
|             try: | ||||
|                 return next(v for i, v in enumerate(self._register.values()) if i == item) | ||||
|                 return next(v for i, v in enumerate(self._collection.values()) if i == item) | ||||
|             except StopIteration: | ||||
|                 return None | ||||
|         try: | ||||
|             return self._register[item] | ||||
|             return self._collection[item] | ||||
|         except KeyError: | ||||
|             return None | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return f'{self.__class__.__name__}[{self._register}]' | ||||
|         return f'{self.__class__.__name__}[{self._collection}]' | ||||
|  | ||||
|  | ||||
| class EnvObjectRegister(ObjectRegister): | ||||
| class EnvObjectCollection(ObjectCollection): | ||||
|  | ||||
|     _accepted_objects = EnvObject | ||||
|  | ||||
| @@ -90,7 +90,7 @@ class EnvObjectRegister(ObjectRegister): | ||||
|                  is_blocking_light: bool = False, | ||||
|                  can_collide: bool = False, | ||||
|                  can_be_shadowed: bool = True, **kwargs): | ||||
|         super(EnvObjectRegister, self).__init__(*args, **kwargs) | ||||
|         super(EnvObjectCollection, self).__init__(*args, **kwargs) | ||||
|         self._shape = obs_shape | ||||
|         self._array = None | ||||
|         self._individual_slices = individual_slices | ||||
| @@ -99,8 +99,8 @@ class EnvObjectRegister(ObjectRegister): | ||||
|         self.can_be_shadowed = can_be_shadowed | ||||
|         self.can_collide = can_collide | ||||
|  | ||||
|     def register_item(self, other: EnvObject): | ||||
|         super(EnvObjectRegister, self).register_item(other) | ||||
|     def add_item(self, other: EnvObject): | ||||
|         super(EnvObjectCollection, self).add_item(other) | ||||
|         if self._array is None: | ||||
|             self._array = np.zeros((1, *self._shape)) | ||||
|         else: | ||||
| @@ -145,13 +145,13 @@ class EnvObjectRegister(ObjectRegister): | ||||
|         if self._individual_slices: | ||||
|             self._array = np.delete(self._array, idx, axis=0) | ||||
|         else: | ||||
|             self.notify_change_to_free(self._register[name]) | ||||
|             self.notify_change_to_free(self._collection[name]) | ||||
|             # Dirty Hack to check if not beeing subclassed. In that case we need to refresh the array since positions | ||||
|             # in the observation array are result of enumeration. They can overide each other. | ||||
|             # Todo: Find a better solution | ||||
|             if not issubclass(self.__class__, EntityRegister) and issubclass(self.__class__, EnvObjectRegister): | ||||
|             if not issubclass(self.__class__, EntityCollection) and issubclass(self.__class__, EnvObjectCollection): | ||||
|                 self._refresh_arrays() | ||||
|         del self._register[name] | ||||
|         del self._collection[name] | ||||
|  | ||||
|     def delete_env_object(self, env_object: EnvObject): | ||||
|         del self[env_object.name] | ||||
| @@ -160,19 +160,19 @@ class EnvObjectRegister(ObjectRegister): | ||||
|         del self[name] | ||||
|  | ||||
|  | ||||
| class EntityRegister(EnvObjectRegister, ABC): | ||||
| class EntityCollection(EnvObjectCollection, ABC): | ||||
|  | ||||
|     _accepted_objects = Entity | ||||
|  | ||||
|     @classmethod | ||||
|     def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs): | ||||
|         # objects_name = cls._accepted_objects.__name__ | ||||
|         register_obj = cls(*args, **kwargs) | ||||
|         entities = [cls._accepted_objects(tile, register_obj, str_ident=i, | ||||
|         collection = cls(*args, **kwargs) | ||||
|         entities = [cls._accepted_objects(tile, collection, str_ident=i, | ||||
|                                           **entity_kwargs if entity_kwargs is not None else {}) | ||||
|                     for i, tile in enumerate(tiles)] | ||||
|         register_obj.register_additional_items(entities) | ||||
|         return register_obj | ||||
|         collection.add_additional_items(entities) | ||||
|         return collection | ||||
|  | ||||
|     @classmethod | ||||
|     def from_argwhere_coordinates(cls, positions: [(int, int)], tiles, *args, entity_kwargs=None, **kwargs, ): | ||||
| @@ -188,13 +188,13 @@ class EntityRegister(EnvObjectRegister, ABC): | ||||
|         return [entity.tile for entity in self] | ||||
|  | ||||
|     def __init__(self, level_shape, *args, **kwargs): | ||||
|         super(EntityRegister, self).__init__(level_shape, *args, **kwargs) | ||||
|         super(EntityCollection, self).__init__(level_shape, *args, **kwargs) | ||||
|         self._lazy_eval_transforms = [] | ||||
|  | ||||
|     def __delitem__(self, name): | ||||
|         idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name) | ||||
|         obj.tile.leave(obj) | ||||
|         super(EntityRegister, self).__delitem__(name) | ||||
|         super(EntityCollection, self).__delitem__(name) | ||||
|  | ||||
|     def as_array(self): | ||||
|         if self._lazy_eval_transforms: | ||||
| @@ -223,7 +223,7 @@ class EntityRegister(EnvObjectRegister, ABC): | ||||
|             return None | ||||
|  | ||||
|  | ||||
| class BoundEnvObjRegister(EnvObjectRegister, ABC): | ||||
| class BoundEnvObjCollection(EnvObjectCollection, ABC): | ||||
|  | ||||
|     def __init__(self, entity_to_be_bound, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
| @@ -248,13 +248,13 @@ class BoundEnvObjRegister(EnvObjectRegister, ABC): | ||||
|         return self._array[self.idx_by_entity(entity)] | ||||
|  | ||||
|  | ||||
| class MovingEntityObjectRegister(EntityRegister, ABC): | ||||
| class MovingEntityObjectCollection(EntityCollection, ABC): | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super(MovingEntityObjectRegister, self).__init__(*args, **kwargs) | ||||
|         super(MovingEntityObjectCollection, self).__init__(*args, **kwargs) | ||||
|  | ||||
|     def notify_change_to_value(self, entity): | ||||
|         super(MovingEntityObjectRegister, self).notify_change_to_value(entity) | ||||
|         super(MovingEntityObjectCollection, self).notify_change_to_value(entity) | ||||
|         if entity.last_pos != c.NO_POS: | ||||
|             try: | ||||
|                 self._array_change_notifyer(entity, entity.last_pos, value=c.FREE_CELL) | ||||
| @@ -263,11 +263,11 @@ class MovingEntityObjectRegister(EntityRegister, ABC): | ||||
|  | ||||
|  | ||||
| ########################################################################## | ||||
| # ################# Objects and Entity Registers ####################### # | ||||
| # ################# Objects and Entity Collection ###################### # | ||||
| ########################################################################## | ||||
|  | ||||
|  | ||||
| class GlobalPositions(EnvObjectRegister): | ||||
| class GlobalPositions(EnvObjectCollection): | ||||
|  | ||||
|     _accepted_objects = GlobalPosition | ||||
|  | ||||
| @@ -288,7 +288,7 @@ class GlobalPositions(EnvObjectRegister): | ||||
|         global_positions = [self._accepted_objects(self._shape, agent, self) | ||||
|                             for _, agent in enumerate(agents)] | ||||
|         # noinspection PyTypeChecker | ||||
|         self.register_additional_items(global_positions) | ||||
|         self.add_additional_items(global_positions) | ||||
|  | ||||
|     def summarize_states(self, n_steps=None): | ||||
|         return {} | ||||
| @@ -306,7 +306,7 @@ class GlobalPositions(EnvObjectRegister): | ||||
|             return None | ||||
|  | ||||
|  | ||||
| class PlaceHolders(EnvObjectRegister): | ||||
| class PlaceHolders(EnvObjectCollection): | ||||
|     _accepted_objects = PlaceHolder | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
| @@ -320,12 +320,12 @@ class PlaceHolders(EnvObjectRegister): | ||||
|         # objects_name = cls._accepted_objects.__name__ | ||||
|         if isinstance(values, (str, numbers.Number)): | ||||
|             values = [values] | ||||
|         register_obj = cls(*args, **kwargs) | ||||
|         objects = [cls._accepted_objects(register_obj, str_ident=i, fill_value=value, | ||||
|         collection = cls(*args, **kwargs) | ||||
|         objects = [cls._accepted_objects(collection, str_ident=i, fill_value=value, | ||||
|                                          **object_kwargs if object_kwargs is not None else {}) | ||||
|                    for i, value in enumerate(values)] | ||||
|         register_obj.register_additional_items(objects) | ||||
|         return register_obj | ||||
|         collection.add_additional_items(objects) | ||||
|         return collection | ||||
|  | ||||
|     # noinspection DuplicatedCode | ||||
|     def as_array(self): | ||||
| @@ -343,8 +343,8 @@ class PlaceHolders(EnvObjectRegister): | ||||
|         return self._array | ||||
|  | ||||
|  | ||||
| class Entities(ObjectRegister): | ||||
|     _accepted_objects = EntityRegister | ||||
| class Entities(ObjectCollection): | ||||
|     _accepted_objects = EntityCollection | ||||
|  | ||||
|     @property | ||||
|     def arrays(self): | ||||
| @@ -352,7 +352,7 @@ class Entities(ObjectRegister): | ||||
|  | ||||
|     @property | ||||
|     def names(self): | ||||
|         return list(self._register.keys()) | ||||
|         return list(self._collection.keys()) | ||||
|  | ||||
|     def __init__(self): | ||||
|         super(Entities, self).__init__() | ||||
| @@ -360,21 +360,21 @@ class Entities(ObjectRegister): | ||||
|     def iter_individual_entitites(self): | ||||
|         return iter((x for sublist in self.values() for x in sublist)) | ||||
|  | ||||
|     def register_item(self, other: dict): | ||||
|     def add_item(self, other: dict): | ||||
|         assert not any([key for key in other.keys() if key in self.keys()]), \ | ||||
|             "This group of entities has already been registered!" | ||||
|         self._register.update(other) | ||||
|             "This group of entities has already been added!" | ||||
|         self._collection.update(other) | ||||
|         return self | ||||
|  | ||||
|     def register_additional_items(self, others: Dict): | ||||
|         return self.register_item(others) | ||||
|     def add_additional_items(self, others: Dict): | ||||
|         return self.add_item(others) | ||||
|  | ||||
|     def by_pos(self, pos: (int, int)): | ||||
|         found_entities = [y for y in (x.by_pos(pos) for x in self.values() if hasattr(x, 'by_pos')) if y is not None] | ||||
|         return found_entities | ||||
|  | ||||
|  | ||||
| class Walls(EntityRegister): | ||||
| class Walls(EntityCollection): | ||||
|     _accepted_objects = Wall | ||||
|  | ||||
|     def as_array(self): | ||||
| @@ -396,7 +396,7 @@ class Walls(EntityRegister): | ||||
|     def from_argwhere_coordinates(cls, argwhere_coordinates, *args, **kwargs): | ||||
|         tiles = cls(*args, **kwargs) | ||||
|         # noinspection PyTypeChecker | ||||
|         tiles.register_additional_items( | ||||
|         tiles.add_additional_items( | ||||
|             [cls._accepted_objects(pos, tiles) | ||||
|              for pos in argwhere_coordinates] | ||||
|         ) | ||||
| @@ -441,7 +441,7 @@ class Floors(Walls): | ||||
|         return {} | ||||
|  | ||||
|  | ||||
| class Agents(MovingEntityObjectRegister): | ||||
| class Agents(MovingEntityObjectCollection): | ||||
|     _accepted_objects = Agent | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
| @@ -455,10 +455,10 @@ class Agents(MovingEntityObjectRegister): | ||||
|         old_agent = self[key] | ||||
|         self[key].tile.leave(self[key]) | ||||
|         agent._name = old_agent.name | ||||
|         self._register[agent.name] = agent | ||||
|         self._collection[agent.name] = agent | ||||
|  | ||||
|  | ||||
| class Doors(EntityRegister): | ||||
| class Doors(EntityCollection): | ||||
|  | ||||
|     def __init__(self, *args, have_area: bool = False, **kwargs): | ||||
|         self.have_area = have_area | ||||
| @@ -490,7 +490,7 @@ class Doors(EntityRegister): | ||||
|         return super(Doors, self).as_array() | ||||
|  | ||||
|  | ||||
| class Actions(ObjectRegister): | ||||
| class Actions(ObjectCollection): | ||||
|     _accepted_objects = Action | ||||
|  | ||||
|     @property | ||||
| @@ -507,22 +507,22 @@ class Actions(ObjectRegister): | ||||
|  | ||||
|         # Move this to Baseclass, Env init? | ||||
|         if self.allow_square_movement: | ||||
|             self.register_additional_items([self._accepted_objects(str_ident=direction) | ||||
|                                             for direction in h.EnvActions.square_move()]) | ||||
|             self.add_additional_items([self._accepted_objects(str_ident=direction) | ||||
|                                        for direction in h.EnvActions.square_move()]) | ||||
|         if self.allow_diagonal_movement: | ||||
|             self.register_additional_items([self._accepted_objects(str_ident=direction) | ||||
|                                             for direction in h.EnvActions.diagonal_move()]) | ||||
|         self._movement_actions = self._register.copy() | ||||
|             self.add_additional_items([self._accepted_objects(str_ident=direction) | ||||
|                                        for direction in h.EnvActions.diagonal_move()]) | ||||
|         self._movement_actions = self._collection.copy() | ||||
|         if self.can_use_doors: | ||||
|             self.register_additional_items([self._accepted_objects(str_ident=h.EnvActions.USE_DOOR)]) | ||||
|             self.add_additional_items([self._accepted_objects(str_ident=h.EnvActions.USE_DOOR)]) | ||||
|         if self.allow_no_op: | ||||
|             self.register_additional_items([self._accepted_objects(str_ident=h.EnvActions.NOOP)]) | ||||
|             self.add_additional_items([self._accepted_objects(str_ident=h.EnvActions.NOOP)]) | ||||
|  | ||||
|     def is_moving_action(self, action: Union[int]): | ||||
|         return action in self.movement_actions.values() | ||||
|  | ||||
|  | ||||
| class Zones(ObjectRegister): | ||||
| class Zones(ObjectCollection): | ||||
|  | ||||
|     @property | ||||
|     def accounting_zones(self): | ||||
| @@ -551,5 +551,5 @@ class Zones(ObjectRegister): | ||||
|     def __getitem__(self, item): | ||||
|         return self._zone_slices[item] | ||||
|  | ||||
|     def register_additional_items(self, other: Union[str, List[str]]): | ||||
|     def add_additional_items(self, other: Union[str, List[str]]): | ||||
|         raise AttributeError('You are not allowed to add additional Zones in runtime.') | ||||
|   | ||||
| @@ -4,7 +4,7 @@ import numpy as np | ||||
|  | ||||
| from environments.factory.base.base_factory import BaseFactory | ||||
| from environments.factory.base.objects import Agent, Action, Entity, EnvObject, BoundingMixin | ||||
| from environments.factory.base.registers import EntityRegister, EnvObjectRegister | ||||
| from environments.factory.base.registers import EntityCollection, EnvObjectCollection | ||||
| from environments.factory.base.renderer import RenderEntity | ||||
| from environments.helpers import Constants as BaseConstants | ||||
| from environments.helpers import EnvActions as BaseActions | ||||
| @@ -68,7 +68,7 @@ class Battery(BoundingMixin, EnvObject): | ||||
|         if self.charge_level != 0: | ||||
|             # noinspection PyTypeChecker | ||||
|             self.charge_level = max(0, amount + self.charge_level) | ||||
|             self._register.notify_change_to_value(self) | ||||
|             self._collection.notify_change_to_value(self) | ||||
|             return c.VALID | ||||
|         else: | ||||
|             return c.NOT_VALID | ||||
| @@ -79,7 +79,7 @@ class Battery(BoundingMixin, EnvObject): | ||||
|         return attr_dict | ||||
|  | ||||
|  | ||||
| class BatteriesRegister(EnvObjectRegister): | ||||
| class BatteriesRegister(EnvObjectCollection): | ||||
|  | ||||
|     _accepted_objects = Battery | ||||
|  | ||||
| @@ -90,7 +90,7 @@ class BatteriesRegister(EnvObjectRegister): | ||||
|  | ||||
|     def spawn_batteries(self, agents, initial_charge_level): | ||||
|         batteries = [self._accepted_objects(initial_charge_level, agent, self) for _, agent in enumerate(agents)] | ||||
|         self.register_additional_items(batteries) | ||||
|         self.add_additional_items(batteries) | ||||
|  | ||||
|     def summarize_states(self, n_steps=None): | ||||
|         # as dict with additional nesting | ||||
| @@ -140,7 +140,7 @@ class ChargePod(Entity): | ||||
|             return summary | ||||
|  | ||||
|  | ||||
| class ChargePods(EntityRegister): | ||||
| class ChargePods(EntityCollection): | ||||
|  | ||||
|     _accepted_objects = ChargePod | ||||
|  | ||||
|   | ||||
| @@ -9,7 +9,7 @@ from environments.factory.base.base_factory import BaseFactory | ||||
| from environments.helpers import Constants as BaseConstants | ||||
| from environments.helpers import EnvActions as BaseActions | ||||
| from environments.factory.base.objects import Agent, Entity, Action | ||||
| from environments.factory.base.registers import Entities, EntityRegister | ||||
| from environments.factory.base.registers import Entities, EntityCollection | ||||
|  | ||||
| from environments.factory.base.renderer import RenderEntity | ||||
|  | ||||
| @@ -73,7 +73,7 @@ class Destination(Entity): | ||||
|         return state_summary | ||||
|  | ||||
|  | ||||
| class Destinations(EntityRegister): | ||||
| class Destinations(EntityCollection): | ||||
|  | ||||
|     _accepted_objects = Destination | ||||
|  | ||||
| @@ -208,13 +208,13 @@ class DestFactory(BaseFactory): | ||||
|             n_dest_to_spawn = len(destinations_to_spawn) | ||||
|             if self.dest_prop.spawn_mode != DestModeOptions.GROUPED: | ||||
|                 destinations = [Destination(tile, c.DEST) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]] | ||||
|                 self[c.DEST].register_additional_items(destinations) | ||||
|                 self[c.DEST].add_additional_items(destinations) | ||||
|                 for dest in destinations_to_spawn: | ||||
|                     del self._dest_spawn_timer[dest] | ||||
|                 self.print(f'{n_dest_to_spawn} new destinations have been spawned') | ||||
|             elif self.dest_prop.spawn_mode == DestModeOptions.GROUPED and n_dest_to_spawn == self.dest_prop.n_dests: | ||||
|                 destinations = [Destination(tile, self[c.DEST]) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]] | ||||
|                 self[c.DEST].register_additional_items(destinations) | ||||
|                 self[c.DEST].add_additional_items(destinations) | ||||
|                 for dest in destinations_to_spawn: | ||||
|                     del self._dest_spawn_timer[dest] | ||||
|                 self.print(f'{n_dest_to_spawn} new destinations have been spawned') | ||||
| @@ -231,7 +231,7 @@ class DestFactory(BaseFactory): | ||||
|             self._dest_spawn_timer[key] = min(self.dest_prop.spawn_frequency, self._dest_spawn_timer[key] + 1) | ||||
|         for dest in list(self[c.DEST].values()): | ||||
|             if dest.is_considered_reached: | ||||
|                 dest.change_register(self[c.DEST]) | ||||
|                 dest.change_parent_collection(self[c.DEST]) | ||||
|                 self._dest_spawn_timer[dest.name] = 0 | ||||
|                 self.print(f'{dest.name} is reached now, removing...') | ||||
|             else: | ||||
|   | ||||
| @@ -11,7 +11,7 @@ from environments.helpers import EnvActions as BaseActions | ||||
|  | ||||
| from environments.factory.base.base_factory import BaseFactory | ||||
| from environments.factory.base.objects import Agent, Action, Entity, Floor | ||||
| from environments.factory.base.registers import Entities, EntityRegister | ||||
| from environments.factory.base.registers import Entities, EntityCollection | ||||
|  | ||||
| from environments.factory.base.renderer import RenderEntity | ||||
| from environments.utility_classes import ObservationProperties | ||||
| @@ -61,7 +61,7 @@ class Dirt(Entity): | ||||
|  | ||||
|     def set_new_amount(self, amount): | ||||
|         self._amount = amount | ||||
|         self._register.notify_change_to_value(self) | ||||
|         self._collection.notify_change_to_value(self) | ||||
|  | ||||
|     def summarize_state(self, **kwargs): | ||||
|         state_dict = super().summarize_state(**kwargs) | ||||
| @@ -69,7 +69,7 @@ class Dirt(Entity): | ||||
|         return state_dict | ||||
|  | ||||
|  | ||||
| class DirtRegister(EntityRegister): | ||||
| class DirtRegister(EntityCollection): | ||||
|  | ||||
|     _accepted_objects = Dirt | ||||
|  | ||||
| @@ -93,7 +93,7 @@ class DirtRegister(EntityRegister): | ||||
|                 dirt = self.by_pos(tile.pos) | ||||
|                 if dirt is None: | ||||
|                     dirt = Dirt(tile, self, amount=self.dirt_properties.max_spawn_amount) | ||||
|                     self.register_item(dirt) | ||||
|                     self.add_item(dirt) | ||||
|                 else: | ||||
|                     new_value = dirt.amount + self.dirt_properties.max_spawn_amount | ||||
|                     dirt.set_new_amount(min(new_value, self.dirt_properties.max_local_amount)) | ||||
|   | ||||
| @@ -5,10 +5,10 @@ import numpy as np | ||||
| from environments.factory.base.objects import Agent, Entity, Action | ||||
| from environments.factory.factory_dirt import Dirt, DirtRegister, DirtFactory | ||||
| from environments.factory.base.objects import Floor | ||||
| from environments.factory.base.registers import Floors, Entities, EntityRegister | ||||
| from environments.factory.base.registers import Floors, Entities, EntityCollection | ||||
|  | ||||
|  | ||||
| class Machines(EntityRegister): | ||||
| class Machines(EntityCollection): | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|   | ||||
| @@ -9,7 +9,7 @@ from environments.helpers import Constants as BaseConstants | ||||
| from environments.helpers import EnvActions as BaseActions | ||||
| from environments import helpers as h | ||||
| from environments.factory.base.objects import Agent, Entity, Action, Floor | ||||
| from environments.factory.base.registers import Entities, EntityRegister, BoundEnvObjRegister, ObjectRegister | ||||
| from environments.factory.base.registers import Entities, EntityCollection, BoundEnvObjCollection, ObjectCollection | ||||
|  | ||||
| from environments.factory.base.renderer import RenderEntity | ||||
|  | ||||
| @@ -53,17 +53,17 @@ class Item(Entity): | ||||
|         self._auto_despawn = auto_despawn | ||||
|  | ||||
|     def set_tile_to(self, no_pos_tile): | ||||
|         assert self._register.__class__.__name__ != ItemRegister.__class__ | ||||
|         assert self._collection.__class__.__name__ != ItemRegister.__class__ | ||||
|         self._tile = no_pos_tile | ||||
|  | ||||
|  | ||||
| class ItemRegister(EntityRegister): | ||||
| class ItemRegister(EntityCollection): | ||||
|  | ||||
|     _accepted_objects = Item | ||||
|  | ||||
|     def spawn_items(self, tiles: List[Floor]): | ||||
|         items = [Item(tile, self) for tile in tiles] | ||||
|         self.register_additional_items(items) | ||||
|         self.add_additional_items(items) | ||||
|  | ||||
|     def despawn_items(self, items: List[Item]): | ||||
|         items = [items] if isinstance(items, Item) else items | ||||
| @@ -71,7 +71,7 @@ class ItemRegister(EntityRegister): | ||||
|             del self[item] | ||||
|  | ||||
|  | ||||
| class Inventory(BoundEnvObjRegister): | ||||
| class Inventory(BoundEnvObjCollection): | ||||
|  | ||||
|     @property | ||||
|     def name(self): | ||||
| @@ -98,7 +98,7 @@ class Inventory(BoundEnvObjRegister): | ||||
|         return item_to_pop | ||||
|  | ||||
|  | ||||
| class Inventories(ObjectRegister): | ||||
| class Inventories(ObjectCollection): | ||||
|  | ||||
|     _accepted_objects = Inventory | ||||
|     is_blocking_light = False | ||||
| @@ -114,7 +114,7 @@ class Inventories(ObjectRegister): | ||||
|     def spawn_inventories(self, agents, capacity): | ||||
|         inventories = [self._accepted_objects(agent, capacity, self._obs_shape) | ||||
|                        for _, agent in enumerate(agents)] | ||||
|         self.register_additional_items(inventories) | ||||
|         self.add_additional_items(inventories) | ||||
|  | ||||
|     def idx_by_entity(self, entity): | ||||
|         try: | ||||
| @@ -161,7 +161,7 @@ class DropOffLocation(Entity): | ||||
|             return super().summarize_state(n_steps=n_steps) | ||||
|  | ||||
|  | ||||
| class DropOffLocations(EntityRegister): | ||||
| class DropOffLocations(EntityCollection): | ||||
|  | ||||
|     _accepted_objects = DropOffLocation | ||||
|  | ||||
| @@ -250,7 +250,7 @@ class ItemFactory(BaseFactory): | ||||
|                           reason=a.ITEM_ACTION, info=info_dict) | ||||
|             return valid, reward | ||||
|         elif item := self[c.ITEM].by_pos(agent.pos): | ||||
|             item.change_register(inventory) | ||||
|             item.change_parent_collection(inventory) | ||||
|             item.set_tile_to(self._NO_POS_TILE) | ||||
|             self.print(f'{agent.name} just picked up an item at {agent.pos}') | ||||
|             info_dict = {f'{agent.name}_{a.ITEM_ACTION}_VALID': 1, f'{a.ITEM_ACTION}_VALID': 1} | ||||
|   | ||||
| @@ -7,47 +7,76 @@ import numpy as np | ||||
| from numpy.typing import ArrayLike | ||||
| from stable_baselines3 import PPO, DQN, A2C | ||||
|  | ||||
| MODEL_MAP = dict(PPO=PPO, DQN=DQN, A2C=A2C) | ||||
|  | ||||
| LEVELS_DIR = 'levels' | ||||
| STEPS_START = 1 | ||||
|  | ||||
| TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles'] | ||||
| IGNORED_DF_COLUMNS = ['Episode', 'Run', 'train_step', 'step', 'index', 'dirt_amount', | ||||
|                       'dirty_tile_count', 'terminal_observation', 'episode'] | ||||
| """ | ||||
| This file is used for: | ||||
|     1. string based definition | ||||
|         Use a class like `Constants`, to define attributes, which then reveal strings. | ||||
|         These can be used for naming convention along the environments as well as keys for mappings such as dicts etc. | ||||
|         When defining new envs, use class inheritance.  | ||||
|      | ||||
|     2. utility function definition | ||||
|         There are static utility functions which are not bound to a specific environment. | ||||
|         In this file they are defined to be used across the entire package. | ||||
| """ | ||||
|  | ||||
|  | ||||
| MODEL_MAP = dict(PPO=PPO, DQN=DQN, A2C=A2C)      # For use in studies and experiments | ||||
|  | ||||
|  | ||||
| LEVELS_DIR = 'levels'                            # for use in studies and experiments | ||||
| STEPS_START = 1                                  # Define where to the stepcount; which is the first step | ||||
|  | ||||
| # Not used anymore? Clean! | ||||
| # TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles'] | ||||
| IGNORED_DF_COLUMNS = ['Episode', 'Run',          # For plotting, which values are ignored when loading monitor files | ||||
|                       'train_step', 'step', 'index', 'dirt_amount', 'dirty_tile_count', 'terminal_observation', | ||||
|                       'episode'] | ||||
|  | ||||
|  | ||||
| # Constants | ||||
| class Constants: | ||||
|     WALL                = '#' | ||||
|     WALLS               = 'Walls' | ||||
|     FLOOR               = 'Floor' | ||||
|     DOOR                = 'D' | ||||
|     DANGER_ZONE         = 'x' | ||||
|     LEVEL               = 'Level' | ||||
|     AGENT               = 'Agent' | ||||
|     AGENT_PLACEHOLDER   = 'AGENT_PLACEHOLDER' | ||||
|     GLOBAL_POSITION     = 'GLOBAL_POSITION' | ||||
|     FREE_CELL           = 0 | ||||
|     OCCUPIED_CELL       = 1 | ||||
|     SHADOWED_CELL       = -1 | ||||
|     ACCESS_DOOR_CELL    = 1/3 | ||||
|     OPEN_DOOR_CELL      = 2/3 | ||||
|     CLOSED_DOOR_CELL    = 3/3 | ||||
|     NO_POS              = (-9999, -9999) | ||||
|  | ||||
|     DOORS               = 'Doors' | ||||
|     CLOSED_DOOR         = 'closed' | ||||
|     OPEN_DOOR           = 'open' | ||||
|     ACCESS_DOOR         = 'access' | ||||
|     """ | ||||
|         String based mapping. Use these to handle keys or define values, which can be then be used globaly. | ||||
|         Please use class inheritance when defining new environments. | ||||
|     """ | ||||
|  | ||||
|     ACTION              = 'action' | ||||
|     COLLISION          = 'collision' | ||||
|     VALID               = True | ||||
|     NOT_VALID           = False | ||||
|     WALL                = '#'                   # Wall tile identifier for resolving the string based map files. | ||||
|     DOOR                = 'D'                   # Door identifier for resolving the string based map files. | ||||
|     DANGER_ZONE         = 'x'                   # Dange Zone tile identifier for resolving the string based map files. | ||||
|  | ||||
|     WALLS               = 'Walls'               # Identifier of Wall-objects and sets (collections). | ||||
|     FLOOR               = 'Floor'               # Identifier of Floor-objects and sets (collections). | ||||
|     DOORS               = 'Doors'               # Identifier of Door-objects and sets (collections). | ||||
|     LEVEL               = 'Level'               # Identifier of Level-objects and sets (collections). | ||||
|     AGENT               = 'Agent'               # Identifier of Agent-objects and sets (collections). | ||||
|     AGENT_PLACEHOLDER   = 'AGENT_PLACEHOLDER'   # Identifier of Placeholder-objects and sets (collections). | ||||
|     GLOBAL_POSITION     = 'GLOBAL_POSITION'     # Identifier of the global position slice | ||||
|  | ||||
|     FREE_CELL           = 0                     # Free-Cell value used in observation | ||||
|     OCCUPIED_CELL       = 1                     # Occupied-Cell value used in observation | ||||
|     SHADOWED_CELL       = -1                    # Shadowed-Cell value used in observation | ||||
|     ACCESS_DOOR_CELL    = 1/3                   # Access-door-Cell value used in observation | ||||
|     OPEN_DOOR_CELL      = 2/3                   # Open-door-Cell value used in observation | ||||
|     CLOSED_DOOR_CELL    = 3/3                   # Closed-door-Cell value used in observation | ||||
|  | ||||
|     NO_POS              = (-9999, -9999)        # Invalid Position value used in the environment (something is off-grid) | ||||
|  | ||||
|     CLOSED_DOOR         = 'closed'              # Identifier to compare door-is-closed state | ||||
|     OPEN_DOOR           = 'open'                # Identifier to compare door-is-open state | ||||
|     # ACCESS_DOOR         = 'access'            # Identifier to compare access positions | ||||
|  | ||||
|     ACTION              = 'action'              # Identifier of Action-objects and sets (collections). | ||||
|     COLLISION           = 'collision'           # Identifier to use in the context of collitions. | ||||
|     VALID               = True                  # Identifier to rename boolean values in the context of actions. | ||||
|     NOT_VALID           = False                 # Identifier to rename boolean values in the context of actions. | ||||
|  | ||||
|  | ||||
| class EnvActions: | ||||
|     """ | ||||
|         String based mapping. Use these to identifiy actions, can be used globaly. | ||||
|         Please use class inheritance when defining new environments with new actions. | ||||
|     """ | ||||
|     # Movements | ||||
|     NORTH           = 'north' | ||||
|     EAST            = 'east' | ||||
| @@ -63,24 +92,77 @@ class EnvActions: | ||||
|     NOOP            = 'no_op' | ||||
|     USE_DOOR        = 'use_door' | ||||
|  | ||||
|     _ACTIONMAP = defaultdict(lambda: (0, 0), | ||||
|                             {NORTH: (-1, 0),    NORTHEAST: (-1, 1), | ||||
|                              EAST:  (0, 1),     SOUTHEAST: (1, 1), | ||||
|                              SOUTH: (1, 0),     SOUTHWEST: (1, -1), | ||||
|                              WEST:  (0, -1),    NORTHWEST: (-1, -1) | ||||
|                              } | ||||
|                             ) | ||||
|  | ||||
|     @classmethod | ||||
|     def is_move(cls, other): | ||||
|         return any([other == direction for direction in cls.movement_actions()]) | ||||
|     def is_move(cls, action): | ||||
|         """ | ||||
|         Classmethod; checks if given action is a movement action or not. Depending on the env. configuration, | ||||
|         Movement actions are either `manhattan` (square) style movements (up,down, left, right) and/or diagonal. | ||||
|  | ||||
|         :param action:  Action to be checked | ||||
|         :type action:   str | ||||
|         :return:        Whether the given action is a movement action. | ||||
|         :rtype:         bool | ||||
|         """ | ||||
|         return any([action == direction for direction in cls.movement_actions()]) | ||||
|  | ||||
|     @classmethod | ||||
|     def square_move(cls): | ||||
|         """ | ||||
|         Classmethod; return a list of movement actions that are considered square or `manhattan` style movements. | ||||
|  | ||||
|         :return: A list of movement actions. | ||||
|         :rtype: list(str) | ||||
|         """ | ||||
|         return [cls.NORTH, cls.EAST, cls.SOUTH, cls.WEST] | ||||
|  | ||||
|     @classmethod | ||||
|     def diagonal_move(cls): | ||||
|         """ | ||||
|         Classmethod; return a list of movement actions that are considered diagonal movements. | ||||
|  | ||||
|         :return: A list of movement actions. | ||||
|         :rtype: list(str) | ||||
|         """ | ||||
|         return [cls.NORTHEAST, cls.SOUTHEAST, cls.SOUTHWEST, cls.NORTHWEST] | ||||
|  | ||||
|     @classmethod | ||||
|     def movement_actions(cls): | ||||
|         """ | ||||
|         Classmethod; return a list of all available movement actions. | ||||
|         Please note, that this is indipendent from the env. properties | ||||
|  | ||||
|         :return: A list of movement actions. | ||||
|         :rtype: list(str) | ||||
|         """ | ||||
|         return list(itertools.chain(cls.square_move(), cls.diagonal_move())) | ||||
|  | ||||
|     @classmethod | ||||
|     def resolve_movement_action_to_coords(cls, action): | ||||
|         """ | ||||
|         Classmethod; resolve movement actions. Given a movement action, return the delta in coordinates it stands for. | ||||
|         How does the current entity coordinate change if it performs the given action? | ||||
|         Please note, this is indipendent from the env. properties | ||||
|  | ||||
|         :return: Delta coorinates. | ||||
|         :rtype: tuple(int, int) | ||||
|         """ | ||||
|         return cls._ACTIONMAP[action] | ||||
|  | ||||
|  | ||||
| class RewardsBase(NamedTuple): | ||||
|     """ | ||||
|         Value based mapping. Use these to define reward values for specific conditions (i.e. the action | ||||
|         in a given context), can be used globaly. | ||||
|         Please use class inheritance when defining new environments with new rewards. | ||||
|     """ | ||||
|     MOVEMENTS_VALID: float = -0.001 | ||||
|     MOVEMENTS_FAIL: float  = -0.05 | ||||
|     NOOP: float            = -0.01 | ||||
| @@ -89,23 +171,31 @@ class RewardsBase(NamedTuple): | ||||
|     COLLISION: float       = -0.5 | ||||
|  | ||||
|  | ||||
| m = EnvActions | ||||
| c = Constants | ||||
|  | ||||
| ACTIONMAP = defaultdict(lambda: (0, 0), | ||||
|                         {m.NORTH: (-1, 0), m.NORTHEAST: (-1, 1), | ||||
|                          m.EAST: (0, 1),   m.SOUTHEAST: (1, 1), | ||||
|                          m.SOUTH: (1, 0),  m.SOUTHWEST: (1, -1), | ||||
|                          m.WEST: (0, -1),  m.NORTHWEST: (-1, -1) | ||||
|                          } | ||||
|                         ) | ||||
|  | ||||
|  | ||||
| class ObservationTranslator: | ||||
|  | ||||
|     def __init__(self, obs_shape_2d: (int, int), this_named_observation_space: Dict[str, dict], | ||||
|                  *per_agent_named_obs_space: Dict[str, dict], | ||||
|                  *per_agent_named_obs_spaces: Dict[str, dict], | ||||
|                  placeholder_fill_value: Union[int, str] = 'N'): | ||||
|         """ | ||||
|         This is a helper class, which converts agents observations from joined environments. | ||||
|         For example, agents trained in different environments may expect different observations. | ||||
|         This class translates from larger observations spaces to smaller. | ||||
|         A string identifier based approach is used. | ||||
|         Currently, it is not possible to mix different obs shapes. | ||||
|  | ||||
|         :param obs_shape_2d: The shape of the observation the agents expect. | ||||
|         :type  obs_shape_2d: tuple(int, int) | ||||
|  | ||||
|         :param this_named_observation_space: `Named observation space` of the joined environment. | ||||
|         :type  this_named_observation_space: Dict[str, dict] | ||||
|  | ||||
|         :param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded. | ||||
|         type  per_agent_named_obs_spaces: Dict[str, dict] | ||||
|  | ||||
|         :param placeholder_fill_value: Currently not fully implemented!!! | ||||
|         :type  placeholder_fill_value: Union[int, str] = 'N') | ||||
|         """ | ||||
|  | ||||
|         assert len(obs_shape_2d) == 2 | ||||
|         self.obs_shape = obs_shape_2d | ||||
|         if isinstance(placeholder_fill_value, str): | ||||
| @@ -119,7 +209,7 @@ class ObservationTranslator: | ||||
|             self.random_fill = None | ||||
|  | ||||
|         self._this_named_obs_space = this_named_observation_space | ||||
|         self._per_agent_named_obs_space = list(per_agent_named_obs_space) | ||||
|         self._per_agent_named_obs_space = list(per_agent_named_obs_spaces) | ||||
|  | ||||
|     def translate_observation(self, agent_idx: int, obs: np.ndarray): | ||||
|         target_obs_space = self._per_agent_named_obs_space[agent_idx] | ||||
| @@ -137,6 +227,19 @@ class ObservationTranslator: | ||||
| class ActionTranslator: | ||||
|  | ||||
|     def __init__(self, target_named_action_space: Dict[str, int], *per_agent_named_action_space: Dict[str, int]): | ||||
|         """ | ||||
|         This is a helper class, which converts agents action spaces to a joined environments action space. | ||||
|         For example, agents trained in different environments may have different action spaces. | ||||
|         This class translates from smaller individual agent action spaces to larger joined spaces. | ||||
|         A string identifier based approach is used. | ||||
|  | ||||
|         :param target_named_action_space:  Joined `Named action space` for the current environment. | ||||
|         :type target_named_action_space: Dict[str, dict] | ||||
|  | ||||
|         :param per_agent_named_action_space: `Named action space` one for each agent. Overloaded. | ||||
|         :type per_agent_named_action_space: Dict[str, dict] | ||||
|         """ | ||||
|  | ||||
|         self._target_named_action_space = target_named_action_space | ||||
|         self._per_agent_named_action_space = list(per_agent_named_action_space) | ||||
|         self._per_agent_idx_actions = [{idx: a for a, idx in x.items()} for x in self._per_agent_named_action_space] | ||||
| @@ -155,6 +258,16 @@ class ActionTranslator: | ||||
|  | ||||
| # Utility functions | ||||
| def parse_level(path): | ||||
|     """ | ||||
|     Given the path to a strin based `level` or `map` representation, this function reads the content. | ||||
|     Cleans `space`, checks for equal length of each row and returns a list of lists. | ||||
|  | ||||
|     :param path: Path to the `level` or `map` file on harddrive. | ||||
|     :type path: os.Pathlike | ||||
|  | ||||
|     :return: The read string representation of the `level` or `map` | ||||
|     :rtype: List[List[str]] | ||||
|     """ | ||||
|     with path.open('r') as lvl: | ||||
|         level = list(map(lambda x: list(x.strip()), lvl.readlines())) | ||||
|     if len(set([len(line) for line in level])) > 1: | ||||
| @@ -162,29 +275,56 @@ def parse_level(path): | ||||
|     return level | ||||
|  | ||||
|  | ||||
| def one_hot_level(level, wall_char: str = c.WALL): | ||||
| def one_hot_level(level, wall_char: str = Constants.WALL): | ||||
|     """ | ||||
|     Given a string based level representation (list of lists, see function `parse_level`), this function creates a | ||||
|     binary numpy array or `grid`. Grid values that equal `wall_char` become of `Constants.OCCUPIED_CELL` value. | ||||
|     Can be changed to filter for any symbol. | ||||
|  | ||||
|     :param level: String based level representation (list of lists, see function `parse_level`). | ||||
|     :param wall_char: List[List[str]] | ||||
|  | ||||
|     :return: Binary numpy array | ||||
|     :rtype: np.typing._array_like.ArrayLike | ||||
|     """ | ||||
|  | ||||
|     grid = np.array(level) | ||||
|     binary_grid = np.zeros(grid.shape, dtype=np.int8) | ||||
|     binary_grid[grid == wall_char] = c.OCCUPIED_CELL | ||||
|     binary_grid[grid == wall_char] = Constants.OCCUPIED_CELL | ||||
|     return binary_grid | ||||
|  | ||||
|  | ||||
| def check_position(slice_to_check_against: ArrayLike, position_to_check: Tuple[int, int]): | ||||
|     """ | ||||
|     Given a slice (2-D Arraylike object) | ||||
|  | ||||
|     :param slice_to_check_against: The slice to check for accessability | ||||
|     :type slice_to_check_against: np.typing._array_like.ArrayLike | ||||
|  | ||||
|     :param position_to_check: Position in slice that should be checked. Can be outside of slice boundarys. | ||||
|     :type position_to_check: tuple(int, int) | ||||
|  | ||||
|     :return: Whether a position can be moved to. | ||||
|     :rtype: bool | ||||
|     """ | ||||
|     x_pos, y_pos = position_to_check | ||||
|  | ||||
|     # Check if agent colides with grid boundrys | ||||
|     valid = not ( | ||||
|             x_pos < 0 or y_pos < 0 | ||||
|             or x_pos >= slice_to_check_against.shape[0] | ||||
|             or y_pos >= slice_to_check_against.shape[0] | ||||
|             or y_pos >= slice_to_check_against.shape[1] | ||||
|     ) | ||||
|  | ||||
|     # Check for collision with level walls | ||||
|     valid = valid and not slice_to_check_against[x_pos, y_pos] | ||||
|     return c.VALID if valid else c.NOT_VALID | ||||
|     return Constants.VALID if valid else Constants.NOT_VALID | ||||
|  | ||||
|  | ||||
| def asset_str(agent): | ||||
|     """ | ||||
|         FIXME @ romue | ||||
|     """ | ||||
|     # What does this abonimation do? | ||||
|     # if any([x is None for x in [cls._slices[j] for j in agent.collisions]]): | ||||
|     #     print('error') | ||||
| @@ -192,33 +332,50 @@ def asset_str(agent): | ||||
|         action = step_result['action_name'] | ||||
|         valid = step_result['action_valid'] | ||||
|         col_names = [x.name for x in step_result['collisions']] | ||||
|         if any(c.AGENT in name for name in col_names): | ||||
|         if any(Constants.AGENT in name for name in col_names): | ||||
|             return 'agent_collision', 'blank' | ||||
|         elif not valid or c.LEVEL in col_names or c.AGENT in col_names: | ||||
|             return c.AGENT, 'invalid' | ||||
|         elif not valid or Constants.LEVEL in col_names or Constants.AGENT in col_names: | ||||
|             return Constants.AGENT, 'invalid' | ||||
|         elif valid and not EnvActions.is_move(action): | ||||
|             return c.AGENT, 'valid' | ||||
|             return Constants.AGENT, 'valid' | ||||
|         elif valid and EnvActions.is_move(action): | ||||
|             return c.AGENT, 'move' | ||||
|             return Constants.AGENT, 'move' | ||||
|         else: | ||||
|             return c.AGENT, 'idle' | ||||
|             return Constants.AGENT, 'idle' | ||||
|     else: | ||||
|         return c.AGENT, 'idle' | ||||
|         return Constants.AGENT, 'idle' | ||||
|  | ||||
|  | ||||
| def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True): | ||||
|     """ | ||||
|     Given a set of coordinates, this function contructs a non-directed graph, by conncting adjected points. | ||||
|     There are three combinations of settings: | ||||
|         Allow all neigbors:     Distance(a, b) <= sqrt(2) | ||||
|         Allow only manhattan:   Distance(a, b) == 1 | ||||
|         Allow only euclidean:   Distance(a, b) == sqrt(2) | ||||
|  | ||||
|  | ||||
|     :param coordiniates_or_tiles: A set of coordinates. | ||||
|     :type coordiniates_or_tiles: Tiles | ||||
|     :param allow_euclidean_connections: Whether to regard diagonal adjected cells as neighbors | ||||
|     :type: bool | ||||
|     :param allow_manhattan_connections: Whether to regard directly adjected cells as neighbors | ||||
|     :type: bool | ||||
|  | ||||
|     :return: A graph with nodes that are conneceted as specified by the parameters. | ||||
|     :rtype: nx.Graph | ||||
|     """ | ||||
|     assert allow_euclidean_connections or allow_manhattan_connections | ||||
|     if hasattr(coordiniates_or_tiles, 'positions'): | ||||
|         coordiniates_or_tiles = coordiniates_or_tiles.positions | ||||
|     possible_connections = itertools.combinations(coordiniates_or_tiles, 2) | ||||
|     graph = nx.Graph() | ||||
|     for a, b in possible_connections: | ||||
|         diff = abs(np.subtract(a, b)) | ||||
|         if not max(diff) > 1: | ||||
|             if allow_manhattan_connections and allow_euclidean_connections: | ||||
|                 graph.add_edge(a, b) | ||||
|             elif not allow_manhattan_connections and allow_euclidean_connections and all(diff): | ||||
|                 graph.add_edge(a, b) | ||||
|             elif allow_manhattan_connections and not allow_euclidean_connections and not all(diff) and any(diff): | ||||
|                 graph.add_edge(a, b) | ||||
|         diff = np.linalg.norm(np.asarray(a)-np.asarray(b)) | ||||
|         if allow_manhattan_connections and allow_euclidean_connections and diff <= np.sqrt(2): | ||||
|             graph.add_edge(a, b) | ||||
|         elif not allow_manhattan_connections and allow_euclidean_connections and diff == np.sqrt(2): | ||||
|             graph.add_edge(a, b) | ||||
|         elif allow_manhattan_connections and not allow_euclidean_connections and diff == 1: | ||||
|             graph.add_edge(a, b) | ||||
|     return graph | ||||
|   | ||||
| @@ -4,6 +4,22 @@ from gym.wrappers.frame_stack import FrameStack | ||||
|  | ||||
|  | ||||
| class AgentRenderOptions(object): | ||||
|     """ | ||||
|     Class that specifies the available options for the way agents are represented in the env observation. | ||||
|  | ||||
|     SEPERATE: | ||||
|     Each agent is represented in a seperate slice as Constant.OCCUPIED_CELL value (one hot) | ||||
|  | ||||
|     COMBINED: | ||||
|     For all agent, value of Constant.OCCUPIED_CELL is added to a zero-value slice at the agents position (sum(SEPERATE)) | ||||
|  | ||||
|     LEVEL: | ||||
|     The combined slice is added to the LEVEL-slice. (Agents appear as obstacle / wall) | ||||
|  | ||||
|     NOT: | ||||
|     The position of individual agents can not be read from the observation. | ||||
|     """ | ||||
|  | ||||
|     SEPERATE = 'seperate' | ||||
|     COMBINED = 'combined' | ||||
|     LEVEL = 'lvl' | ||||
| @@ -11,24 +27,61 @@ class AgentRenderOptions(object): | ||||
|  | ||||
|  | ||||
| class MovementProperties(NamedTuple): | ||||
|     """ | ||||
|     Property holder; for setting multiple related parameters through a single parameter. Comes with default values. | ||||
|     """ | ||||
|  | ||||
|     """Allow the manhattan style movement on a grid (move to cells that are connected by square edges).""" | ||||
|     allow_square_movement: bool = True | ||||
|  | ||||
|     """Allow diagonal movement on the grid (move to cells that are connected by square corners).""" | ||||
|     allow_diagonal_movement: bool = False | ||||
|  | ||||
|     """Allow the agent to just do nothing; not move (NO-OP).""" | ||||
|     allow_no_op: bool = False | ||||
|  | ||||
|  | ||||
| class ObservationProperties(NamedTuple): | ||||
|     # Todo: Add Description | ||||
|     """ | ||||
|     Property holder; for setting multiple related parameters through a single parameter. Comes with default values. | ||||
|     """ | ||||
|  | ||||
|     """How to represent agents in the observation space. This may also alters the obs-shape.""" | ||||
|     render_agents: AgentRenderOptions = AgentRenderOptions.SEPERATE | ||||
|  | ||||
|     """Obserations are build per agent; whether the current agent should be represented in its own observation.""" | ||||
|     omit_agent_self: bool = True | ||||
|  | ||||
|     """Their might be the case you want to modify the agents obs-space, so that it can be used with additional obs. | ||||
|        The additional slice can be filled with any number""" | ||||
|     additional_agent_placeholder: Union[None, str, int] = None | ||||
|  | ||||
|     """Whether to cast shadows (make floortiles and items hidden).; """ | ||||
|     cast_shadows: bool = True | ||||
|  | ||||
|     """Frame Stacking is a methode do give some temporal information to the agents.  | ||||
|     This paramters controls how many "old-frames" """ | ||||
|     frames_to_stack: int = 0 | ||||
|     pomdp_r: int = 0 | ||||
|  | ||||
|     """Specifies the radius (_r) of the agents field of view. Please note, that the agents grid cellis not taken  | ||||
|         accountance for. This means, that the resulting field of view diameter = `pomdp_r * 2 + 1`. | ||||
|         A 'pomdp_r' of 0 always returns the full env == no partial observability.""" | ||||
|     pomdp_r: int = 2 | ||||
|  | ||||
|     """Whether to place a visual encoding on walkable tiles around the doors. This is helpfull when the doors can be  | ||||
|     operated from their surrounding area. So the agent can more easily get a notion of where to choose the door option. | ||||
|     However, this is not necesarry at all.  | ||||
|         """ | ||||
|     indicate_door_area: bool = False | ||||
|  | ||||
|     """Whether to add the agents normalized global position as float values (2,1) to a seperate information slice. | ||||
|         More optional informations are to come. | ||||
|         """ | ||||
|     show_global_position_info: bool = False | ||||
|  | ||||
|  | ||||
| class MarlFrameStack(gym.ObservationWrapper): | ||||
|     """todo @romue404""" | ||||
|     def __init__(self, env): | ||||
|         super().__init__(env) | ||||
|  | ||||
|   | ||||
| @@ -215,7 +215,7 @@ if __name__ == '__main__': | ||||
|                                 clean_amount=0.34, | ||||
|                                 max_spawn_amount=0.1, max_global_amount=20, | ||||
|                                 max_local_amount=1, spawn_frequency=0, max_spawn_ratio=0.05, | ||||
|                                 dirt_smear_amount=0.0, agent_can_interact=True) | ||||
|                                 dirt_smear_amount=0.0) | ||||
|     item_props = ItemProperties(n_items=10, | ||||
|                                 spawn_frequency=30, n_drop_off_locations=2, | ||||
|                                 max_agent_inventory_capacity=15) | ||||
| @@ -349,6 +349,7 @@ if __name__ == '__main__': | ||||
|                         # Env Init & Model kwargs definition | ||||
|                         if model_cls.__name__ in ["PPO", "A2C"]: | ||||
|                             # env_factory = env_class(**env_kwargs) | ||||
|  | ||||
|                             env_factory = SubprocVecEnv([encapsule_env_factory(env_class, env_kwargs) | ||||
|                                                          for _ in range(6)], start_method="spawn") | ||||
|                             model_kwargs = policy_model_kwargs() | ||||
|   | ||||
| @@ -213,7 +213,8 @@ if __name__ == '__main__': | ||||
|                     env_factory.save_params(param_path) | ||||
|  | ||||
|                 # EnvMonitor Init | ||||
|                 callbacks = [EnvMonitor(env_factory)] | ||||
|                 env_monitor = EnvMonitor(env_factory) | ||||
|                 callbacks = [env_monitor] | ||||
|  | ||||
|                 # Model Init | ||||
|                 model = model_cls("MlpPolicy", env_factory, **policy_model_kwargs, | ||||
| @@ -233,7 +234,7 @@ if __name__ == '__main__': | ||||
|                 model.save(save_path) | ||||
|  | ||||
|                 # Monitor Save | ||||
|                 callbacks[0].save_run(combination_path / 'monitor.pick', | ||||
|                 env_monitor.save_run(combination_path / 'monitor.pick', | ||||
|                                       auto_plotting_keys=['step_reward', 'collision'] + env_plot_keys) | ||||
|  | ||||
|                 # Better be save then sorry: Clean up! | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Steffen Illium
					Steffen Illium