diff --git a/marl_factory_grid/utils/helpers.py b/marl_factory_grid/utils/helpers.py index e2cd1eb..e2f3c9a 100644 --- a/marl_factory_grid/utils/helpers.py +++ b/marl_factory_grid/utils/helpers.py @@ -239,7 +239,7 @@ def add_bound_name(name_str, bound_e): def add_pos_name(name_str, bound_e): - if bound_e.var_has_pos: + if bound_e.var_has_position: return f'{name_str}({bound_e.pos})' return name_str diff --git a/marl_factory_grid/utils/observation_builder.py b/marl_factory_grid/utils/observation_builder.py index 6bc0eae..f5aa6ec 100644 --- a/marl_factory_grid/utils/observation_builder.py +++ b/marl_factory_grid/utils/observation_builder.py @@ -1,4 +1,5 @@ import math +import re from collections import defaultdict from itertools import product from typing import Dict, List @@ -120,10 +121,13 @@ class OBSBuilder(object): except KeyError: try: # Look for bound entity names! - e = self.all_obs[h.add_bound_name(l_name, agent)] + pattern = re.compile(f'{re.escape(l_name)}(.*){re.escape(agent.name)}') + print(pattern) + name = next((x for x in self.all_obs if pattern.search(x)), None) + e = self.all_obs[name] except KeyError: try: - e = next(x for x in self.all_obs if l_name in x and agent.name in x) + e = next(v for k in self.all_obs.items() if l_name in k and agent.name in k) except StopIteration: raise KeyError( f'Check for spelling errors! \n ' @@ -146,8 +150,6 @@ class OBSBuilder(object): try: v = e.encoding except AttributeError: - print(e) - print(e.var_has_position) raise AttributeError(f'This env. expects Entity-Clases to report their "encoding"') try: np.put(obs[idx], range(len(v)), v, mode='raise')