From fbbf8d6f6e9b0bdc10c93856e8957539f24cf11c Mon Sep 17 00:00:00 2001 From: Steffen Illium Date: Mon, 30 Oct 2023 10:08:40 +0100 Subject: [PATCH] naming Functions --- marl_factory_grid/environment/entity/object.py | 11 +++++++++-- marl_factory_grid/utils/helpers.py | 12 ++++++++++++ marl_factory_grid/utils/observation_builder.py | 8 ++++++-- 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/marl_factory_grid/environment/entity/object.py b/marl_factory_grid/environment/entity/object.py index da77788..c1518be 100644 --- a/marl_factory_grid/environment/entity/object.py +++ b/marl_factory_grid/environment/entity/object.py @@ -2,6 +2,7 @@ from collections import defaultdict from typing import Union from marl_factory_grid.environment import constants as c +import marl_factory_grid.utils.helpers as h class _Object: @@ -30,8 +31,14 @@ class _Object: @property def name(self): if self._str_ident is not None: - return f'{self.__class__.__name__}[{self._str_ident}]' - return f'{self.__class__.__name__}#{self.u_int}' + name = f'{self.__class__.__name__}[{self._str_ident}]' + else: + name = f'{self.__class__.__name__}#{self.u_int}' + if self.bound_entity: + name = h.add_bound_name(name, self.bound_entity) + if self.var_has_position: + name = h.add_pos_name(name, self) + return name # @property # def name(self): diff --git a/marl_factory_grid/utils/helpers.py b/marl_factory_grid/utils/helpers.py index 8fd3d3a..ad27d29 100644 --- a/marl_factory_grid/utils/helpers.py +++ b/marl_factory_grid/utils/helpers.py @@ -232,3 +232,15 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''): except AttributeError: continue raise AttributeError(f'Class "{class_name}" was not found in "{folder_path.name}"', list(set(all_found_modules))) + + +def add_bound_name(name_str, bound_e): + return f'{name_str}({bound_e.identifier})' + + +def add_pos_name(name_str, bound_e): + if bound_e.var_has_pos: + return f'{name_str}({bound_e.pos})' + return name_str + + diff --git a/marl_factory_grid/utils/observation_builder.py b/marl_factory_grid/utils/observation_builder.py index b9d3eac..6bc0eae 100644 --- a/marl_factory_grid/utils/observation_builder.py +++ b/marl_factory_grid/utils/observation_builder.py @@ -8,6 +8,7 @@ from numba import njit from marl_factory_grid.environment import constants as c from marl_factory_grid.environment.groups.utils import Combined +import marl_factory_grid.utils.helpers as h from marl_factory_grid.utils.states import Gamestate from marl_factory_grid.utils.utility_classes import Floor @@ -118,13 +119,16 @@ class OBSBuilder(object): e = self.all_obs[l_name] except KeyError: try: - e = self.all_obs[f'{l_name}({agent.name})'] + # Look for bound entity names! + e = self.all_obs[h.add_bound_name(l_name, agent)] except KeyError: try: e = next(x for x in self.all_obs if l_name in x and agent.name in x) except StopIteration: raise KeyError( - f'Check typing! {l_name} could not be found in: {list(dict(self.all_obs).keys())}') + f'Check for spelling errors! \n ' + f'No combination of "{l_name} and {agent.name}" could not be found in:\n ' + f'{list(dict(self.all_obs).keys())}') try: positional = e.var_has_position