diff --git a/README.md b/README.md index 34208d6..a1d2740 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,7 @@ General: level_name: rooms # 'double', 'large', 'simple', ... ``` ... or create your own , maybe with the help of [asciiflow.com](https://asciiflow.com/#/). -Make sure to use `#` as [Walls](marl_factory_grid/environment/entity/wall_floor.py), `-` as free (walkable) [Floor](marl_factory_grid/environment/entity/wall_floor.py)-Tiles, `D` for [Walls](./modules/doors/entities.py). +Make sure to use `#` as [Walls](marl_factory_grid/environment/entity/wall.py), `-` as free (walkable) [Floor](marl_factory_grid/environment/entity/wall.py)-Tiles, `D` for [Walls](./modules/doors/entities.py). Other Entites (define you own) may bring their own `Symbols` #### Entites diff --git a/marl_factory_grid/environment/entity/entity.py b/marl_factory_grid/environment/entity/entity.py index 8cac2aa..637827f 100644 --- a/marl_factory_grid/environment/entity/entity.py +++ b/marl_factory_grid/environment/entity/entity.py @@ -1,14 +1,15 @@ import abc +from collections import defaultdict import numpy as np +from .object import _Object from .. import constants as c -from .object import EnvObject -from ...utils.utility_classes import RenderEntity from ...utils.results import ActionResult +from ...utils.utility_classes import RenderEntity -class Entity(EnvObject, abc.ABC): +class Entity(_Object, abc.ABC): """Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc...""" @property @@ -19,6 +20,34 @@ class Entity(EnvObject, abc.ABC): def var_has_position(self): return self.pos != c.VALUE_NO_POS + @property + def var_is_blocking_light(self): + try: + return self._collection.var_is_blocking_light or False + except AttributeError: + return False + + @property + def var_can_move(self): + try: + return self._collection.var_can_move or False + except AttributeError: + return False + + @property + def var_is_blocking_pos(self): + try: + return self._collection.var_is_blocking_pos or False + except AttributeError: + return False + + @property + def var_can_collide(self): + try: + return self._collection.var_can_collide or False + except AttributeError: + return False + @property def x(self): return self.pos[0] @@ -54,7 +83,7 @@ class Entity(EnvObject, abc.ABC): if valid := state.check_move_validity(self, next_pos): for observer in self.observers: observer.notify_del_entity(self) - self._view_directory = curr_pos[0]-next_pos[0], curr_pos[1]-next_pos[1] + self._view_directory = curr_pos[0] - next_pos[0], curr_pos[1] - next_pos[1] self._pos = next_pos for observer in self.observers: observer.notify_add_entity(self) @@ -73,7 +102,7 @@ class Entity(EnvObject, abc.ABC): print(f'Objects of class "{self.__class__.__name__}" can not be bound to other entities.') exit() - def summarize_state(self) -> dict: # tile=str(self.tile.name) + def summarize_state(self) -> dict: return dict(name=str(self.name), x=int(self.x), y=int(self.y), can_collide=bool(self.var_can_collide)) @abc.abstractmethod @@ -82,3 +111,42 @@ class Entity(EnvObject, abc.ABC): def __repr__(self): return super(Entity, self).__repr__() + f'(@{self.pos})' + + @property + def obs_tag(self): + try: + return self._collection.name or self.name + except AttributeError: + return self.name + + @property + def encoding(self): + return c.VALUE_OCCUPIED_CELL + + def change_parent_collection(self, other_collection): + other_collection.add_item(self) + self._collection.delete_env_object(self) + self._collection = other_collection + return self._collection == other_collection + + @classmethod + def from_coordinates(cls, positions: [(int, int)], *args, entity_kwargs=None, **kwargs, ): + collection = cls(*args, **kwargs) + collection.add_items( + [cls._entity(tuple(pos), **entity_kwargs if entity_kwargs is not None else {}) for pos in positions]) + return collection + + def notify_del_entity(self, entity): + try: + self.pos_dict[entity.pos].remove(entity) + except (ValueError, AttributeError): + pass + + def by_pos(self, pos: (int, int)): + pos = tuple(pos) + try: + return self.state.entities.pos_dict[pos] + except StopIteration: + pass + except ValueError: + print() diff --git a/marl_factory_grid/environment/entity/object.py b/marl_factory_grid/environment/entity/object.py index 8e6e02c..8810baf 100644 --- a/marl_factory_grid/environment/entity/object.py +++ b/marl_factory_grid/environment/entity/object.py @@ -2,10 +2,10 @@ from collections import defaultdict from typing import Union from marl_factory_grid.environment import constants as c +import marl_factory_grid.utils.helpers as h -class Object: - +class _Object: """Generell Objects for Organisation and Maintanance such as Actions etc...""" _u_idx = defaultdict(lambda: 0) @@ -13,6 +13,17 @@ class Object: def __bool__(self): return True + @property + def var_has_position(self): + return False + + @property + def var_can_be_bound(self): + try: + return self._collection.var_can_be_bound or False + except AttributeError: + return False + @property def observers(self): return self._observers @@ -20,8 +31,14 @@ class Object: @property def name(self): if self._str_ident is not None: - return f'{self.__class__.__name__}[{self._str_ident}]' - return f'{self.__class__.__name__}#{self.u_int}' + name = f'{self.__class__.__name__}[{self._str_ident}]' + else: + name = f'{self.__class__.__name__}#{self.u_int}' + if self.bound_entity: + name = h.add_bound_name(name, self.bound_entity) + if self.var_has_position: + name = h.add_pos_name(name, self) + return name @property def identifier(self): @@ -35,6 +52,7 @@ class Object: return True def __init__(self, str_ident: Union[str, None] = None, **kwargs): + self._bound_entity = None self._observers = [] self._str_ident = str_ident self.u_int = self._identify_and_count_up() @@ -53,8 +71,8 @@ class Object: return hash(self.identifier) def _identify_and_count_up(self): - idx = Object._u_idx[self.__class__.__name__] - Object._u_idx[self.__class__.__name__] += 1 + idx = _Object._u_idx[self.__class__.__name__] + _Object._u_idx[self.__class__.__name__] += 1 return idx def set_collection(self, collection): @@ -70,75 +88,96 @@ class Object: def summarize_state(self): return dict() + def bind(self, entity): + # noinspection PyAttributeOutsideInit + self._bound_entity = entity + return c.VALID -class EnvObject(Object): - - """Objects that hold Information that are observable, but have no position on the environment grid. Inventories etc...""" - - _u_idx = defaultdict(lambda: 0) - - @property - def obs_tag(self): - try: - return self._collection.name or self.name - except AttributeError: - return self.name + def belongs_to_entity(self, entity): + return self._bound_entity == entity @property - def var_is_blocking_light(self): - try: - return self._collection.var_is_blocking_light or False - except AttributeError: - return False + def bound_entity(self): + return self._bound_entity - @property - def var_can_be_bound(self): - try: - return self._collection.var_can_be_bound or False - except AttributeError: - return False + def bind_to(self, entity): + self._bound_entity = entity - @property - def var_can_move(self): - try: - return self._collection.var_can_move or False - except AttributeError: - return False - - @property - def var_is_blocking_pos(self): - try: - return self._collection.var_is_blocking_pos or False - except AttributeError: - return False - - @property - def var_has_position(self): - try: - return self._collection.var_has_position or False - except AttributeError: - return False - - @property - def var_can_collide(self): - try: - return self._collection.var_can_collide or False - except AttributeError: - return False - - @property - def encoding(self): - return c.VALUE_OCCUPIED_CELL - - def __init__(self, **kwargs): + def unbind(self): self._bound_entity = None - super(EnvObject, self).__init__(**kwargs) - def change_parent_collection(self, other_collection): - other_collection.add_item(self) - self._collection.delete_env_object(self) - self._collection = other_collection - return self._collection == other_collection - def summarize_state(self): - return dict(name=str(self.name)) +# class EnvObject(_Object): +# """Objects that hold Information that are observable, but have no position on the environment grid. Inventories etc...""" +# + # _u_idx = defaultdict(lambda: 0) +# +# @property +# def obs_tag(self): +# try: +# return self._collection.name or self.name +# except AttributeError: +# return self.name +# +# @property +# def var_is_blocking_light(self): +# try: +# return self._collection.var_is_blocking_light or False +# except AttributeError: +# return False +# +# @property +# def var_can_be_bound(self): +# try: +# return self._collection.var_can_be_bound or False +# except AttributeError: +# return False +# +# @property +# def var_can_move(self): +# try: +# return self._collection.var_can_move or False +# except AttributeError: +# return False +# +# @property +# def var_is_blocking_pos(self): +# try: +# return self._collection.var_is_blocking_pos or False +# except AttributeError: +# return False +# +# @property +# def var_has_position(self): +# try: +# return self._collection.var_has_position or False +# except AttributeError: +# return False +# +# @property +# def var_can_collide(self): +# try: +# return self._collection.var_can_collide or False +# except AttributeError: +# return False +# +# +# @property +# def encoding(self): +# return c.VALUE_OCCUPIED_CELL +# +# +# def __init__(self, **kwargs): +# self._bound_entity = None +# super(EnvObject, self).__init__(**kwargs) +# +# +# def change_parent_collection(self, other_collection): +# other_collection.add_item(self) +# self._collection.delete_env_object(self) +# self._collection = other_collection +# return self._collection == other_collection +# +# +# def summarize_state(self): +# return dict(name=str(self.name)) diff --git a/marl_factory_grid/environment/entity/util.py b/marl_factory_grid/environment/entity/util.py index fbf0c4a..1a5cbe3 100644 --- a/marl_factory_grid/environment/entity/util.py +++ b/marl_factory_grid/environment/entity/util.py @@ -1,9 +1,6 @@ -import math - import numpy as np -from marl_factory_grid.environment.entity.mixin import BoundEntityMixin -from marl_factory_grid.environment.entity.object import Object, EnvObject +from marl_factory_grid.environment.entity.object import _Object ########################################################################## @@ -11,7 +8,7 @@ from marl_factory_grid.environment.entity.object import Object, EnvObject ########################################################################## -class PlaceHolder(Object): +class PlaceHolder(_Object): def __init__(self, *args, fill_value=0, **kwargs): super().__init__(*args, **kwargs) @@ -30,7 +27,7 @@ class PlaceHolder(Object): return "PlaceHolder" -class GlobalPosition(BoundEntityMixin, EnvObject): +class GlobalPosition(_Object): @property def encoding(self): diff --git a/marl_factory_grid/environment/entity/wall.py b/marl_factory_grid/environment/entity/wall.py new file mode 100644 index 0000000..3f0fb7c --- /dev/null +++ b/marl_factory_grid/environment/entity/wall.py @@ -0,0 +1,29 @@ +from marl_factory_grid.environment import constants as c +from marl_factory_grid.environment.entity.entity import Entity +from marl_factory_grid.utils.utility_classes import RenderEntity + + +class Wall(Entity): + + @property + def var_has_position(self): + return True + + @property + def var_can_collide(self): + return True + + @property + def encoding(self): + return c.VALUE_OCCUPIED_CELL + + def render(self): + return RenderEntity(c.WALL, self.pos) + + @property + def var_is_blocking_pos(self): + return True + + @property + def var_is_blocking_light(self): + return True diff --git a/marl_factory_grid/environment/entity/wall_floor.py b/marl_factory_grid/environment/entity/wall_floor.py deleted file mode 100644 index e8b153e..0000000 --- a/marl_factory_grid/environment/entity/wall_floor.py +++ /dev/null @@ -1,120 +0,0 @@ -from typing import List - -import numpy as np - -from marl_factory_grid.environment import constants as c -from marl_factory_grid.environment.entity.object import EnvObject -from marl_factory_grid.utils.utility_classes import RenderEntity -from marl_factory_grid.utils import helpers as h - - -class Floor(EnvObject): - - @property - def var_has_position(self): - return True - - @property - def var_can_collide(self): - return False - - @property - def var_can_move(self): - return False - - @property - def var_is_blocking_pos(self): - return False - - @property - def var_is_blocking_light(self): - return False - - @property - def encoding(self): - return c.VALUE_OCCUPIED_CELL - - # @property - # def guests_that_can_collide(self): - # return [x for x in self.guests if x.var_can_collide] - - @property - def guests(self): - return self._guests.values() - - @property - def x(self): - return self.pos[0] - - @property - def y(self): - return self.pos[1] - - @property - def is_blocked(self): - return any([x.var_is_blocking_pos for x in self.guests]) - - def __init__(self, pos, **kwargs): - super(Floor, self).__init__(**kwargs) - self._guests = dict() - self.pos = tuple(pos) - self._neighboring_floor: List[Floor] = list() - self._blocked_by = None - - def __len__(self): - return len(self._guests) - - def is_empty(self): - return not len(self._guests) - - def is_occupied(self): - return bool(len(self._guests)) - - def enter(self, guest, spawn=False): - same_pos = guest.name not in self._guests - not_blocked = not self.is_blocked - no_become_blocked_when_occupied = not (guest.var_is_blocking_pos and self.is_occupied()) - not_introduce_collision = not (spawn and guest.var_can_collide and any(x.var_can_collide for x in self.guests)) - if same_pos and not_blocked and no_become_blocked_when_occupied and not_introduce_collision: - self._guests.update({guest.name: guest}) - return c.VALID - else: - return c.NOT_VALID - - def leave(self, guest): - try: - del self._guests[guest.name] - except (ValueError, KeyError): - return c.NOT_VALID - return c.VALID - - def __repr__(self): - return f'{self.name}(@{self.pos})' - - def summarize_state(self, **_): - return dict(name=self.name, x=int(self.x), y=int(self.y)) - - def render(self): - return None - - -class Wall(Floor): - - @property - def var_can_collide(self): - return True - - @property - def encoding(self): - return c.VALUE_OCCUPIED_CELL - - def render(self): - return RenderEntity(c.WALL, self.pos) - - @property - def var_is_blocking_pos(self): - return True - - @property - def var_is_blocking_light(self): - return True diff --git a/marl_factory_grid/environment/groups/agents.py b/marl_factory_grid/environment/groups/agents.py index 0169f88..f4a6ac6 100644 --- a/marl_factory_grid/environment/groups/agents.py +++ b/marl_factory_grid/environment/groups/agents.py @@ -1,12 +1,21 @@ from marl_factory_grid.environment.entity.agent import Agent -from marl_factory_grid.environment.groups.env_objects import EnvObjects -from marl_factory_grid.environment.groups.mixins import PositionMixin +from marl_factory_grid.environment.groups.collection import Collection -class Agents(PositionMixin, EnvObjects): +class Agents(Collection): _entity = Agent - is_blocking_light = False - can_move = True + + @property + def var_is_blocking_light(self): + return False + + @property + def var_can_move(self): + return True + + @property + def var_has_position(self): + return True def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/marl_factory_grid/environment/groups/collection.py b/marl_factory_grid/environment/groups/collection.py new file mode 100644 index 0000000..640c3b4 --- /dev/null +++ b/marl_factory_grid/environment/groups/collection.py @@ -0,0 +1,128 @@ +from typing import List, Tuple, Union + +from marl_factory_grid.environment.entity.entity import Entity +from marl_factory_grid.environment.groups.objects import _Objects +from marl_factory_grid.environment.entity.object import _Object +import marl_factory_grid.environment.constants as c + + +class Collection(_Objects): + _entity = _Object # entity? + + @property + def var_is_blocking_light(self): + return False + + @property + def var_can_collide(self): + return False + + @property + def var_can_move(self): + return False + + @property + def var_has_position(self): + return False + + # @property + # def var_has_bound(self): + # return False # batteries, globalpos, inventories true + + @property + def var_can_be_bound(self): + return False + + @property + def encodings(self): + return [x.encoding for x in self] + + def __init__(self, size, *args, **kwargs): + super(Collection, self).__init__(*args, **kwargs) + self.size = size + + def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): # woihn mit den args + if isinstance(coords_or_quantity, int): + self.add_items([self._entity() for _ in range(coords_or_quantity)]) + else: + self.add_items([self._entity(pos) for pos in coords_or_quantity]) + return c.VALID + + def despawn(self, items: List[_Object]): + items = [items] if isinstance(items, _Object) else items + for item in items: + del self[item] + + def add_item(self, item: Entity): + assert self.var_has_position or (len(self) <= self.size) + super(Collection, self).add_item(item) + return self + + def delete_env_object(self, env_object): + del self[env_object.name] + + def delete_env_object_by_name(self, name): + del self[name] + + @property + def obs_pairs(self): + pair_list = [(self.name, self)] + try: + if self.var_can_be_bound: + pair_list.extend([(a.name, a) for a in self]) + except AttributeError: + pass + return pair_list + + def by_entity(self, entity): + try: + return next((x for x in self if x.belongs_to_entity(entity))) + except (StopIteration, AttributeError): + return None + + def idx_by_entity(self, entity): + try: + return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity))) + except (StopIteration, AttributeError): + return None + + def render(self): + if self.var_has_position: + return [y for y in [x.render() for x in self] if y is not None] + else: + return [] + + @classmethod + def from_coordinates(cls, positions: [(int, int)], *args, entity_kwargs=None, **kwargs, ): + collection = cls(*args, **kwargs) + collection.add_items( + [cls._entity(tuple(pos), **entity_kwargs if entity_kwargs is not None else {}) for pos in positions]) + return collection + + def __delitem__(self, name): + idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name) + try: + for observer in obj.observers: + observer.notify_del_entity(obj) + except AttributeError: + pass + super().__delitem__(name) + + def by_pos(self, pos: (int, int)): + pos = tuple(pos) + try: + return self.pos_dict[pos] + except StopIteration: + pass + except ValueError: + print() + + @property + def positions(self): + return [e.pos for e in self] + + def notify_del_entity(self, entity: Entity): + try: + self.pos_dict[entity.pos].remove(entity) + except (ValueError, AttributeError): + pass diff --git a/marl_factory_grid/environment/groups/env_objects.py b/marl_factory_grid/environment/groups/env_objects.py deleted file mode 100644 index 1113833..0000000 --- a/marl_factory_grid/environment/groups/env_objects.py +++ /dev/null @@ -1,31 +0,0 @@ -from marl_factory_grid.environment.groups.objects import Objects -from marl_factory_grid.environment.entity.object import EnvObject - - -class EnvObjects(Objects): - - _entity = EnvObject - var_is_blocking_light: bool = False - var_can_collide: bool = False - var_has_position: bool = False - var_can_move: bool = False - var_can_be_bound: bool = False - - @property - def encodings(self): - return [x.encoding for x in self] - - def __init__(self, size, *args, **kwargs): - super(EnvObjects, self).__init__(*args, **kwargs) - self.size = size - - def add_item(self, item: EnvObject): - assert self.var_has_position or (len(self) <= self.size) - super(EnvObjects, self).add_item(item) - return self - - def delete_env_object(self, env_object: EnvObject): - del self[env_object.name] - - def delete_env_object_by_name(self, name): - del self[name] diff --git a/marl_factory_grid/environment/groups/global_entities.py b/marl_factory_grid/environment/groups/global_entities.py index 60915f9..8bfc9fe 100644 --- a/marl_factory_grid/environment/groups/global_entities.py +++ b/marl_factory_grid/environment/groups/global_entities.py @@ -3,12 +3,12 @@ from operator import itemgetter from random import shuffle, random from typing import Dict -from marl_factory_grid.environment.groups.objects import Objects +from marl_factory_grid.environment.groups.objects import _Objects from marl_factory_grid.utils.helpers import POS_MASK -class Entities(Objects): - _entity = Objects +class Entities(_Objects): + _entity = _Objects @staticmethod def neighboring_positions(pos): @@ -34,13 +34,9 @@ class Entities(Objects): self.pos_dict = defaultdict(list) super().__init__() - # def all_floors(self): - # return[key for key, val in self.pos_dict.items() if any('floor' in x.name.lower() for x in val)] - def guests_that_can_collide(self, pos): return[x for val in self.pos_dict[pos] for x in val if x.var_can_collide] - @property def empty_positions(self): empty_positions = [key for key in self.floorlist if not self.pos_dict[key]] shuffle(empty_positions) @@ -92,8 +88,6 @@ class Entities(Objects): def by_pos(self, pos: (int, int)): return self.pos_dict[pos] - # found_entities = [y for y in (x.by_pos(pos) for x in self.values() if hasattr(x, 'by_pos')) if y is not None] - # return found_entities @property def positions(self): diff --git a/marl_factory_grid/environment/groups/mixins.py b/marl_factory_grid/environment/groups/mixins.py index 7171d43..48333ca 100644 --- a/marl_factory_grid/environment/groups/mixins.py +++ b/marl_factory_grid/environment/groups/mixins.py @@ -1,68 +1,4 @@ -from typing import List, Tuple - -import numpy as np - from marl_factory_grid.environment import constants as c -from marl_factory_grid.environment.entity.entity import Entity -from marl_factory_grid.environment.entity.wall_floor import Floor - - -class PositionMixin: - _entity = Entity - var_is_blocking_light: bool = True - var_can_collide: bool = True - var_has_position: bool = True - - def spawn(self, coords: List[Tuple[(int, int)]]): - self.add_items([self._entity(pos) for pos in coords]) - - def render(self): - return [y for y in [x.render() for x in self] if y is not None] - - # @classmethod - # def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs): - # collection = cls(*args, **kwargs) - # entities = [cls._entity(tile, str_ident=i, - # **entity_kwargs if entity_kwargs is not None else {}) - # for i, tile in enumerate(tiles)] - # collection.add_items(entities) - # return collection - - @classmethod - def from_coordinates(cls, positions: [(int, int)], *args, entity_kwargs=None, **kwargs, ): - collection = cls(*args, **kwargs) - collection.add_items( - [cls._entity(tuple(pos), **entity_kwargs if entity_kwargs is not None else {}) for pos in positions]) - return collection - - def __delitem__(self, name): - idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name) - try: - for observer in obj.observers: - observer.notify_del_entity(obj) - except AttributeError: - pass - super().__delitem__(name) - - def by_pos(self, pos: (int, int)): - pos = tuple(pos) - try: - return self.pos_dict[pos] - # return next(e for e in self if e.pos == pos) - except StopIteration: - pass - except ValueError: - print() - - @property - def positions(self): - return [e.pos for e in self] - - def notify_del_entity(self, entity: Entity): - try: - self.pos_dict[entity.pos].remove(entity) - except (ValueError, AttributeError): - pass # noinspection PyUnresolvedReferences,PyTypeChecker diff --git a/marl_factory_grid/environment/groups/objects.py b/marl_factory_grid/environment/groups/objects.py index 76346cd..d3f32af 100644 --- a/marl_factory_grid/environment/groups/objects.py +++ b/marl_factory_grid/environment/groups/objects.py @@ -3,12 +3,12 @@ from typing import List import numpy as np -from marl_factory_grid.environment.entity.object import Object +from marl_factory_grid.environment.entity.object import _Object import marl_factory_grid.environment.constants as c -class Objects: - _entity = Object +class _Objects: + _entity = _Object @property def observers(self): @@ -54,7 +54,6 @@ class Objects: assert self._data[item.name] is None, f'{item.name} allready exists!!!' self._data.update({item.name: item}) item.set_collection(self) - # self.notify_add_entity(item) for observer in self.observers: observer.notify_add_entity(item) return self @@ -123,36 +122,16 @@ class Objects: raise TypeError def __repr__(self): - repr_dict = { key: val for key, val in self._data.items() if key not in [c.WALLS]} + repr_dict = {key: val for key, val in self._data.items() if key not in [c.WALLS]} return f'{self.__class__.__name__}[{repr_dict}]' - def spawn(self, n: int): - self.add_items([self._entity() for _ in range(n)]) - return c.VALID - - def despawn(self, items: List[Object]): - items = [items] if isinstance(items, Object) else items - for item in items: - del self[item] - - # def notify_change_pos(self, entity: object): - # try: - # self.pos_dict[entity.last_pos].remove(entity) - # except (ValueError, AttributeError): - # pass - # if entity.var_has_position: - # try: - # self.pos_dict[entity.pos].append(entity) - # except (ValueError, AttributeError): - # pass - - def notify_del_entity(self, entity: Object): + def notify_del_entity(self, entity: _Object): try: self.pos_dict[entity.pos].remove(entity) except (AttributeError, ValueError, IndexError): pass - def notify_add_entity(self, entity: Object): + def notify_add_entity(self, entity: _Object): try: if self not in entity.observers: entity.add_observer(self) @@ -166,3 +145,15 @@ class Objects: # FIXME PROTOBUFF # return [e.summarize_state() for e in self] return [e.summarize_state() for e in self] + + def by_entity(self, entity): + try: + return next((x for x in self if x.belongs_to_entity(entity))) + except (StopIteration, AttributeError): + return None + + def idx_by_entity(self, entity): + try: + return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity))) + except (StopIteration, AttributeError): + return None diff --git a/marl_factory_grid/environment/groups/utils.py b/marl_factory_grid/environment/groups/utils.py index a3f90b0..5619041 100644 --- a/marl_factory_grid/environment/groups/utils.py +++ b/marl_factory_grid/environment/groups/utils.py @@ -1,17 +1,14 @@ from typing import List, Union -import numpy as np - from marl_factory_grid.environment.entity.util import GlobalPosition -from marl_factory_grid.environment.groups.env_objects import EnvObjects -from marl_factory_grid.environment.groups.mixins import PositionMixin, HasBoundMixin -from marl_factory_grid.environment.groups.objects import Objects -from marl_factory_grid.modules.zones import Zone -from marl_factory_grid.utils import helpers as h -from marl_factory_grid.environment import constants as c +from marl_factory_grid.environment.groups.collection import Collection -class Combined(PositionMixin, EnvObjects): +class Combined(Collection): + + @property + def var_has_position(self): + return True @property def name(self): @@ -35,11 +32,21 @@ class Combined(PositionMixin, EnvObjects): return [(name, None) for name in self.names] -class GlobalPositions(HasBoundMixin, EnvObjects): +class GlobalPositions(Collection): _entity = GlobalPosition - is_blocking_light = False, - can_collide = False + + @property + def var_is_blocking_light(self): + return False + + @property + def var_can_collide(self): + return False + + @property + def var_can_be_bound(self): + return True def __init__(self, *args, **kwargs): super(GlobalPositions, self).__init__(*args, **kwargs) diff --git a/marl_factory_grid/environment/groups/walls.py b/marl_factory_grid/environment/groups/walls.py index 7fcd939..2d85362 100644 --- a/marl_factory_grid/environment/groups/walls.py +++ b/marl_factory_grid/environment/groups/walls.py @@ -1,43 +1,22 @@ -import random -from typing import List, Tuple - from marl_factory_grid.environment import constants as c -from marl_factory_grid.environment.groups.env_objects import EnvObjects -from marl_factory_grid.environment.groups.mixins import PositionMixin -from marl_factory_grid.environment.entity.wall_floor import Wall, Floor +from marl_factory_grid.environment.entity.wall import Wall +from marl_factory_grid.environment.groups.collection import Collection -class Walls(PositionMixin, EnvObjects): +class Walls(Collection): _entity = Wall symbol = c.SYMBOL_WALL + @property + def var_has_position(self): + return True + def __init__(self, *args, **kwargs): super(Walls, self).__init__(*args, **kwargs) self._value = c.VALUE_OCCUPIED_CELL - #ToDo: Do we need this? Move to spawn methode? - # @classmethod - # def from_coordinates(cls, argwhere_coordinates, *args, **kwargs): - # tiles = cls(*args, **kwargs) - # # noinspection PyTypeChecker - # tiles.add_items([cls._entity(pos) for pos in argwhere_coordinates]) - # return tiles - def by_pos(self, pos: (int, int)): try: return super().by_pos(pos)[0] except IndexError: return None - - -class Floors(Walls): - _entity = Floor - symbol = c.SYMBOL_FLOOR - var_is_blocking_light: bool = False - var_can_collide: bool = False - - def __init__(self, *args, **kwargs): - super(Floors, self).__init__(*args, **kwargs) - self._value = c.VALUE_FREE_CELL - - diff --git a/marl_factory_grid/environment/rules.py b/marl_factory_grid/environment/rules.py index 79d4e27..f9678b0 100644 --- a/marl_factory_grid/environment/rules.py +++ b/marl_factory_grid/environment/rules.py @@ -49,7 +49,7 @@ class SpawnAgents(Rule): agent_conf = state.agents_conf # agents = Agents(lvl_map.size) agents = state[c.AGENT] - empty_positions = state.entities.empty_positions[:len(agent_conf)] + empty_positions = state.entities.empty_positions()[:len(agent_conf)] for agent_name in agent_conf: actions = agent_conf[agent_name]['actions'].copy() observations = agent_conf[agent_name]['observations'].copy() diff --git a/marl_factory_grid/modules/batteries/constants.py b/marl_factory_grid/modules/batteries/constants.py index cbf3be0..604f886 100644 --- a/marl_factory_grid/modules/batteries/constants.py +++ b/marl_factory_grid/modules/batteries/constants.py @@ -1,5 +1,3 @@ -from typing import NamedTuple, Union - # Battery Env CHARGE_PODS = 'ChargePods' BATTERIES = 'Batteries' diff --git a/marl_factory_grid/modules/batteries/entitites.py b/marl_factory_grid/modules/batteries/entitites.py index 0e8153e..b51f2dd 100644 --- a/marl_factory_grid/modules/batteries/entitites.py +++ b/marl_factory_grid/modules/batteries/entitites.py @@ -1,13 +1,15 @@ -from marl_factory_grid.environment.entity.mixin import BoundEntityMixin -from marl_factory_grid.environment.entity.object import EnvObject -from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment import constants as c +from marl_factory_grid.environment.entity.entity import Entity +from marl_factory_grid.environment.entity.object import _Object +from marl_factory_grid.modules.batteries import constants as b from marl_factory_grid.utils.utility_classes import RenderEntity -from marl_factory_grid.modules.batteries import constants as b +class Battery(_Object): -class Battery(BoundEntityMixin, EnvObject): + @property + def var_can_be_bound(self): + return True @property def is_discharged(self): @@ -47,9 +49,6 @@ class Battery(BoundEntityMixin, EnvObject): summary.update(dict(belongs_to=self._bound_entity.name, chargeLevel=self.charge_level)) return summary - def render(self): - return None - class Pod(Entity): @@ -66,8 +65,8 @@ class Pod(Entity): def charge_battery(self, battery: Battery): if battery.charge_level == 1.0: return c.NOT_VALID - # if sum(guest for guest in self.tile.guests if 'agent' in guest.name.lower()) > 1: - if sum(1 for key, val in self.state.entities.pos_dict[self.pos] for guest in val if 'agent' in guest.name.lower()) > 1: + if sum(1 for key, val in self.state.entities.pos_dict[self.pos] for guest in val if + 'agent' in guest.name.lower()) > 1: return c.NOT_VALID valid = battery.do_charge_action(self.charge_rate) return valid diff --git a/marl_factory_grid/modules/batteries/groups.py b/marl_factory_grid/modules/batteries/groups.py index ac51daf..8d9e060 100644 --- a/marl_factory_grid/modules/batteries/groups.py +++ b/marl_factory_grid/modules/batteries/groups.py @@ -1,13 +1,31 @@ -from marl_factory_grid.environment.groups.env_objects import EnvObjects -from marl_factory_grid.environment.groups.mixins import PositionMixin, HasBoundMixin +from typing import Union, List, Tuple + +from marl_factory_grid.environment.groups.collection import Collection from marl_factory_grid.modules.batteries.entitites import Pod, Battery -class Batteries(HasBoundMixin, EnvObjects): - +class Batteries(Collection): _entity = Battery - is_blocking_light: bool = False - can_collide: bool = False + + @property + def var_is_blocking_light(self): + return False + + @property + def var_can_collide(self): + return False + + @property + def var_can_move(self): + return False + + @property + def var_has_position(self): + return False + + @property + def var_can_be_bound(self): + return True @property def obs_tag(self): @@ -20,9 +38,14 @@ class Batteries(HasBoundMixin, EnvObjects): batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)] self.add_items(batteries) + # def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): hat keine pos + # agents = entity_args[0] + # initial_charge_level = entity_args[1] + # batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)] + # self.add_items(batteries) -class ChargePods(PositionMixin, EnvObjects): +class ChargePods(Collection): _entity = Pod def __init__(self, *args, **kwargs): diff --git a/marl_factory_grid/modules/batteries/rules.py b/marl_factory_grid/modules/batteries/rules.py index 3f492e2..e060629 100644 --- a/marl_factory_grid/modules/batteries/rules.py +++ b/marl_factory_grid/modules/batteries/rules.py @@ -49,7 +49,7 @@ class BatteryDecharge(Rule): self.per_action_costs = per_action_costs self.initial_charge = initial_charge - def on_init(self, state, lvl_map): + def on_init(self, state, lvl_map): # on reset? assert len(state[c.AGENT]), "There are no agents, did you already spawn them?" state[b.BATTERIES].spawn(state[c.AGENT], self.initial_charge) diff --git a/marl_factory_grid/modules/clean_up/groups.py b/marl_factory_grid/modules/clean_up/groups.py index a011b97..63e5898 100644 --- a/marl_factory_grid/modules/clean_up/groups.py +++ b/marl_factory_grid/modules/clean_up/groups.py @@ -1,15 +1,29 @@ -from marl_factory_grid.environment.groups.env_objects import EnvObjects -from marl_factory_grid.environment.groups.mixins import PositionMixin -from marl_factory_grid.modules.clean_up.entitites import DirtPile +from typing import Union, List, Tuple from marl_factory_grid.environment import constants as c from marl_factory_grid.utils.results import Result +from marl_factory_grid.environment.groups.collection import Collection +from marl_factory_grid.modules.clean_up.entitites import DirtPile -class DirtPiles(PositionMixin, EnvObjects): +class DirtPiles(Collection): _entity = DirtPile - is_blocking_light: bool = False - can_collide: bool = False + + @property + def var_is_blocking_light(self): + return False + + @property + def var_can_collide(self): + return False + + @property + def var_can_move(self): + return False + + @property + def var_has_position(self): + return True @property def amount(self): @@ -24,9 +38,10 @@ class DirtPiles(PositionMixin, EnvObjects): self.max_global_amount = max_global_amount self.max_local_amount = max_local_amount - def spawn(self, then_dirty_positions, amount_s) -> Result: + def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): + amount_s = entity_args[0] spawn_counter = 0 - for idx, pos in enumerate(then_dirty_positions): + for idx, pos in enumerate(coords_or_quantity): if not self.amount > self.max_global_amount: amount = amount_s[idx] if isinstance(amount_s, list) else amount_s if dirt := self.by_pos(pos): diff --git a/marl_factory_grid/modules/clean_up/rules.py b/marl_factory_grid/modules/clean_up/rules.py index 8afc116..3f58cdb 100644 --- a/marl_factory_grid/modules/clean_up/rules.py +++ b/marl_factory_grid/modules/clean_up/rules.py @@ -100,7 +100,7 @@ class EntitiesSmearDirtOnMove(Rule): if is_move(entity.state.identifier) and entity.state.validity == c.VALID: if old_pos_dirt := state[d.DIRT].by_pos(entity.last_pos): if smeared_dirt := round(old_pos_dirt.amount * self.smear_ratio, 2): - if state[d.DIRT].spawn(entity.pos, amount=smeared_dirt): # pos statt tile + if state[d.DIRT].spawn(entity.pos, amount=smeared_dirt): results.append(TickResult(identifier=self.name, entity=entity, reward=0, validity=c.VALID)) return results diff --git a/marl_factory_grid/modules/destinations/entitites.py b/marl_factory_grid/modules/destinations/entitites.py index 42669fd..7b866b7 100644 --- a/marl_factory_grid/modules/destinations/entitites.py +++ b/marl_factory_grid/modules/destinations/entitites.py @@ -1,21 +1,37 @@ from collections import defaultdict +from marl_factory_grid.environment import constants as c from marl_factory_grid.environment.entity.agent import Agent from marl_factory_grid.environment.entity.entity import Entity -from marl_factory_grid.environment import constants as c -from marl_factory_grid.environment.entity.mixin import BoundEntityMixin -from marl_factory_grid.utils.utility_classes import RenderEntity from marl_factory_grid.modules.destinations import constants as d +from marl_factory_grid.utils.utility_classes import RenderEntity -class Destination(BoundEntityMixin, Entity): +class Destination(Entity): - var_can_move = False - var_can_collide = False - var_has_position = True - var_is_blocking_pos = False - var_is_blocking_light = False - var_can_be_bound = True # Introduce this globally! + @property + def var_can_move(self): + return False + + @property + def var_can_collide(self): + return False + + @property + def var_has_position(self): + return True + + @property + def var_is_blocking_pos(self): + return False + + @property + def var_is_blocking_light(self): + return False + + @property + def var_can_be_bound(self): + return True def was_reached(self): return self._was_reached diff --git a/marl_factory_grid/modules/destinations/groups.py b/marl_factory_grid/modules/destinations/groups.py index a220c04..5f91bb4 100644 --- a/marl_factory_grid/modules/destinations/groups.py +++ b/marl_factory_grid/modules/destinations/groups.py @@ -1,14 +1,27 @@ -from marl_factory_grid.environment.groups.env_objects import EnvObjects -from marl_factory_grid.environment.groups.mixins import PositionMixin +from marl_factory_grid.environment.groups.collection import Collection from marl_factory_grid.modules.destinations.entitites import Destination from marl_factory_grid.environment import constants as c from marl_factory_grid.modules.destinations import constants as d -class Destinations(PositionMixin, EnvObjects): +class Destinations(Collection): _entity = Destination - is_blocking_light: bool = False - can_collide: bool = False + + @property + def var_is_blocking_light(self): + return False + + @property + def var_can_collide(self): + return False + + @property + def var_can_move(self): + return False + + @property + def var_has_position(self): + return True def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/marl_factory_grid/modules/doors/entitites.py b/marl_factory_grid/modules/doors/entitites.py index 0eea655..669f74e 100644 --- a/marl_factory_grid/modules/doors/entitites.py +++ b/marl_factory_grid/modules/doors/entitites.py @@ -12,7 +12,7 @@ class DoorIndicator(Entity): return d.VALUE_ACCESS_INDICATOR def render(self): - return None + return [] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/marl_factory_grid/modules/doors/groups.py b/marl_factory_grid/modules/doors/groups.py index 4a947f8..687846e 100644 --- a/marl_factory_grid/modules/doors/groups.py +++ b/marl_factory_grid/modules/doors/groups.py @@ -1,16 +1,19 @@ from typing import Union -from marl_factory_grid.environment.groups.env_objects import EnvObjects -from marl_factory_grid.environment.groups.mixins import PositionMixin +from marl_factory_grid.environment.groups.collection import Collection from marl_factory_grid.modules.doors import constants as d from marl_factory_grid.modules.doors.entitites import Door -class Doors(PositionMixin, EnvObjects): +class Doors(Collection): symbol = d.SYMBOL_DOOR _entity = Door + @property + def var_has_position(self): + return True + def __init__(self, *args, **kwargs): super(Doors, self).__init__(*args, can_collide=True, **kwargs) diff --git a/marl_factory_grid/modules/items/entitites.py b/marl_factory_grid/modules/items/entitites.py index 0e5def5..b710282 100644 --- a/marl_factory_grid/modules/items/entitites.py +++ b/marl_factory_grid/modules/items/entitites.py @@ -8,7 +8,9 @@ from marl_factory_grid.modules.items import constants as i class Item(Entity): - var_can_collide = False + @property + def var_can_collide(self): + return False def render(self): return RenderEntity(i.ITEM, self.pos) if self.pos != c.VALUE_NO_POS else None @@ -71,7 +73,7 @@ class DropOffLocation(Entity): def place_item(self, item: Item): if self.is_full: raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.") - return bc.NOT_VALID + return bc.NOT_VALID # in Zeile 81 verschieben? else: self.storage.append(item) item.set_auto_despawn(self.auto_item_despawn_interval) diff --git a/marl_factory_grid/modules/items/groups.py b/marl_factory_grid/modules/items/groups.py index 118b512..707f743 100644 --- a/marl_factory_grid/modules/items/groups.py +++ b/marl_factory_grid/modules/items/groups.py @@ -3,17 +3,27 @@ from random import shuffle from marl_factory_grid.modules.items import constants as i from marl_factory_grid.environment import constants as c -from marl_factory_grid.environment.groups.env_objects import EnvObjects -from marl_factory_grid.environment.groups.objects import Objects -from marl_factory_grid.environment.groups.mixins import PositionMixin, IsBoundMixin, HasBoundMixin +from marl_factory_grid.environment.groups.collection import Collection +from marl_factory_grid.environment.groups.objects import _Objects +from marl_factory_grid.environment.groups.mixins import IsBoundMixin from marl_factory_grid.environment.entity.agent import Agent from marl_factory_grid.modules.items.entitites import Item, DropOffLocation -class Items(PositionMixin, EnvObjects): +class Items(Collection): _entity = Item - is_blocking_light: bool = False - can_collide: bool = False + + @property + def var_has_position(self): + return False + + @property + def is_blocking_light(self): + return False + + @property + def can_collide(self): + return False def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -32,9 +42,13 @@ class Items(PositionMixin, EnvObjects): return 0 -class Inventory(IsBoundMixin, EnvObjects): +class Inventory(IsBoundMixin, Collection): _accepted_objects = Item + @property + def var_can_be_bound(self): + return True + @property def obs_tag(self): return self.name @@ -59,9 +73,12 @@ class Inventory(IsBoundMixin, EnvObjects): self._collection = collection -class Inventories(HasBoundMixin, Objects): +class Inventories(_Objects): _entity = Inventory - var_can_move = False + + @property + def var_can_move(self): + return False def __init__(self, size: int, *args, **kwargs): super(Inventories, self).__init__(*args, **kwargs) @@ -94,17 +111,31 @@ class Inventories(HasBoundMixin, Objects): state[i.INVENTORY].spawn(state[c.AGENT]) -class DropOffLocations(PositionMixin, EnvObjects): +class DropOffLocations(Collection): _entity = DropOffLocation - is_blocking_light: bool = False - can_collide: bool = False + + @property + def var_is_blocking_light(self): + return False + + @property + def var_can_collide(self): + return False + + @property + def var_can_move(self): + return False + + @property + def var_has_position(self): + return True def __init__(self, *args, **kwargs): super(DropOffLocations, self).__init__(*args, **kwargs) @staticmethod def trigger_drop_off_location_spawn(state, n_locations): - empty_positions = state.entities.empty_positions[:n_locations] + empty_positions = state.entities.empty_positions()[:n_locations] do_entites = state[i.DROP_OFF] drop_offs = [DropOffLocation(pos) for pos in empty_positions] do_entites.add_items(drop_offs) diff --git a/marl_factory_grid/modules/machines/groups.py b/marl_factory_grid/modules/machines/groups.py index f8a27e7..5f2d970 100644 --- a/marl_factory_grid/modules/machines/groups.py +++ b/marl_factory_grid/modules/machines/groups.py @@ -1,14 +1,26 @@ -from marl_factory_grid.environment.groups.env_objects import EnvObjects -from marl_factory_grid.environment.groups.mixins import PositionMixin +from typing import Union, List, Tuple + +from marl_factory_grid.environment.groups.collection import Collection from .entitites import Machine -class Machines(PositionMixin, EnvObjects): +class Machines(Collection): _entity = Machine - is_blocking_light: bool = False - can_collide: bool = False + + @property + def var_can_collide(self): + return False + + @property + def var_is_blocking_light(self): + return False + + @property + def var_has_position(self): + return True def __init__(self, *args, **kwargs): super(Machines, self).__init__(*args, **kwargs) + diff --git a/marl_factory_grid/modules/machines/rules.py b/marl_factory_grid/modules/machines/rules.py index 84cd4ba..84e3410 100644 --- a/marl_factory_grid/modules/machines/rules.py +++ b/marl_factory_grid/modules/machines/rules.py @@ -13,8 +13,7 @@ class MachineRule(Rule): self.n_machines = n_machines def on_init(self, state, lvl_map): - # TODO Move to spawn!!! - state[m.MACHINES].add_items(Machine(pos) for pos in state.entities.empty_positions()) + state[m.MACHINES].spawn(state.entities.empty_positions()) def tick_pre_step(self, state) -> List[TickResult]: pass diff --git a/marl_factory_grid/modules/maintenance/entities.py b/marl_factory_grid/modules/maintenance/entities.py index 3582197..e084b0c 100644 --- a/marl_factory_grid/modules/maintenance/entities.py +++ b/marl_factory_grid/modules/maintenance/entities.py @@ -84,14 +84,12 @@ class Maintainer(Entity): def _door_is_close(self, state): state.print("Found a door that is close.") try: - # return next(y for x in self.tile.neighboring_floor for y in x.guests if do.DOOR in y.name) return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name) except StopIteration: return None def _predict_move(self, state): next_pos = self._path[0] - # if len(state[c.FLOORS].by_pos(next_pos).guests_that_can_collide) > 0: if any(x for x in state.entities.pos_dict[next_pos] if x.var_can_collide) > 0: action = c.NOOP else: diff --git a/marl_factory_grid/modules/maintenance/groups.py b/marl_factory_grid/modules/maintenance/groups.py index 4a32a2b..2df70cb 100644 --- a/marl_factory_grid/modules/maintenance/groups.py +++ b/marl_factory_grid/modules/maintenance/groups.py @@ -1,25 +1,34 @@ -from typing import List +from typing import Union, List, Tuple +from marl_factory_grid.environment.groups.collection import Collection from .entities import Maintainer -from marl_factory_grid.environment.entity.wall_floor import Floor -from marl_factory_grid.environment.groups.env_objects import EnvObjects -from marl_factory_grid.environment.groups.mixins import PositionMixin +from ..machines import constants as mc from ..machines.actions import MachineAction from ...utils.states import Gamestate -from ..machines import constants as mc - - -class Maintainers(PositionMixin, EnvObjects): +class Maintainers(Collection): _entity = Maintainer - var_can_collide = True - var_can_move = True - var_is_blocking_light = False - var_has_position = True + + @property + def var_can_collide(self): + return True + + @property + def var_can_move(self): + return True + + @property + def var_is_blocking_light(self): + return False + + @property + def var_has_position(self): + return True def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def spawn(self, position, state: Gamestate): - self.add_items([self._entity(state, mc.MACHINES, MachineAction(), pos) for pos in position]) + def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): + state = entity_args[0] + self.add_items([self._entity(state, mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity]) diff --git a/marl_factory_grid/modules/maintenance/rules.py b/marl_factory_grid/modules/maintenance/rules.py index bf73d67..820183e 100644 --- a/marl_factory_grid/modules/maintenance/rules.py +++ b/marl_factory_grid/modules/maintenance/rules.py @@ -14,7 +14,6 @@ class MaintenanceRule(Rule): self.n_maintainer = n_maintainer def on_init(self, state: Gamestate, lvl_map): - # Move to spawn? : #TODO state[M.MAINTAINERS].spawn(state.entities.empty_positions[:self.n_maintainer], state) pass diff --git a/marl_factory_grid/modules/zones/entitites.py b/marl_factory_grid/modules/zones/entitites.py index f4923b7..cfd313f 100644 --- a/marl_factory_grid/modules/zones/entitites.py +++ b/marl_factory_grid/modules/zones/entitites.py @@ -1,15 +1,10 @@ import random from typing import List, Tuple -from marl_factory_grid.environment.entity.entity import Entity -from marl_factory_grid.environment.entity.object import Object -from marl_factory_grid.utils.utility_classes import RenderEntity -from marl_factory_grid.environment import constants as c - -from marl_factory_grid.modules.doors import constants as d +from marl_factory_grid.environment.entity.object import _Object -class Zone(Object): +class Zone(_Object): @property def positions(self): diff --git a/marl_factory_grid/modules/zones/groups.py b/marl_factory_grid/modules/zones/groups.py index e706a29..71eb329 100644 --- a/marl_factory_grid/modules/zones/groups.py +++ b/marl_factory_grid/modules/zones/groups.py @@ -1,12 +1,14 @@ -from marl_factory_grid.environment.groups.objects import Objects +from marl_factory_grid.environment.groups.objects import _Objects from marl_factory_grid.modules.zones import Zone -class Zones(Objects): - +class Zones(_Objects): symbol = None _entity = Zone - var_can_move = False + + @property + def var_can_move(self): + return False def __init__(self, *args, **kwargs): super(Zones, self).__init__(*args, can_collide=True, **kwargs) diff --git a/marl_factory_grid/utils/helpers.py b/marl_factory_grid/utils/helpers.py index 8fd3d3a..e2f3c9a 100644 --- a/marl_factory_grid/utils/helpers.py +++ b/marl_factory_grid/utils/helpers.py @@ -232,3 +232,15 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''): except AttributeError: continue raise AttributeError(f'Class "{class_name}" was not found in "{folder_path.name}"', list(set(all_found_modules))) + + +def add_bound_name(name_str, bound_e): + return f'{name_str}({bound_e.name})' + + +def add_pos_name(name_str, bound_e): + if bound_e.var_has_position: + return f'{name_str}({bound_e.pos})' + return name_str + + diff --git a/marl_factory_grid/utils/observation_builder.py b/marl_factory_grid/utils/observation_builder.py index 678fba7..9fd1d26 100644 --- a/marl_factory_grid/utils/observation_builder.py +++ b/marl_factory_grid/utils/observation_builder.py @@ -1,4 +1,5 @@ import math +import re from collections import defaultdict from itertools import product from typing import Dict, List @@ -8,12 +9,12 @@ from numba import njit from marl_factory_grid.environment import constants as c from marl_factory_grid.environment.groups.utils import Combined +import marl_factory_grid.utils.helpers as h from marl_factory_grid.utils.states import Gamestate from marl_factory_grid.utils.utility_classes import Floor class OBSBuilder(object): - default_obs = [c.WALLS, c.OTHERS] @property @@ -93,13 +94,13 @@ class OBSBuilder(object): agent_want_obs = self.obs_layers[agent.name] # Handle in-grid observations aka visible observations (Things on the map, with pos) - visible_entitites = self.ray_caster[agent.name].visible_entities(state.entities.pos_dict) - pre_sort_obs = defaultdict(lambda: np.zeros(self.obs_shape)) + visible_entities = self.ray_caster[agent.name].visible_entities(state.entities.pos_dict) + pre_sort_obs = defaultdict(lambda: np.zeros(self.obs_shape)) if self.pomdp_r: - for e in set(visible_entitites): + for e in set(visible_entities): self.place_entity_in_observation(pre_sort_obs[e.obs_tag], agent, e) else: - for e in set(visible_entitites): + for e in set(visible_entities): pre_sort_obs[e.obs_tag][e.x, e.y] += e.encoding pre_sort_obs = dict(pre_sort_obs) @@ -120,13 +121,18 @@ class OBSBuilder(object): e = self.all_obs[l_name] except KeyError: try: - e = self.all_obs[f'{l_name}({agent.name})'] + # Look for bound entity names! + pattern = re.compile(f'{re.escape(l_name)}(.*){re.escape(agent.name)}') + name = next((x for x in self.all_obs if pattern.search(x)), None) + e = self.all_obs[name] except KeyError: try: - e = next(x for x in self.all_obs if l_name in x and agent.name in x) + e = next(v for k, v in self.all_obs.items() if l_name in k and agent.name in k) except StopIteration: raise KeyError( - f'Check typing! {l_name} could not be found in: {list(dict(self.all_obs).keys())}') + f'Check for spelling errors! \n ' + f'No combination of "{l_name} and {agent.name}" could not be found in:\n ' + f'{list(dict(self.all_obs).keys())}') try: positional = e.var_has_position @@ -224,7 +230,7 @@ class RayCaster: return f'{self.__class__.__name__}({self.agent.name})' def build_ray_targets(self): - north = np.array([0, -1])*self.pomdp_r + north = np.array([0, -1]) * self.pomdp_r thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]] rot_M = [ [[math.cos(theta), -math.sin(theta)], @@ -257,8 +263,9 @@ class RayCaster: diag_hits = all([ self.ray_block_cache( key, - lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light) and bool(pos_dict[key])) - for key in ((x, y-cy), (x-cx, y)) + lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light) and bool( + pos_dict[key])) + for key in ((x, y - cy), (x - cx, y)) ]) if (cx != 0 and cy != 0) else False visible += entities_hit if not diag_hits else [] diff --git a/marl_factory_grid/utils/states.py b/marl_factory_grid/utils/states.py index 953d2d1..4c1f7f2 100644 --- a/marl_factory_grid/utils/states.py +++ b/marl_factory_grid/utils/states.py @@ -57,7 +57,7 @@ class Gamestate(object): @property def moving_entites(self): - return [y for x in self.entities for y in x if x.var_can_move] # wird das aus dem String gelesen? + return [y for x in self.entities for y in x if x.var_can_move] def __init__(self, entities, agents_conf, rules: Dict[str, dict], env_seed=69, verbose=False): self.entities = entities @@ -114,15 +114,12 @@ class Gamestate(object): results.extend(on_check_done_result) return results - def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]: positions = [pos for pos, entity_list_for_position in self.entities.pos_dict.items() if any([e.var_can_collide for e in entity_list_for_position])] return positions def check_move_validity(self, moving_entity, position): - # if (guest.name not in self._guests and not self.is_blocked) - # and not (guest.var_is_blocking_pos and self.is_occupied()): if moving_entity.pos != position and not any( entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]) and not ( moving_entity.var_is_blocking_pos and self.entities.is_occupied(position)):