From b13dff925bba552bf87f1de4e1ae09956fc2650d Mon Sep 17 00:00:00 2001 From: Chanumask Date: Fri, 27 Oct 2023 13:08:01 +0200 Subject: [PATCH] bugfixes --- marl_factory_grid/configs/default_config.yaml | 15 +- .../environment/entity/entity.py | 6 +- .../environment/entity/object.py | 157 ++++++++++-------- marl_factory_grid/environment/entity/util.py | 5 +- .../environment/groups/collection.py | 18 +- .../environment/groups/global_entities.py | 1 - .../environment/groups/objects.py | 14 +- marl_factory_grid/environment/rules.py | 2 +- marl_factory_grid/modules/batteries/groups.py | 2 +- .../modules/destinations/rules.py | 3 +- marl_factory_grid/modules/items/groups.py | 8 +- marl_factory_grid/modules/zones/groups.py | 6 +- .../utils/observation_builder.py | 3 + 13 files changed, 141 insertions(+), 99 deletions(-) diff --git a/marl_factory_grid/configs/default_config.yaml b/marl_factory_grid/configs/default_config.yaml index 5e7b152..7612b27 100644 --- a/marl_factory_grid/configs/default_config.yaml +++ b/marl_factory_grid/configs/default_config.yaml @@ -5,6 +5,19 @@ Agents: - Noop - ItemAction Observations: + - Combined: + - Other + - Walls + - GlobalPosition + - Battery + - ChargePods + - DirtPiles + - Destinations + - Doors + - Items + - Inventory + - DropOffLocations + - Maintainers Wolfgang: Actions: - Noop @@ -64,8 +77,6 @@ Rules: done_at_collisions: false AssignGlobalPositions: {} DestinationReachAny: {} - DestinationReach: - n_dests: 1 DestinationSpawn: n_dests: 1 spawn_frequency: 5 diff --git a/marl_factory_grid/environment/entity/entity.py b/marl_factory_grid/environment/entity/entity.py index 9eec210..bd54ea7 100644 --- a/marl_factory_grid/environment/entity/entity.py +++ b/marl_factory_grid/environment/entity/entity.py @@ -12,8 +12,6 @@ from ...utils.utility_classes import RenderEntity class Entity(_Object, abc.ABC): """Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc...""" - _u_idx = defaultdict(lambda: 0) - @property def state(self): return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0) @@ -29,7 +27,6 @@ class Entity(_Object, abc.ABC): except AttributeError: return False - @property def var_can_move(self): try: @@ -51,7 +48,6 @@ class Entity(_Object, abc.ABC): except AttributeError: return False - @property def x(self): return self.pos[0] @@ -87,7 +83,7 @@ class Entity(_Object, abc.ABC): if valid := state.check_move_validity(self, next_pos): for observer in self.observers: observer.notify_del_entity(self) - self._view_directory = curr_pos[0]-next_pos[0], curr_pos[1]-next_pos[1] + self._view_directory = curr_pos[0] - next_pos[0], curr_pos[1] - next_pos[1] self._pos = next_pos for observer in self.observers: observer.notify_add_entity(self) diff --git a/marl_factory_grid/environment/entity/object.py b/marl_factory_grid/environment/entity/object.py index 8312e29..da77788 100644 --- a/marl_factory_grid/environment/entity/object.py +++ b/marl_factory_grid/environment/entity/object.py @@ -14,10 +14,7 @@ class _Object: @property def var_has_position(self): - try: - return self.pos != c.VALUE_NO_POS or False - except AttributeError: - return False + return False @property def var_can_be_bound(self): @@ -36,6 +33,17 @@ class _Object: return f'{self.__class__.__name__}[{self._str_ident}]' return f'{self.__class__.__name__}#{self.u_int}' + # @property + # def name(self): + # name = f"{self.__class__.__name__}" + # if self.bound_entity: + # name += f"[{self.bound_entity.name}]" + # if self._str_ident is not None: + # name += f"({self._str_ident})" + # else: + # name += f"(#{self.u_int})" + # return name + @property def identifier(self): if self._str_ident is not None: @@ -48,6 +56,7 @@ class _Object: return True def __init__(self, str_ident: Union[str, None] = None, **kwargs): + self._bound_entity = None self._observers = [] self._str_ident = str_ident self.u_int = self._identify_and_count_up() @@ -91,73 +100,83 @@ class _Object: def belongs_to_entity(self, entity): return self._bound_entity == entity - -class EnvObject(_Object): - """Objects that hold Information that are observable, but have no position on the environment grid. Inventories etc...""" - - _u_idx = defaultdict(lambda: 0) - @property - def obs_tag(self): - try: - return self._collection.name or self.name - except AttributeError: - return self.name + def bound_entity(self): + return self._bound_entity - @property - def var_is_blocking_light(self): - try: - return self._collection.var_is_blocking_light or False - except AttributeError: - return False + def bind_to(self, entity): + self._bound_entity = entity - @property - def var_can_be_bound(self): - try: - return self._collection.var_can_be_bound or False - except AttributeError: - return False + def unbind(self): + self._bound_entity = None - @property - def var_can_move(self): - try: - return self._collection.var_can_move or False - except AttributeError: - return False - @property - def var_is_blocking_pos(self): - try: - return self._collection.var_is_blocking_pos or False - except AttributeError: - return False - - @property - def var_has_position(self): - try: - return self._collection.var_has_position or False - except AttributeError: - return False - - @property - def var_can_collide(self): - try: - return self._collection.var_can_collide or False - except AttributeError: - return False - - @property - def encoding(self): - return c.VALUE_OCCUPIED_CELL - - def __init__(self, **kwargs): - super(EnvObject, self).__init__(**kwargs) - - def change_parent_collection(self, other_collection): - other_collection.add_item(self) - self._collection.delete_env_object(self) - self._collection = other_collection - return self._collection == other_collection - - def summarize_state(self): - return dict(name=str(self.name)) +# class EnvObject(_Object): +# """Objects that hold Information that are observable, but have no position on the environment grid. Inventories etc...""" +# + # _u_idx = defaultdict(lambda: 0) +# +# @property +# def obs_tag(self): +# try: +# return self._collection.name or self.name +# except AttributeError: +# return self.name +# +# @property +# def var_is_blocking_light(self): +# try: +# return self._collection.var_is_blocking_light or False +# except AttributeError: +# return False +# +# @property +# def var_can_be_bound(self): +# try: +# return self._collection.var_can_be_bound or False +# except AttributeError: +# return False +# +# @property +# def var_can_move(self): +# try: +# return self._collection.var_can_move or False +# except AttributeError: +# return False +# +# @property +# def var_is_blocking_pos(self): +# try: +# return self._collection.var_is_blocking_pos or False +# except AttributeError: +# return False +# +# @property +# def var_has_position(self): +# try: +# return self._collection.var_has_position or False +# except AttributeError: +# return False +# +# @property +# def var_can_collide(self): +# try: +# return self._collection.var_can_collide or False +# except AttributeError: +# return False +# +# @property +# def encoding(self): +# return c.VALUE_OCCUPIED_CELL +# +# def __init__(self, **kwargs): +# super(EnvObject, self).__init__(**kwargs) +# +# def change_parent_collection(self, other_collection): +# other_collection.add_item(self) +# self._collection.delete_env_object(self) +# self._collection = other_collection +# return self._collection == other_collection +# +# def summarize_state(self): +# return dict(name=str(self.name)) diff --git a/marl_factory_grid/environment/entity/util.py b/marl_factory_grid/environment/entity/util.py index 945ec87..1a5cbe3 100644 --- a/marl_factory_grid/environment/entity/util.py +++ b/marl_factory_grid/environment/entity/util.py @@ -1,9 +1,6 @@ -import math - import numpy as np -from marl_factory_grid.environment.entity.mixin import BoundEntityMixin -from marl_factory_grid.environment.entity.object import _Object, EnvObject +from marl_factory_grid.environment.entity.object import _Object ########################################################################## diff --git a/marl_factory_grid/environment/groups/collection.py b/marl_factory_grid/environment/groups/collection.py index 405fbfe..ee54c03 100644 --- a/marl_factory_grid/environment/groups/collection.py +++ b/marl_factory_grid/environment/groups/collection.py @@ -2,11 +2,11 @@ from typing import List, Tuple from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment.groups.objects import _Objects -from marl_factory_grid.environment.entity.object import EnvObject +from marl_factory_grid.environment.entity.object import _Object class Collection(_Objects): - _entity = EnvObject # entity? object? objects? + _entity = _Object # entity? @property def var_is_blocking_light(self): @@ -22,13 +22,13 @@ class Collection(_Objects): @property def var_has_position(self): - return False # alles was posmixin hat true + return False + + # @property + # def var_has_bound(self): + # return False # batteries, globalpos, inventories true @property - def var_has_bound(self): - return False # batteries, globalpos, inventories true - - @property # beide bounds hier? inventory can be bound def var_can_be_bound(self): return False @@ -40,12 +40,12 @@ class Collection(_Objects): super(Collection, self).__init__(*args, **kwargs) self.size = size - def add_item(self, item: EnvObject): + def add_item(self, item: Entity): assert self.var_has_position or (len(self) <= self.size) super(Collection, self).add_item(item) return self - def delete_env_object(self, env_object: EnvObject): + def delete_env_object(self, env_object): del self[env_object.name] def delete_env_object_by_name(self, name): diff --git a/marl_factory_grid/environment/groups/global_entities.py b/marl_factory_grid/environment/groups/global_entities.py index 9b0555a..7ab48dd 100644 --- a/marl_factory_grid/environment/groups/global_entities.py +++ b/marl_factory_grid/environment/groups/global_entities.py @@ -36,7 +36,6 @@ class Entities(_Objects): def guests_that_can_collide(self, pos): return[x for val in self.pos_dict[pos] for x in val if x.var_can_collide] - @property def empty_positions(self): empty_positions= [key for key in self.floorlist if self.pos_dict[key]] shuffle(empty_positions) diff --git a/marl_factory_grid/environment/groups/objects.py b/marl_factory_grid/environment/groups/objects.py index bea9521..87c9337 100644 --- a/marl_factory_grid/environment/groups/objects.py +++ b/marl_factory_grid/environment/groups/objects.py @@ -122,7 +122,7 @@ class _Objects: raise TypeError def __repr__(self): - repr_dict = { key: val for key, val in self._data.items() if key not in [c.WALLS]} + repr_dict = {key: val for key, val in self._data.items() if key not in [c.WALLS]} return f'{self.__class__.__name__}[{repr_dict}]' def spawn(self, n: int): @@ -169,3 +169,15 @@ class _Objects: # FIXME PROTOBUFF # return [e.summarize_state() for e in self] return [e.summarize_state() for e in self] + + def by_entity(self, entity): + try: + return next((x for x in self if x.belongs_to_entity(entity))) + except (StopIteration, AttributeError): + return None + + def idx_by_entity(self, entity): + try: + return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity))) + except (StopIteration, AttributeError): + return None diff --git a/marl_factory_grid/environment/rules.py b/marl_factory_grid/environment/rules.py index 79d4e27..f9678b0 100644 --- a/marl_factory_grid/environment/rules.py +++ b/marl_factory_grid/environment/rules.py @@ -49,7 +49,7 @@ class SpawnAgents(Rule): agent_conf = state.agents_conf # agents = Agents(lvl_map.size) agents = state[c.AGENT] - empty_positions = state.entities.empty_positions[:len(agent_conf)] + empty_positions = state.entities.empty_positions()[:len(agent_conf)] for agent_name in agent_conf: actions = agent_conf[agent_name]['actions'].copy() observations = agent_conf[agent_name]['observations'].copy() diff --git a/marl_factory_grid/modules/batteries/groups.py b/marl_factory_grid/modules/batteries/groups.py index cc0a09d..ee057aa 100644 --- a/marl_factory_grid/modules/batteries/groups.py +++ b/marl_factory_grid/modules/batteries/groups.py @@ -20,7 +20,7 @@ class Batteries(Collection): @property def var_has_position(self): - return True + return False @property def obs_tag(self): diff --git a/marl_factory_grid/modules/destinations/rules.py b/marl_factory_grid/modules/destinations/rules.py index 8773f2d..b5eb9f2 100644 --- a/marl_factory_grid/modules/destinations/rules.py +++ b/marl_factory_grid/modules/destinations/rules.py @@ -36,7 +36,6 @@ class DestinationReachAll(Rule): results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent)) return results - def on_check_done(self, state) -> List[DoneResult]: if all(x.was_reached() for x in state[d.DESTINATION]): return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)] @@ -56,7 +55,7 @@ class DestinationReachAny(DestinationReachAll): class DestinationSpawn(Rule): - def __init__(self, n_dests: int = 1, + def __init__(self, n_dests: int = 1, spawn_frequency: int = 5, spawn_mode: str = d.MODE_GROUPED): super(DestinationSpawn, self).__init__() self.n_dests = n_dests diff --git a/marl_factory_grid/modules/items/groups.py b/marl_factory_grid/modules/items/groups.py index 997f91b..707f743 100644 --- a/marl_factory_grid/modules/items/groups.py +++ b/marl_factory_grid/modules/items/groups.py @@ -5,7 +5,7 @@ from marl_factory_grid.environment import constants as c from marl_factory_grid.environment.groups.collection import Collection from marl_factory_grid.environment.groups.objects import _Objects -from marl_factory_grid.environment.groups.mixins import IsBoundMixin, HasBoundMixin +from marl_factory_grid.environment.groups.mixins import IsBoundMixin from marl_factory_grid.environment.entity.agent import Agent from marl_factory_grid.modules.items.entitites import Item, DropOffLocation @@ -45,6 +45,10 @@ class Items(Collection): class Inventory(IsBoundMixin, Collection): _accepted_objects = Item + @property + def var_can_be_bound(self): + return True + @property def obs_tag(self): return self.name @@ -69,7 +73,7 @@ class Inventory(IsBoundMixin, Collection): self._collection = collection -class Inventories(HasBoundMixin, _Objects): +class Inventories(_Objects): _entity = Inventory @property diff --git a/marl_factory_grid/modules/zones/groups.py b/marl_factory_grid/modules/zones/groups.py index 2f668f6..71eb329 100644 --- a/marl_factory_grid/modules/zones/groups.py +++ b/marl_factory_grid/modules/zones/groups.py @@ -3,10 +3,12 @@ from marl_factory_grid.modules.zones import Zone class Zones(_Objects): - symbol = None _entity = Zone - var_can_move = False + + @property + def var_can_move(self): + return False def __init__(self, *args, **kwargs): super(Zones, self).__init__(*args, can_collide=True, **kwargs) diff --git a/marl_factory_grid/utils/observation_builder.py b/marl_factory_grid/utils/observation_builder.py index 1377a92..b9d3eac 100644 --- a/marl_factory_grid/utils/observation_builder.py +++ b/marl_factory_grid/utils/observation_builder.py @@ -103,6 +103,7 @@ class OBSBuilder(object): obs = np.zeros((len(agent_want_obs), self.obs_shape[0], self.obs_shape[1])) for idx, l_name in enumerate(agent_want_obs): + print(l_name) try: obs[idx] = pre_sort_obs[l_name] except KeyError: @@ -141,6 +142,8 @@ class OBSBuilder(object): try: v = e.encoding except AttributeError: + print(e) + print(e.var_has_position) raise AttributeError(f'This env. expects Entity-Clases to report their "encoding"') try: np.put(obs[idx], range(len(v)), v, mode='raise')