mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-25 04:41:34 +02:00
added documentation for env groups
This commit is contained in:
@ -11,12 +11,30 @@ class Entities(Objects):
|
||||
_entity = Objects
|
||||
|
||||
def neighboring_positions(self, pos):
|
||||
"""
|
||||
Get all 8 neighboring positions of a given position.
|
||||
|
||||
:param pos: The reference position.
|
||||
:return: List of neighboring positions.
|
||||
"""
|
||||
return [tuple(x) for x in (POS_MASK_8 + pos).reshape(-1, 2) if tuple(x) in self._floor_positions]
|
||||
|
||||
def neighboring_4_positions(self, pos):
|
||||
"""
|
||||
Get neighboring 4 positions of a given position. (North, East, South, West)
|
||||
|
||||
:param pos: Reference position.
|
||||
:return: List of neighboring positions.
|
||||
"""
|
||||
return [tuple(x) for x in (POS_MASK_4 + pos) if tuple(x) in self._floor_positions]
|
||||
|
||||
def get_entities_near_pos(self, pos):
|
||||
"""
|
||||
Get entities near a given position.
|
||||
|
||||
:param pos: The reference position.
|
||||
:return: List of entities near the position.
|
||||
"""
|
||||
return [y for x in itemgetter(*self.neighboring_positions(pos))(self.pos_dict) for y in x]
|
||||
|
||||
def render(self):
|
||||
@ -28,10 +46,18 @@ class Entities(Objects):
|
||||
|
||||
@property
|
||||
def floorlist(self):
|
||||
"""
|
||||
Shuffle and return the list of floor positions.
|
||||
|
||||
:return: Shuffled list of floor positions.
|
||||
"""
|
||||
shuffle(self._floor_positions)
|
||||
return [x for x in self._floor_positions]
|
||||
|
||||
def __init__(self, floor_positions):
|
||||
"""
|
||||
:param floor_positions: list of all positions that are not blocked by a wall.
|
||||
"""
|
||||
self._floor_positions = floor_positions
|
||||
self.pos_dict = None
|
||||
super().__init__()
|
||||
@ -40,28 +66,54 @@ class Entities(Objects):
|
||||
return f'{self.__class__.__name__}{[x for x in self]}'
|
||||
|
||||
def guests_that_can_collide(self, pos):
|
||||
"""
|
||||
Get entities at a position that can collide.
|
||||
|
||||
:param pos: The reference position.
|
||||
:return: List of entities at the position that can collide.
|
||||
"""
|
||||
return [x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
|
||||
|
||||
@property
|
||||
def empty_positions(self):
|
||||
"""
|
||||
Get shuffled list of empty positions.
|
||||
|
||||
:return: Shuffled list of empty positions.
|
||||
"""
|
||||
empty_positions = [key for key in self.floorlist if not self.pos_dict[key]]
|
||||
shuffle(empty_positions)
|
||||
return empty_positions
|
||||
|
||||
@property
|
||||
def occupied_positions(self): # positions that are not empty
|
||||
def occupied_positions(self):
|
||||
"""
|
||||
Get shuffled list of occupied positions.
|
||||
|
||||
:return: Shuffled list of occupied positions.
|
||||
"""
|
||||
empty_positions = [key for key in self.floorlist if self.pos_dict[key]]
|
||||
shuffle(empty_positions)
|
||||
return empty_positions
|
||||
|
||||
@property
|
||||
def blocked_positions(self):
|
||||
"""
|
||||
Get shuffled list of blocked positions.
|
||||
|
||||
:return: Shuffled list of blocked positions.
|
||||
"""
|
||||
blocked_positions = [key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])]
|
||||
shuffle(blocked_positions)
|
||||
return blocked_positions
|
||||
|
||||
@property
|
||||
def free_positions_generator(self):
|
||||
"""
|
||||
Get a generator for free positions.
|
||||
|
||||
:return: Generator for free positions.
|
||||
"""
|
||||
generator = (
|
||||
key for key in self.floorlist if all(not x.var_can_collide and not x.var_is_blocking_pos
|
||||
for x in self.pos_dict[key])
|
||||
@ -70,9 +122,19 @@ class Entities(Objects):
|
||||
|
||||
@property
|
||||
def free_positions_list(self):
|
||||
"""
|
||||
Get a list of free positions.
|
||||
|
||||
:return: List of free positions.
|
||||
"""
|
||||
return [x for x in self.free_positions_generator]
|
||||
|
||||
def iter_entities(self):
|
||||
"""
|
||||
Get an iterator over all entities in the collection.
|
||||
|
||||
:return: Iterator over entities.
|
||||
"""
|
||||
return iter((x for sublist in self.values() for x in sublist))
|
||||
|
||||
def add_items(self, items: Dict):
|
||||
@ -105,13 +167,30 @@ class Entities(Objects):
|
||||
print('OhOh (debug me)')
|
||||
|
||||
def by_pos(self, pos: (int, int)):
|
||||
"""
|
||||
Get entities at a specific position.
|
||||
|
||||
:param pos: The reference position.
|
||||
:return: List of entities at the position.
|
||||
"""
|
||||
return self.pos_dict[pos]
|
||||
|
||||
@property
|
||||
def positions(self):
|
||||
"""
|
||||
Get a list of all positions in the collection.
|
||||
|
||||
:return: List of positions.
|
||||
"""
|
||||
return [k for k, v in self.pos_dict.items() for _ in v]
|
||||
|
||||
def is_occupied(self, pos):
|
||||
"""
|
||||
Check if a position is occupied.
|
||||
|
||||
:param pos: The reference position.
|
||||
:return: True if the position is occupied, False otherwise.
|
||||
"""
|
||||
return len([x for x in self.pos_dict[pos] if x.var_can_collide or x.var_is_blocking_pos]) >= 1
|
||||
|
||||
def reset(self):
|
||||
|
Reference in New Issue
Block a user