mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-26 05:01:36 +02:00
Maintainer and pos_dicts fixed. Are sets now.
This commit is contained in:
@ -2,7 +2,6 @@ from typing import List, Tuple, Union, Dict
|
||||
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.environment.groups.objects import Objects
|
||||
# noinspection PyProtectedMember
|
||||
from marl_factory_grid.environment.entity.object import Object
|
||||
import marl_factory_grid.environment.constants as c
|
||||
from marl_factory_grid.utils.results import Result
|
||||
|
@ -31,9 +31,12 @@ class Entities(Objects):
|
||||
|
||||
def __init__(self, floor_positions):
|
||||
self._floor_positions = floor_positions
|
||||
self.pos_dict = defaultdict(list)
|
||||
self.pos_dict = None
|
||||
super().__init__()
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}{[x for x in self]}'
|
||||
|
||||
def guests_that_can_collide(self, pos):
|
||||
return [x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
|
||||
|
||||
@ -108,3 +111,12 @@ class Entities(Objects):
|
||||
|
||||
def is_occupied(self, pos):
|
||||
return len([x for x in self.pos_dict[pos] if x.var_can_collide or x.var_is_blocking_pos]) >= 1
|
||||
|
||||
def reset(self):
|
||||
self._observers = set(self)
|
||||
self.pos_dict = defaultdict(list)
|
||||
for entity_group in self:
|
||||
entity_group.reset()
|
||||
|
||||
if hasattr(entity_group, "var_has_position") and entity_group.var_has_position:
|
||||
entity_group.add_observer(self)
|
||||
|
@ -44,7 +44,7 @@ class Objects:
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._data = defaultdict(lambda: None)
|
||||
self._observers = [self]
|
||||
self._observers = set(self)
|
||||
self.pos_dict = defaultdict(list)
|
||||
|
||||
def __len__(self):
|
||||
@ -59,6 +59,8 @@ class Objects:
|
||||
assert self._data[item.name] is None, f'{item.name} allready exists!!!'
|
||||
self._data.update({item.name: item})
|
||||
item.set_collection(self)
|
||||
if hasattr(self, "var_has_position") and self.var_has_position:
|
||||
item.add_observer(self)
|
||||
for observer in self.observers:
|
||||
observer.notify_add_entity(item)
|
||||
return self
|
||||
@ -82,10 +84,9 @@ class Objects:
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
def add_observer(self, observer):
|
||||
self.observers.append(observer)
|
||||
self.observers.add(observer)
|
||||
for entity in self:
|
||||
if observer not in entity.observers:
|
||||
entity.add_observer(observer)
|
||||
entity.add_observer(observer)
|
||||
|
||||
def add_items(self, items: List[_entity]):
|
||||
for item in items:
|
||||
@ -127,8 +128,7 @@ class Objects:
|
||||
raise TypeError
|
||||
|
||||
def __repr__(self):
|
||||
repr_dict = {key: val for key, val in self._data.items() if key not in [c.WALLS]}
|
||||
return f'{self.__class__.__name__}[{repr_dict}]'
|
||||
return f'{self.__class__.__name__}[{len(self)}]'
|
||||
|
||||
def notify_del_entity(self, entity: Object):
|
||||
try:
|
||||
@ -163,3 +163,9 @@ class Objects:
|
||||
return h.get_first_index(self, filter_by=lambda x: x.belongs_to_entity(entity))
|
||||
except (StopIteration, AttributeError):
|
||||
return None
|
||||
|
||||
def reset(self):
|
||||
self._data = defaultdict(lambda: None)
|
||||
self._observers = set(self)
|
||||
self.pos_dict = defaultdict(list)
|
||||
|
||||
|
@ -23,3 +23,7 @@ class Walls(Collection):
|
||||
return super().by_pos(pos)[0]
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
|
Reference in New Issue
Block a user