added documentation for env groups

This commit is contained in:
Chanumask
2024-01-08 11:13:46 +01:00
parent 70bbdd256f
commit 0f6f34f83e
6 changed files with 296 additions and 11 deletions

View File

@@ -1,6 +1,6 @@
# About EDYS # About EDYS
### Tackling emergent dysfunctions (EDYs) in cooperation with Fraunhofer-IKS. ## Tackling emergent dysfunctions (EDYs) in cooperation with Fraunhofer-IKS.
Collaborating with Fraunhofer-IKS, this project is dedicated to investigating Emergent Dysfunctions (EDYs) Collaborating with Fraunhofer-IKS, this project is dedicated to investigating Emergent Dysfunctions (EDYs)
within multi-agent environments. within multi-agent environments.
@@ -40,7 +40,7 @@ Refer to [quickstart](_quickstart) for specific scenarios.
## Usage ## Usage
The majority of environment objects, including entities, rules, and assets, can be loaded automatically. The majority of environment objects, including entities, rules, and assets, can be loaded automatically.
Simply specify the requirements of your environment in a [*yaml*-configfile](marl_factory_grid/configs/default_config.yaml). Simply specify the requirements of your environment in a [*yaml*-config file](marl_factory_grid/configs/default_config.yaml).
If you only plan on using the environment without making any modifications, use ``quickstart_use``. If you only plan on using the environment without making any modifications, use ``quickstart_use``.
This creates a default config-file and another one that lists all possible options of the environment. This creates a default config-file and another one that lists all possible options of the environment.

View File

@@ -37,10 +37,10 @@ class Agents(Collection):
@property @property
def action_space(self): def action_space(self):
""" """
TODO The action space defines the set of all possible actions that an agent can take in the environment.
:return: Action space
:return: :rtype: gym.Space
""" """
from gymnasium import spaces from gymnasium import spaces
space = spaces.Tuple([spaces.Discrete(len(x.actions)) for x in self]) space = spaces.Tuple([spaces.Discrete(len(x.actions)) for x in self])
@@ -49,10 +49,10 @@ class Agents(Collection):
@property @property
def named_action_space(self) -> dict[str, dict[str, list[int]]]: def named_action_space(self) -> dict[str, dict[str, list[int]]]:
""" """
TODO Returns the named action space for agents.
:return: Named action space
:return: :rtype: dict[str, dict[str, list[int]]]
""" """
named_space = dict() named_space = dict()
for agent in self: for agent in self:

View File

@@ -13,31 +13,65 @@ class Collection(Objects):
@property @property
def var_is_blocking_light(self): def var_is_blocking_light(self):
"""
Indicates whether the collection blocks light.
:return: Always False for a collection.
"""
return False return False
@property @property
def var_is_blocking_pos(self): def var_is_blocking_pos(self):
"""
Indicates whether the collection blocks positions.
:return: Always False for a collection.
"""
return False return False
@property @property
def var_can_collide(self): def var_can_collide(self):
"""
Indicates whether the collection can collide.
:return: Always False for a collection.
"""
return False return False
@property @property
def var_can_move(self): def var_can_move(self):
"""
Indicates whether the collection can move.
:return: Always False for a collection.
"""
return False return False
@property @property
def var_has_position(self): def var_has_position(self):
"""
Indicates whether the collection has positions.
:return: Always True for a collection.
"""
return True return True
@property @property
def encodings(self): def encodings(self):
"""
Returns a list of encodings for all entities in the collection.
:return: List of encodings.
"""
return [x.encoding for x in self] return [x.encoding for x in self]
@property @property
def spawn_rule(self): def spawn_rule(self):
"""Prevent SpawnRule creation if Objects are spawned by map, Doors e.g.""" """
Prevents SpawnRule creation if Objects are spawned by the map, doors, etc.
:return: The spawn rule or None.
"""
if self.symbol: if self.symbol:
return None return None
elif self._spawnrule: elif self._spawnrule:
@@ -48,6 +82,17 @@ class Collection(Objects):
def __init__(self, size, *args, coords_or_quantity: int = None, ignore_blocking=False, def __init__(self, size, *args, coords_or_quantity: int = None, ignore_blocking=False,
spawnrule: Union[None, Dict[str, dict]] = None, spawnrule: Union[None, Dict[str, dict]] = None,
**kwargs): **kwargs):
"""
Initializes the Collection.
:param size: Size of the collection.
:type size: int
:param coords_or_quantity: Coordinates or quantity for spawning entities.
:param ignore_blocking: Ignore blocking when spawning entities.
:type ignore_blocking: bool
:param spawnrule: Spawn rule for the collection. Default: None
:type spawnrule: Union[None, Dict[str, dict]]
"""
super(Collection, self).__init__(*args, **kwargs) super(Collection, self).__init__(*args, **kwargs)
self._coords_or_quantity = coords_or_quantity self._coords_or_quantity = coords_or_quantity
self.size = size self.size = size
@@ -55,6 +100,17 @@ class Collection(Objects):
self._ignore_blocking = ignore_blocking self._ignore_blocking = ignore_blocking
def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, ignore_blocking=False, **entity_kwargs): def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, ignore_blocking=False, **entity_kwargs):
"""
Triggers the spawning of entities in the collection.
:param state: The game state.
:type state: marl_factory_grid.utils.states.GameState
:param entity_args: Additional arguments for entity creation.
:param coords_or_quantity: Coordinates or quantity for spawning entities.
:param ignore_blocking: Ignore blocking when spawning entities.
:param entity_kwargs: Additional keyword arguments for entity creation.
:return: Result of the spawn operation.
"""
coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity
if self.var_has_position: if self.var_has_position:
if self.var_has_position and isinstance(coords_or_quantity, int): if self.var_has_position and isinstance(coords_or_quantity, int):
@@ -74,6 +130,14 @@ class Collection(Objects):
raise ValueError(f'{self._entity.__name__} has no position!') raise ValueError(f'{self._entity.__name__} has no position!')
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args, **entity_kwargs): def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args, **entity_kwargs):
"""
Spawns entities in the collection.
:param coords_or_quantity: Coordinates or quantity for spawning entities.
:param entity_args: Additional arguments for entity creation.
:param entity_kwargs: Additional keyword arguments for entity creation.
:return: Validity of the spawn operation.
"""
if self.var_has_position: if self.var_has_position:
if isinstance(coords_or_quantity, int): if isinstance(coords_or_quantity, int):
raise ValueError(f'{self._entity.__name__} should have a position!') raise ValueError(f'{self._entity.__name__} should have a position!')
@@ -87,6 +151,11 @@ class Collection(Objects):
return c.VALID return c.VALID
def despawn(self, items: List[Object]): def despawn(self, items: List[Object]):
"""
Despawns entities from the collection.
:param items: List of entities to despawn.
"""
items = [items] if isinstance(items, Object) else items items = [items] if isinstance(items, Object) else items
for item in items: for item in items:
del self[item] del self[item]
@@ -97,9 +166,19 @@ class Collection(Objects):
return self return self
def delete_env_object(self, env_object): def delete_env_object(self, env_object):
"""
Deletes an environmental object from the collection.
:param env_object: The environmental object to delete.
"""
del self[env_object.name] del self[env_object.name]
def delete_env_object_by_name(self, name): def delete_env_object_by_name(self, name):
"""
Deletes an environmental object from the collection by name.
:param name: The name of the environmental object to delete.
"""
del self[name] del self[name]
@property @property
@@ -126,6 +205,13 @@ class Collection(Objects):
@classmethod @classmethod
def from_coordinates(cls, positions: [(int, int)], *args, entity_kwargs=None, **kwargs, ): def from_coordinates(cls, positions: [(int, int)], *args, entity_kwargs=None, **kwargs, ):
"""
Creates a collection of entities from specified coordinates.
:param positions: List of coordinates for entity positions.
:param args: Additional positional arguments.
:return: The created collection.
"""
collection = cls(*args, **kwargs) collection = cls(*args, **kwargs)
collection.add_items( collection.add_items(
[cls._entity(tuple(pos), **entity_kwargs if entity_kwargs is not None else {}) for pos in positions]) [cls._entity(tuple(pos), **entity_kwargs if entity_kwargs is not None else {}) for pos in positions])
@@ -141,6 +227,12 @@ class Collection(Objects):
super().__delitem__(name) super().__delitem__(name)
def by_pos(self, pos: (int, int)): def by_pos(self, pos: (int, int)):
"""
Retrieves an entity from the collection based on its position.
:param pos: The position tuple.
:return: The entity at the specified position or None if not found.
"""
pos = tuple(pos) pos = tuple(pos)
try: try:
return self.pos_dict[pos] return self.pos_dict[pos]
@@ -151,6 +243,11 @@ class Collection(Objects):
@property @property
def positions(self): def positions(self):
"""
Returns a list of positions for all entities in the collection.
:return: List of positions.
"""
return [e.pos for e in self] return [e.pos for e in self]
def notify_del_entity(self, entity: Entity): def notify_del_entity(self, entity: Entity):

View File

@@ -11,12 +11,30 @@ class Entities(Objects):
_entity = Objects _entity = Objects
def neighboring_positions(self, pos): def neighboring_positions(self, pos):
"""
Get all 8 neighboring positions of a given position.
:param pos: The reference position.
:return: List of neighboring positions.
"""
return [tuple(x) for x in (POS_MASK_8 + pos).reshape(-1, 2) if tuple(x) in self._floor_positions] return [tuple(x) for x in (POS_MASK_8 + pos).reshape(-1, 2) if tuple(x) in self._floor_positions]
def neighboring_4_positions(self, pos): def neighboring_4_positions(self, pos):
"""
Get neighboring 4 positions of a given position. (North, East, South, West)
:param pos: Reference position.
:return: List of neighboring positions.
"""
return [tuple(x) for x in (POS_MASK_4 + pos) if tuple(x) in self._floor_positions] return [tuple(x) for x in (POS_MASK_4 + pos) if tuple(x) in self._floor_positions]
def get_entities_near_pos(self, pos): def get_entities_near_pos(self, pos):
"""
Get entities near a given position.
:param pos: The reference position.
:return: List of entities near the position.
"""
return [y for x in itemgetter(*self.neighboring_positions(pos))(self.pos_dict) for y in x] return [y for x in itemgetter(*self.neighboring_positions(pos))(self.pos_dict) for y in x]
def render(self): def render(self):
@@ -28,10 +46,18 @@ class Entities(Objects):
@property @property
def floorlist(self): def floorlist(self):
"""
Shuffle and return the list of floor positions.
:return: Shuffled list of floor positions.
"""
shuffle(self._floor_positions) shuffle(self._floor_positions)
return [x for x in self._floor_positions] return [x for x in self._floor_positions]
def __init__(self, floor_positions): def __init__(self, floor_positions):
"""
:param floor_positions: list of all positions that are not blocked by a wall.
"""
self._floor_positions = floor_positions self._floor_positions = floor_positions
self.pos_dict = None self.pos_dict = None
super().__init__() super().__init__()
@@ -40,28 +66,54 @@ class Entities(Objects):
return f'{self.__class__.__name__}{[x for x in self]}' return f'{self.__class__.__name__}{[x for x in self]}'
def guests_that_can_collide(self, pos): def guests_that_can_collide(self, pos):
"""
Get entities at a position that can collide.
:param pos: The reference position.
:return: List of entities at the position that can collide.
"""
return [x for val in self.pos_dict[pos] for x in val if x.var_can_collide] return [x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
@property @property
def empty_positions(self): def empty_positions(self):
"""
Get shuffled list of empty positions.
:return: Shuffled list of empty positions.
"""
empty_positions = [key for key in self.floorlist if not self.pos_dict[key]] empty_positions = [key for key in self.floorlist if not self.pos_dict[key]]
shuffle(empty_positions) shuffle(empty_positions)
return empty_positions return empty_positions
@property @property
def occupied_positions(self): # positions that are not empty def occupied_positions(self):
"""
Get shuffled list of occupied positions.
:return: Shuffled list of occupied positions.
"""
empty_positions = [key for key in self.floorlist if self.pos_dict[key]] empty_positions = [key for key in self.floorlist if self.pos_dict[key]]
shuffle(empty_positions) shuffle(empty_positions)
return empty_positions return empty_positions
@property @property
def blocked_positions(self): def blocked_positions(self):
"""
Get shuffled list of blocked positions.
:return: Shuffled list of blocked positions.
"""
blocked_positions = [key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])] blocked_positions = [key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])]
shuffle(blocked_positions) shuffle(blocked_positions)
return blocked_positions return blocked_positions
@property @property
def free_positions_generator(self): def free_positions_generator(self):
"""
Get a generator for free positions.
:return: Generator for free positions.
"""
generator = ( generator = (
key for key in self.floorlist if all(not x.var_can_collide and not x.var_is_blocking_pos key for key in self.floorlist if all(not x.var_can_collide and not x.var_is_blocking_pos
for x in self.pos_dict[key]) for x in self.pos_dict[key])
@@ -70,9 +122,19 @@ class Entities(Objects):
@property @property
def free_positions_list(self): def free_positions_list(self):
"""
Get a list of free positions.
:return: List of free positions.
"""
return [x for x in self.free_positions_generator] return [x for x in self.free_positions_generator]
def iter_entities(self): def iter_entities(self):
"""
Get an iterator over all entities in the collection.
:return: Iterator over entities.
"""
return iter((x for sublist in self.values() for x in sublist)) return iter((x for sublist in self.values() for x in sublist))
def add_items(self, items: Dict): def add_items(self, items: Dict):
@@ -105,13 +167,30 @@ class Entities(Objects):
print('OhOh (debug me)') print('OhOh (debug me)')
def by_pos(self, pos: (int, int)): def by_pos(self, pos: (int, int)):
"""
Get entities at a specific position.
:param pos: The reference position.
:return: List of entities at the position.
"""
return self.pos_dict[pos] return self.pos_dict[pos]
@property @property
def positions(self): def positions(self):
"""
Get a list of all positions in the collection.
:return: List of positions.
"""
return [k for k, v in self.pos_dict.items() for _ in v] return [k for k, v in self.pos_dict.items() for _ in v]
def is_occupied(self, pos): def is_occupied(self, pos):
"""
Check if a position is occupied.
:param pos: The reference position.
:return: True if the position is occupied, False otherwise.
"""
return len([x for x in self.pos_dict[pos] if x.var_can_collide or x.var_is_blocking_pos]) >= 1 return len([x for x in self.pos_dict[pos] if x.var_can_collide or x.var_is_blocking_pos]) >= 1
def reset(self): def reset(self):

View File

@@ -1,29 +1,58 @@
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
"""
Mixins are a way to modularly extend the functionality of classes in object-oriented programming without using
inheritance in the traditional sense. They provide a means to include a set of methods and properties in a class that
can be reused across different class hierarchies.
"""
# noinspection PyUnresolvedReferences,PyTypeChecker # noinspection PyUnresolvedReferences,PyTypeChecker
class IsBoundMixin: class IsBoundMixin:
"""
This mixin is designed to be used in classes that represent objects which can be bound to another entity.
"""
def __repr__(self): def __repr__(self):
return f'{self.__class__.__name__}#{self._bound_entity.name}({self._data})' return f'{self.__class__.__name__}#{self._bound_entity.name}({self._data})'
def bind(self, entity): def bind(self, entity):
"""
Binds the current object to another entity.
:param entity: the entity to be bound
"""
# noinspection PyAttributeOutsideInit # noinspection PyAttributeOutsideInit
self._bound_entity = entity self._bound_entity = entity
return c.VALID return c.VALID
def belongs_to_entity(self, entity): def belongs_to_entity(self, entity):
"""
Checks if the given entity is the bound entity.
:return: True if the given entity is the bound entity, false otherwise.
"""
return self._bound_entity == entity return self._bound_entity == entity
# noinspection PyUnresolvedReferences,PyTypeChecker # noinspection PyUnresolvedReferences,PyTypeChecker
class HasBoundMixin: class HasBoundMixin:
"""
This mixin is intended for classes that contain a collection of objects and need functionality to interact with
those objects.
"""
@property @property
def obs_pairs(self): def obs_pairs(self):
"""
Returns a list of pairs containing the names and corresponding objects within the collection.
"""
return [(x.name, x) for x in self] return [(x.name, x) for x in self]
def by_entity(self, entity): def by_entity(self, entity):
"""
Retrieves an object from the collection based on its belonging to a specific entity.
"""
try: try:
return next((x for x in self if x.belongs_to_entity(entity))) return next((x for x in self if x.belongs_to_entity(entity)))
except (StopIteration, AttributeError): except (StopIteration, AttributeError):

View File

@@ -13,22 +13,37 @@ class Objects:
@property @property
def var_can_be_bound(self): def var_can_be_bound(self):
"""
Property indicating whether objects in the collection can be bound to another entity.
"""
return False return False
@property @property
def observers(self): def observers(self):
"""
Property returning a set of observers associated with the collection.
"""
return self._observers return self._observers
@property @property
def obs_tag(self): def obs_tag(self):
"""
Property providing a tag for observation purposes.
"""
return self.__class__.__name__ return self.__class__.__name__
@staticmethod @staticmethod
def render(): def render():
"""
Static method returning an empty list. Override this method in derived classes for rendering functionality.
"""
return [] return []
@property @property
def obs_pairs(self): def obs_pairs(self):
"""
Property returning a list of pairs containing the names and corresponding objects within the collection.
"""
pair_list = [(self.name, self)] pair_list = [(self.name, self)]
pair_list.extend([(a.name, a) for a in self]) pair_list.extend([(a.name, a) for a in self])
return pair_list return pair_list
@@ -48,12 +63,26 @@ class Objects:
self.pos_dict = defaultdict(list) self.pos_dict = defaultdict(list)
def __len__(self): def __len__(self):
"""
Returns the number of objects in the collection.
"""
return len(self._data) return len(self._data)
def __iter__(self) -> Iterator[Union[Object, None]]: def __iter__(self) -> Iterator[Union[Object, None]]:
return iter(self.values()) return iter(self.values())
def add_item(self, item: _entity): def add_item(self, item: _entity):
"""
Adds an item to the collection.
:param item: The object to add to the collection.
:returns: The updated collection.
Raises:
AssertionError: If the item is not of the correct type or already exists in the collection.
"""
assert_str = f'All item names have to be of type {self._entity}, but were {item.__class__}.,' assert_str = f'All item names have to be of type {self._entity}, but were {item.__class__}.,'
assert isinstance(item, self._entity), assert_str assert isinstance(item, self._entity), assert_str
assert self._data[item.name] is None, f'{item.name} allready exists!!!' assert self._data[item.name] is None, f'{item.name} allready exists!!!'
@@ -66,6 +95,9 @@ class Objects:
return self return self
def remove_item(self, item: _entity): def remove_item(self, item: _entity):
"""
Removes an item from the collection.
"""
for observer in item.observers: for observer in item.observers:
observer.notify_del_entity(item) observer.notify_del_entity(item)
# noinspection PyTypeChecker # noinspection PyTypeChecker
@@ -77,6 +109,9 @@ class Objects:
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
def del_observer(self, observer): def del_observer(self, observer):
"""
Removes an observer from the collection and its entities.
"""
self.observers.remove(observer) self.observers.remove(observer)
for entity in self: for entity in self:
if observer in entity.observers: if observer in entity.observers:
@@ -84,31 +119,56 @@ class Objects:
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
def add_observer(self, observer): def add_observer(self, observer):
"""
Adds an observer to the collection and its entities.
"""
self.observers.add(observer) self.observers.add(observer)
for entity in self: for entity in self:
entity.add_observer(observer) entity.add_observer(observer)
def add_items(self, items: List[_entity]): def add_items(self, items: List[_entity]):
"""
Adds a list of items to the collection.
:param items: List of items to add.
:type items: List[_entity]
:returns: The updated collection.
"""
for item in items: for item in items:
self.add_item(item) self.add_item(item)
return self return self
def keys(self): def keys(self):
"""
Returns the keys (names) of the objects in the collection.
"""
return self._data.keys() return self._data.keys()
def values(self): def values(self):
"""
Returns the values (objects) in the collection.
"""
return self._data.values() return self._data.values()
def items(self): def items(self):
"""
Returns the items (name-object pairs) in the collection.
"""
return self._data.items() return self._data.items()
def _get_index(self, item): def _get_index(self, item):
"""
Gets the index of an item in the collection.
"""
try: try:
return next(i for i, v in enumerate(self._data.values()) if v == item) return next(i for i, v in enumerate(self._data.values()) if v == item)
except StopIteration: except StopIteration:
return None return None
def by_name(self, name): def by_name(self, name):
"""
Gets an object from the collection by its name.
"""
return next(x for x in self if x.name == name) return next(x for x in self if x.name == name)
def __getitem__(self, item): def __getitem__(self, item):
@@ -131,6 +191,9 @@ class Objects:
return f'{self.__class__.__name__}[{len(self)}]' return f'{self.__class__.__name__}[{len(self)}]'
def notify_del_entity(self, entity: Object): def notify_del_entity(self, entity: Object):
"""
Notifies the collection that an entity has been deleted.
"""
try: try:
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
self.pos_dict[entity.pos].remove(entity) self.pos_dict[entity.pos].remove(entity)
@@ -138,6 +201,9 @@ class Objects:
pass pass
def notify_add_entity(self, entity: Object): def notify_add_entity(self, entity: Object):
"""
Notifies the collection that an entity has been added.
"""
try: try:
if self not in entity.observers: if self not in entity.observers:
entity.add_observer(self) entity.add_observer(self)
@@ -148,24 +214,38 @@ class Objects:
pass pass
def summarize_states(self): def summarize_states(self):
"""
Summarizes the states of all entities in the collection.
:returns: A list of dictionaries representing the summarized states of the entities.
:rtype: List[dict]
"""
# FIXME PROTOBUFF # FIXME PROTOBUFF
# return [e.summarize_state() for e in self] # return [e.summarize_state() for e in self]
return [e.summarize_state() for e in self] return [e.summarize_state() for e in self]
def by_entity(self, entity): def by_entity(self, entity):
"""
Gets an entity from the collection that belongs to a specified entity.
"""
try: try:
return h.get_first(self, filter_by=lambda x: x.belongs_to_entity(entity)) return h.get_first(self, filter_by=lambda x: x.belongs_to_entity(entity))
except (StopIteration, AttributeError): except (StopIteration, AttributeError):
return None return None
def idx_by_entity(self, entity): def idx_by_entity(self, entity):
"""
Gets the index of an entity in the collection.
"""
try: try:
return h.get_first_index(self, filter_by=lambda x: x == entity) return h.get_first_index(self, filter_by=lambda x: x == entity)
except (StopIteration, AttributeError): except (StopIteration, AttributeError):
return None return None
def reset(self): def reset(self):
"""
Resets the collection by clearing data and observers.
"""
self._data = defaultdict(lambda: None) self._data = defaultdict(lambda: None)
self._observers = set(self) self._observers = set(self)
self.pos_dict = defaultdict(list) self.pos_dict = defaultdict(list)