mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-05 17:11:35 +02:00
added documentation for env groups
This commit is contained in:
@ -37,10 +37,10 @@ class Agents(Collection):
|
||||
@property
|
||||
def action_space(self):
|
||||
"""
|
||||
TODO
|
||||
The action space defines the set of all possible actions that an agent can take in the environment.
|
||||
|
||||
|
||||
:return:
|
||||
:return: Action space
|
||||
:rtype: gym.Space
|
||||
"""
|
||||
from gymnasium import spaces
|
||||
space = spaces.Tuple([spaces.Discrete(len(x.actions)) for x in self])
|
||||
@ -49,10 +49,10 @@ class Agents(Collection):
|
||||
@property
|
||||
def named_action_space(self) -> dict[str, dict[str, list[int]]]:
|
||||
"""
|
||||
TODO
|
||||
Returns the named action space for agents.
|
||||
|
||||
|
||||
:return:
|
||||
:return: Named action space
|
||||
:rtype: dict[str, dict[str, list[int]]]
|
||||
"""
|
||||
named_space = dict()
|
||||
for agent in self:
|
||||
|
@ -13,31 +13,65 @@ class Collection(Objects):
|
||||
|
||||
@property
|
||||
def var_is_blocking_light(self):
|
||||
"""
|
||||
Indicates whether the collection blocks light.
|
||||
|
||||
:return: Always False for a collection.
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_is_blocking_pos(self):
|
||||
"""
|
||||
Indicates whether the collection blocks positions.
|
||||
|
||||
:return: Always False for a collection.
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_can_collide(self):
|
||||
"""
|
||||
Indicates whether the collection can collide.
|
||||
|
||||
:return: Always False for a collection.
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_can_move(self):
|
||||
"""
|
||||
Indicates whether the collection can move.
|
||||
|
||||
:return: Always False for a collection.
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
"""
|
||||
Indicates whether the collection has positions.
|
||||
|
||||
:return: Always True for a collection.
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def encodings(self):
|
||||
"""
|
||||
Returns a list of encodings for all entities in the collection.
|
||||
|
||||
:return: List of encodings.
|
||||
"""
|
||||
return [x.encoding for x in self]
|
||||
|
||||
@property
|
||||
def spawn_rule(self):
|
||||
"""Prevent SpawnRule creation if Objects are spawned by map, Doors e.g."""
|
||||
"""
|
||||
Prevents SpawnRule creation if Objects are spawned by the map, doors, etc.
|
||||
|
||||
:return: The spawn rule or None.
|
||||
"""
|
||||
if self.symbol:
|
||||
return None
|
||||
elif self._spawnrule:
|
||||
@ -48,6 +82,17 @@ class Collection(Objects):
|
||||
def __init__(self, size, *args, coords_or_quantity: int = None, ignore_blocking=False,
|
||||
spawnrule: Union[None, Dict[str, dict]] = None,
|
||||
**kwargs):
|
||||
"""
|
||||
Initializes the Collection.
|
||||
|
||||
:param size: Size of the collection.
|
||||
:type size: int
|
||||
:param coords_or_quantity: Coordinates or quantity for spawning entities.
|
||||
:param ignore_blocking: Ignore blocking when spawning entities.
|
||||
:type ignore_blocking: bool
|
||||
:param spawnrule: Spawn rule for the collection. Default: None
|
||||
:type spawnrule: Union[None, Dict[str, dict]]
|
||||
"""
|
||||
super(Collection, self).__init__(*args, **kwargs)
|
||||
self._coords_or_quantity = coords_or_quantity
|
||||
self.size = size
|
||||
@ -55,6 +100,17 @@ class Collection(Objects):
|
||||
self._ignore_blocking = ignore_blocking
|
||||
|
||||
def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, ignore_blocking=False, **entity_kwargs):
|
||||
"""
|
||||
Triggers the spawning of entities in the collection.
|
||||
|
||||
:param state: The game state.
|
||||
:type state: marl_factory_grid.utils.states.GameState
|
||||
:param entity_args: Additional arguments for entity creation.
|
||||
:param coords_or_quantity: Coordinates or quantity for spawning entities.
|
||||
:param ignore_blocking: Ignore blocking when spawning entities.
|
||||
:param entity_kwargs: Additional keyword arguments for entity creation.
|
||||
:return: Result of the spawn operation.
|
||||
"""
|
||||
coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity
|
||||
if self.var_has_position:
|
||||
if self.var_has_position and isinstance(coords_or_quantity, int):
|
||||
@ -74,6 +130,14 @@ class Collection(Objects):
|
||||
raise ValueError(f'{self._entity.__name__} has no position!')
|
||||
|
||||
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args, **entity_kwargs):
|
||||
"""
|
||||
Spawns entities in the collection.
|
||||
|
||||
:param coords_or_quantity: Coordinates or quantity for spawning entities.
|
||||
:param entity_args: Additional arguments for entity creation.
|
||||
:param entity_kwargs: Additional keyword arguments for entity creation.
|
||||
:return: Validity of the spawn operation.
|
||||
"""
|
||||
if self.var_has_position:
|
||||
if isinstance(coords_or_quantity, int):
|
||||
raise ValueError(f'{self._entity.__name__} should have a position!')
|
||||
@ -87,6 +151,11 @@ class Collection(Objects):
|
||||
return c.VALID
|
||||
|
||||
def despawn(self, items: List[Object]):
|
||||
"""
|
||||
Despawns entities from the collection.
|
||||
|
||||
:param items: List of entities to despawn.
|
||||
"""
|
||||
items = [items] if isinstance(items, Object) else items
|
||||
for item in items:
|
||||
del self[item]
|
||||
@ -97,9 +166,19 @@ class Collection(Objects):
|
||||
return self
|
||||
|
||||
def delete_env_object(self, env_object):
|
||||
"""
|
||||
Deletes an environmental object from the collection.
|
||||
|
||||
:param env_object: The environmental object to delete.
|
||||
"""
|
||||
del self[env_object.name]
|
||||
|
||||
def delete_env_object_by_name(self, name):
|
||||
"""
|
||||
Deletes an environmental object from the collection by name.
|
||||
|
||||
:param name: The name of the environmental object to delete.
|
||||
"""
|
||||
del self[name]
|
||||
|
||||
@property
|
||||
@ -126,6 +205,13 @@ class Collection(Objects):
|
||||
|
||||
@classmethod
|
||||
def from_coordinates(cls, positions: [(int, int)], *args, entity_kwargs=None, **kwargs, ):
|
||||
"""
|
||||
Creates a collection of entities from specified coordinates.
|
||||
|
||||
:param positions: List of coordinates for entity positions.
|
||||
:param args: Additional positional arguments.
|
||||
:return: The created collection.
|
||||
"""
|
||||
collection = cls(*args, **kwargs)
|
||||
collection.add_items(
|
||||
[cls._entity(tuple(pos), **entity_kwargs if entity_kwargs is not None else {}) for pos in positions])
|
||||
@ -141,6 +227,12 @@ class Collection(Objects):
|
||||
super().__delitem__(name)
|
||||
|
||||
def by_pos(self, pos: (int, int)):
|
||||
"""
|
||||
Retrieves an entity from the collection based on its position.
|
||||
|
||||
:param pos: The position tuple.
|
||||
:return: The entity at the specified position or None if not found.
|
||||
"""
|
||||
pos = tuple(pos)
|
||||
try:
|
||||
return self.pos_dict[pos]
|
||||
@ -151,6 +243,11 @@ class Collection(Objects):
|
||||
|
||||
@property
|
||||
def positions(self):
|
||||
"""
|
||||
Returns a list of positions for all entities in the collection.
|
||||
|
||||
:return: List of positions.
|
||||
"""
|
||||
return [e.pos for e in self]
|
||||
|
||||
def notify_del_entity(self, entity: Entity):
|
||||
|
@ -11,12 +11,30 @@ class Entities(Objects):
|
||||
_entity = Objects
|
||||
|
||||
def neighboring_positions(self, pos):
|
||||
"""
|
||||
Get all 8 neighboring positions of a given position.
|
||||
|
||||
:param pos: The reference position.
|
||||
:return: List of neighboring positions.
|
||||
"""
|
||||
return [tuple(x) for x in (POS_MASK_8 + pos).reshape(-1, 2) if tuple(x) in self._floor_positions]
|
||||
|
||||
def neighboring_4_positions(self, pos):
|
||||
"""
|
||||
Get neighboring 4 positions of a given position. (North, East, South, West)
|
||||
|
||||
:param pos: Reference position.
|
||||
:return: List of neighboring positions.
|
||||
"""
|
||||
return [tuple(x) for x in (POS_MASK_4 + pos) if tuple(x) in self._floor_positions]
|
||||
|
||||
def get_entities_near_pos(self, pos):
|
||||
"""
|
||||
Get entities near a given position.
|
||||
|
||||
:param pos: The reference position.
|
||||
:return: List of entities near the position.
|
||||
"""
|
||||
return [y for x in itemgetter(*self.neighboring_positions(pos))(self.pos_dict) for y in x]
|
||||
|
||||
def render(self):
|
||||
@ -28,10 +46,18 @@ class Entities(Objects):
|
||||
|
||||
@property
|
||||
def floorlist(self):
|
||||
"""
|
||||
Shuffle and return the list of floor positions.
|
||||
|
||||
:return: Shuffled list of floor positions.
|
||||
"""
|
||||
shuffle(self._floor_positions)
|
||||
return [x for x in self._floor_positions]
|
||||
|
||||
def __init__(self, floor_positions):
|
||||
"""
|
||||
:param floor_positions: list of all positions that are not blocked by a wall.
|
||||
"""
|
||||
self._floor_positions = floor_positions
|
||||
self.pos_dict = None
|
||||
super().__init__()
|
||||
@ -40,28 +66,54 @@ class Entities(Objects):
|
||||
return f'{self.__class__.__name__}{[x for x in self]}'
|
||||
|
||||
def guests_that_can_collide(self, pos):
|
||||
"""
|
||||
Get entities at a position that can collide.
|
||||
|
||||
:param pos: The reference position.
|
||||
:return: List of entities at the position that can collide.
|
||||
"""
|
||||
return [x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
|
||||
|
||||
@property
|
||||
def empty_positions(self):
|
||||
"""
|
||||
Get shuffled list of empty positions.
|
||||
|
||||
:return: Shuffled list of empty positions.
|
||||
"""
|
||||
empty_positions = [key for key in self.floorlist if not self.pos_dict[key]]
|
||||
shuffle(empty_positions)
|
||||
return empty_positions
|
||||
|
||||
@property
|
||||
def occupied_positions(self): # positions that are not empty
|
||||
def occupied_positions(self):
|
||||
"""
|
||||
Get shuffled list of occupied positions.
|
||||
|
||||
:return: Shuffled list of occupied positions.
|
||||
"""
|
||||
empty_positions = [key for key in self.floorlist if self.pos_dict[key]]
|
||||
shuffle(empty_positions)
|
||||
return empty_positions
|
||||
|
||||
@property
|
||||
def blocked_positions(self):
|
||||
"""
|
||||
Get shuffled list of blocked positions.
|
||||
|
||||
:return: Shuffled list of blocked positions.
|
||||
"""
|
||||
blocked_positions = [key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])]
|
||||
shuffle(blocked_positions)
|
||||
return blocked_positions
|
||||
|
||||
@property
|
||||
def free_positions_generator(self):
|
||||
"""
|
||||
Get a generator for free positions.
|
||||
|
||||
:return: Generator for free positions.
|
||||
"""
|
||||
generator = (
|
||||
key for key in self.floorlist if all(not x.var_can_collide and not x.var_is_blocking_pos
|
||||
for x in self.pos_dict[key])
|
||||
@ -70,9 +122,19 @@ class Entities(Objects):
|
||||
|
||||
@property
|
||||
def free_positions_list(self):
|
||||
"""
|
||||
Get a list of free positions.
|
||||
|
||||
:return: List of free positions.
|
||||
"""
|
||||
return [x for x in self.free_positions_generator]
|
||||
|
||||
def iter_entities(self):
|
||||
"""
|
||||
Get an iterator over all entities in the collection.
|
||||
|
||||
:return: Iterator over entities.
|
||||
"""
|
||||
return iter((x for sublist in self.values() for x in sublist))
|
||||
|
||||
def add_items(self, items: Dict):
|
||||
@ -105,13 +167,30 @@ class Entities(Objects):
|
||||
print('OhOh (debug me)')
|
||||
|
||||
def by_pos(self, pos: (int, int)):
|
||||
"""
|
||||
Get entities at a specific position.
|
||||
|
||||
:param pos: The reference position.
|
||||
:return: List of entities at the position.
|
||||
"""
|
||||
return self.pos_dict[pos]
|
||||
|
||||
@property
|
||||
def positions(self):
|
||||
"""
|
||||
Get a list of all positions in the collection.
|
||||
|
||||
:return: List of positions.
|
||||
"""
|
||||
return [k for k, v in self.pos_dict.items() for _ in v]
|
||||
|
||||
def is_occupied(self, pos):
|
||||
"""
|
||||
Check if a position is occupied.
|
||||
|
||||
:param pos: The reference position.
|
||||
:return: True if the position is occupied, False otherwise.
|
||||
"""
|
||||
return len([x for x in self.pos_dict[pos] if x.var_can_collide or x.var_is_blocking_pos]) >= 1
|
||||
|
||||
def reset(self):
|
||||
|
@ -1,29 +1,58 @@
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
"""
|
||||
Mixins are a way to modularly extend the functionality of classes in object-oriented programming without using
|
||||
inheritance in the traditional sense. They provide a means to include a set of methods and properties in a class that
|
||||
can be reused across different class hierarchies.
|
||||
"""
|
||||
|
||||
|
||||
# noinspection PyUnresolvedReferences,PyTypeChecker
|
||||
class IsBoundMixin:
|
||||
"""
|
||||
This mixin is designed to be used in classes that represent objects which can be bound to another entity.
|
||||
"""
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}#{self._bound_entity.name}({self._data})'
|
||||
|
||||
def bind(self, entity):
|
||||
"""
|
||||
Binds the current object to another entity.
|
||||
|
||||
:param entity: the entity to be bound
|
||||
"""
|
||||
# noinspection PyAttributeOutsideInit
|
||||
self._bound_entity = entity
|
||||
return c.VALID
|
||||
|
||||
def belongs_to_entity(self, entity):
|
||||
"""
|
||||
Checks if the given entity is the bound entity.
|
||||
|
||||
:return: True if the given entity is the bound entity, false otherwise.
|
||||
"""
|
||||
return self._bound_entity == entity
|
||||
|
||||
|
||||
# noinspection PyUnresolvedReferences,PyTypeChecker
|
||||
class HasBoundMixin:
|
||||
"""
|
||||
This mixin is intended for classes that contain a collection of objects and need functionality to interact with
|
||||
those objects.
|
||||
"""
|
||||
|
||||
@property
|
||||
def obs_pairs(self):
|
||||
"""
|
||||
Returns a list of pairs containing the names and corresponding objects within the collection.
|
||||
"""
|
||||
return [(x.name, x) for x in self]
|
||||
|
||||
def by_entity(self, entity):
|
||||
"""
|
||||
Retrieves an object from the collection based on its belonging to a specific entity.
|
||||
"""
|
||||
try:
|
||||
return next((x for x in self if x.belongs_to_entity(entity)))
|
||||
except (StopIteration, AttributeError):
|
||||
|
@ -13,22 +13,37 @@ class Objects:
|
||||
|
||||
@property
|
||||
def var_can_be_bound(self):
|
||||
"""
|
||||
Property indicating whether objects in the collection can be bound to another entity.
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def observers(self):
|
||||
"""
|
||||
Property returning a set of observers associated with the collection.
|
||||
"""
|
||||
return self._observers
|
||||
|
||||
@property
|
||||
def obs_tag(self):
|
||||
"""
|
||||
Property providing a tag for observation purposes.
|
||||
"""
|
||||
return self.__class__.__name__
|
||||
|
||||
@staticmethod
|
||||
def render():
|
||||
"""
|
||||
Static method returning an empty list. Override this method in derived classes for rendering functionality.
|
||||
"""
|
||||
return []
|
||||
|
||||
@property
|
||||
def obs_pairs(self):
|
||||
"""
|
||||
Property returning a list of pairs containing the names and corresponding objects within the collection.
|
||||
"""
|
||||
pair_list = [(self.name, self)]
|
||||
pair_list.extend([(a.name, a) for a in self])
|
||||
return pair_list
|
||||
@ -48,12 +63,26 @@ class Objects:
|
||||
self.pos_dict = defaultdict(list)
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Returns the number of objects in the collection.
|
||||
"""
|
||||
return len(self._data)
|
||||
|
||||
def __iter__(self) -> Iterator[Union[Object, None]]:
|
||||
return iter(self.values())
|
||||
|
||||
def add_item(self, item: _entity):
|
||||
"""
|
||||
Adds an item to the collection.
|
||||
|
||||
|
||||
:param item: The object to add to the collection.
|
||||
|
||||
:returns: The updated collection.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the item is not of the correct type or already exists in the collection.
|
||||
"""
|
||||
assert_str = f'All item names have to be of type {self._entity}, but were {item.__class__}.,'
|
||||
assert isinstance(item, self._entity), assert_str
|
||||
assert self._data[item.name] is None, f'{item.name} allready exists!!!'
|
||||
@ -66,6 +95,9 @@ class Objects:
|
||||
return self
|
||||
|
||||
def remove_item(self, item: _entity):
|
||||
"""
|
||||
Removes an item from the collection.
|
||||
"""
|
||||
for observer in item.observers:
|
||||
observer.notify_del_entity(item)
|
||||
# noinspection PyTypeChecker
|
||||
@ -77,6 +109,9 @@ class Objects:
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
def del_observer(self, observer):
|
||||
"""
|
||||
Removes an observer from the collection and its entities.
|
||||
"""
|
||||
self.observers.remove(observer)
|
||||
for entity in self:
|
||||
if observer in entity.observers:
|
||||
@ -84,31 +119,56 @@ class Objects:
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
def add_observer(self, observer):
|
||||
"""
|
||||
Adds an observer to the collection and its entities.
|
||||
"""
|
||||
self.observers.add(observer)
|
||||
for entity in self:
|
||||
entity.add_observer(observer)
|
||||
|
||||
def add_items(self, items: List[_entity]):
|
||||
"""
|
||||
Adds a list of items to the collection.
|
||||
|
||||
:param items: List of items to add.
|
||||
:type items: List[_entity]
|
||||
:returns: The updated collection.
|
||||
"""
|
||||
for item in items:
|
||||
self.add_item(item)
|
||||
return self
|
||||
|
||||
def keys(self):
|
||||
"""
|
||||
Returns the keys (names) of the objects in the collection.
|
||||
"""
|
||||
return self._data.keys()
|
||||
|
||||
def values(self):
|
||||
"""
|
||||
Returns the values (objects) in the collection.
|
||||
"""
|
||||
return self._data.values()
|
||||
|
||||
def items(self):
|
||||
"""
|
||||
Returns the items (name-object pairs) in the collection.
|
||||
"""
|
||||
return self._data.items()
|
||||
|
||||
def _get_index(self, item):
|
||||
"""
|
||||
Gets the index of an item in the collection.
|
||||
"""
|
||||
try:
|
||||
return next(i for i, v in enumerate(self._data.values()) if v == item)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def by_name(self, name):
|
||||
"""
|
||||
Gets an object from the collection by its name.
|
||||
"""
|
||||
return next(x for x in self if x.name == name)
|
||||
|
||||
def __getitem__(self, item):
|
||||
@ -131,6 +191,9 @@ class Objects:
|
||||
return f'{self.__class__.__name__}[{len(self)}]'
|
||||
|
||||
def notify_del_entity(self, entity: Object):
|
||||
"""
|
||||
Notifies the collection that an entity has been deleted.
|
||||
"""
|
||||
try:
|
||||
# noinspection PyUnresolvedReferences
|
||||
self.pos_dict[entity.pos].remove(entity)
|
||||
@ -138,6 +201,9 @@ class Objects:
|
||||
pass
|
||||
|
||||
def notify_add_entity(self, entity: Object):
|
||||
"""
|
||||
Notifies the collection that an entity has been added.
|
||||
"""
|
||||
try:
|
||||
if self not in entity.observers:
|
||||
entity.add_observer(self)
|
||||
@ -148,24 +214,38 @@ class Objects:
|
||||
pass
|
||||
|
||||
def summarize_states(self):
|
||||
"""
|
||||
Summarizes the states of all entities in the collection.
|
||||
|
||||
:returns: A list of dictionaries representing the summarized states of the entities.
|
||||
:rtype: List[dict]
|
||||
"""
|
||||
# FIXME PROTOBUFF
|
||||
# return [e.summarize_state() for e in self]
|
||||
return [e.summarize_state() for e in self]
|
||||
|
||||
def by_entity(self, entity):
|
||||
"""
|
||||
Gets an entity from the collection that belongs to a specified entity.
|
||||
"""
|
||||
try:
|
||||
return h.get_first(self, filter_by=lambda x: x.belongs_to_entity(entity))
|
||||
except (StopIteration, AttributeError):
|
||||
return None
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
"""
|
||||
Gets the index of an entity in the collection.
|
||||
"""
|
||||
try:
|
||||
return h.get_first_index(self, filter_by=lambda x: x == entity)
|
||||
except (StopIteration, AttributeError):
|
||||
return None
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Resets the collection by clearing data and observers.
|
||||
"""
|
||||
self._data = defaultdict(lambda: None)
|
||||
self._observers = set(self)
|
||||
self.pos_dict = defaultdict(list)
|
||||
|
||||
|
Reference in New Issue
Block a user