Chanumask ee4d29d50b Merge branch 'main' into refactor_rename
# Conflicts:
#	marl_factory_grid/configs/default_config.yaml
#	marl_factory_grid/environment/entity/object.py
2023-10-31 10:28:25 +01:00

95 lines
3.0 KiB
Python

from collections import defaultdict
from operator import itemgetter
from random import shuffle, random
from typing import Dict
from marl_factory_grid.environment.groups.objects import _Objects
from marl_factory_grid.utils.helpers import POS_MASK
class Entities(_Objects):
_entity = _Objects
@staticmethod
def neighboring_positions(pos):
return (POS_MASK + pos).reshape(-1, 2)
def get_entities_near_pos(self, pos):
return [y for x in itemgetter(*(tuple(x) for x in self.neighboring_positions(pos)))(self.pos_dict) for y in x]
def render(self):
return [y for x in self for y in x.render() if x is not None]
@property
def names(self):
return list(self._data.keys())
@property
def floorlist(self):
shuffle(self._floor_positions)
return self._floor_positions
def __init__(self, floor_positions):
self._floor_positions = floor_positions
self.pos_dict = defaultdict(list)
super().__init__()
def guests_that_can_collide(self, pos):
return[x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
def empty_positions(self):
empty_positions = [key for key in self.floorlist if not self.pos_dict[key]]
shuffle(empty_positions)
return empty_positions
@property
def occupied_positions(self): # positions that are not empty
empty_positions = [key for key in self.floorlist if self.pos_dict[key]]
shuffle(empty_positions)
return empty_positions
def is_blocked(self):
return[key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])]
def is_not_blocked(self):
return[key for key, val in self.pos_dict.items() if not all([x.var_is_blocking_pos for x in val])]
def iter_entities(self):
return iter((x for sublist in self.values() for x in sublist))
def add_items(self, items: Dict):
return self.add_item(items)
def add_item(self, item: dict):
assert_str = 'This group of entity has already been added!'
assert not any([key for key in item.keys() if key in self.keys()]), assert_str
self._data.update(item)
for val in item.values():
val.add_observer(self)
return self
def __contains__(self, item):
return item in self._data
def __delitem__(self, name):
assert_str = 'This group of entity does not exist in this collection!'
assert any([key for key in name.keys() if key in self.keys()]), assert_str
self[name]._observers.delete(self)
for entity in self[name]:
entity.del_observer(self)
return super(Entities, self).__delitem__(name)
@property
def obs_pairs(self):
try:
return [y for x in self for y in x.obs_pairs]
except AttributeError:
print('OhOh (debug me)')
def by_pos(self, pos: (int, int)):
return self.pos_dict[pos]
@property
def positions(self):
return [k for k, v in self.pos_dict.items() for _ in v]