from collections import defaultdict from typing import List, Iterator, Union import numpy as np from marl_factory_grid.environment.entity.object import Object import marl_factory_grid.environment.constants as c from marl_factory_grid.utils import helpers as h class Objects: _entity = Object @property def var_can_be_bound(self): return False @property def observers(self): return self._observers @property def obs_tag(self): return self.__class__.__name__ @staticmethod def render(): return [] @property def obs_pairs(self): pair_list = [(self.name, self)] pair_list.extend([(a.name, a) for a in self]) return pair_list @property def names(self): # noinspection PyUnresolvedReferences return [x.name for x in self] @property def name(self): return f'{self.__class__.__name__}' def __init__(self, *args, **kwargs): self._data = defaultdict(lambda: None) self._observers = [self] self.pos_dict = defaultdict(list) def __len__(self): return len(self._data) def __iter__(self) -> Iterator[Union[Object, None]]: return iter(self.values()) def add_item(self, item: _entity): assert_str = f'All item names have to be of type {self._entity}, but were {item.__class__}.,' assert isinstance(item, self._entity), assert_str assert self._data[item.name] is None, f'{item.name} allready exists!!!' self._data.update({item.name: item}) item.set_collection(self) for observer in self.observers: observer.notify_add_entity(item) return self def remove_item(self, item: _entity): for observer in item.observers: observer.notify_del_entity(item) # noinspection PyTypeChecker del self._data[item.name] return True def __delitem__(self, name): return self.remove_item(self[name]) # noinspection PyUnresolvedReferences def del_observer(self, observer): self.observers.remove(observer) for entity in self: if observer in entity.observers: entity.del_observer(observer) # noinspection PyUnresolvedReferences def add_observer(self, observer): self.observers.append(observer) for entity in self: if observer not in entity.observers: entity.add_observer(observer) def add_items(self, items: List[_entity]): for item in items: self.add_item(item) return self def keys(self): return self._data.keys() def values(self): return self._data.values() def items(self): return self._data.items() def _get_index(self, item): try: return next(i for i, v in enumerate(self._data.values()) if v == item) except StopIteration: return None def by_name(self, name): return next(x for x in self if x.name == name) def __getitem__(self, item): if isinstance(item, (int, np.int64, np.int32)): if item < 0: item = len(self._data) - abs(item) try: return next(v for i, v in enumerate(self._data.values()) if i == item) except StopIteration: return None try: return self._data[item] except KeyError: return None except TypeError: print('Ups') raise TypeError def __repr__(self): repr_dict = {key: val for key, val in self._data.items() if key not in [c.WALLS]} return f'{self.__class__.__name__}[{repr_dict}]' def notify_del_entity(self, entity: Object): try: # noinspection PyUnresolvedReferences self.pos_dict[entity.pos].remove(entity) except (AttributeError, ValueError, IndexError): pass def notify_add_entity(self, entity: Object): try: if self not in entity.observers: entity.add_observer(self) if entity.var_has_position: if entity not in self.pos_dict[entity.pos]: self.pos_dict[entity.pos].append(entity) except (ValueError, AttributeError): pass def summarize_states(self): # FIXME PROTOBUFF # return [e.summarize_state() for e in self] return [e.summarize_state() for e in self] def by_entity(self, entity): try: return h.get_first(self, filter_by=lambda x: x.belongs_to_entity(entity)) except (StopIteration, AttributeError): return None def idx_by_entity(self, entity): try: return h.get_first_index(self, filter_by=lambda x: x.belongs_to_entity(entity)) except (StopIteration, AttributeError): return None