import re from collections import defaultdict from typing import Dict, List import numpy as np from marl_factory_grid.environment import constants as c from marl_factory_grid.environment.entity.object import Object from marl_factory_grid.environment.groups.utils import Combined from marl_factory_grid.utils.utility_classes import Floor from marl_factory_grid.utils.ray_caster import RayCaster from marl_factory_grid.utils.states import Gamestate from marl_factory_grid.utils import helpers as h class OBSBuilder(object): default_obs = [c.WALLS, c.OTHERS] @property def pomdp_d(self): """ Calculates the effective diameter of the POMDP observation space. :return: The calculated effective diameter. :rtype: int """ if self.pomdp_r: return (self.pomdp_r * 2) + 1 else: return 0 def __init__(self, level_shape: np.size, state: Gamestate, pomdp_r: int): """ OBSBuilder ========== The OBSBuilder class is responsible for constructing observations in the environment. :param level_shape: The shape of the level or environment. :type level_shape: np.size :param state: The current game state. :type state: marl_factory_grid.environment.state.Gamestate :param pomdp_r: The POMDP radius, influencing the size of the observation space. :type pomdp_r: int """ self.all_obs = dict() self.ray_caster = dict() self.level_shape = level_shape self.pomdp_r = pomdp_r self.obs_shape = (self.pomdp_d, self.pomdp_d) if self.pomdp_r else self.level_shape self.size = np.prod(self.obs_shape) self.obs_layers = dict() self.curr_lightmaps = dict() self._floortiles = defaultdict(list, {pos: [Floor(*pos)] for pos in state.entities.floorlist}) self.reset(state) def reset(self, state): """ Resets temporary information and constructs an empty observation array with possible placeholders. """ # Reset temporary information self.curr_lightmaps = dict() # Construct an empty obs (array) for possible placeholders self.all_obs[c.PLACEHOLDER] = np.full(self.obs_shape, 0, dtype=float) # Fill the all_obs-dict with all available entities self.all_obs.update({key: obj for key, obj in state.entities.obs_pairs}) return True def observation_space(self, state): """ Returns the observation space for a single agent or a tuple of spaces for multiple agents. :returns: The observation space for the agent(s). :rtype: gym.Space|Tuple """ from gymnasium.spaces import Tuple, Box self.reset(state) obsn = self.build_for_all(state) if len(state[c.AGENT]) == 1: space = Box(low=0, high=1, shape=next(x for x in obsn.values()).shape, dtype=np.float32) else: space = Tuple([Box(low=0, high=1, shape=obs.shape, dtype=np.float32) for obs in obsn.values()]) return space def named_observation_space(self, state): """ :returns: A dictionary of named observation spaces for all agents. :rtype: dict """ self.reset(state) return self.build_for_all(state) def build_for_all(self, state) -> (dict, dict): """ Builds observations for all agents in the environment. :returns: A dictionary of observations for all agents. :rtype: dict """ return {agent.name: self.build_for_agent(agent, state)[0] for agent in state[c.AGENT]} def build_named_for_all(self, state) -> Dict[str, Dict[str, np.ndarray]]: """ Builds named observations for all agents in the environment. :returns: A dictionary containing named observations for all agents. :rtype: dict """ named_obs_dict = {} for agent in state[c.AGENT]: obs, names = self.build_for_agent(agent, state) named_obs_dict[agent.name] = {'observation': obs, 'names': names} return named_obs_dict def place_entity_in_observation(self, obs_array, agent, e): """ Places the encoding of an entity in the observation array relative to the agent's position. :param obs_array: The observation array. :type obs_array: np.ndarray :param agent: the associated agent :type agent: Agent :param e: The entity to be placed in the observation. :type e: Entity """ x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r if not min([y, x]) < 0: try: obs_array[x, y] += e.encoding except IndexError: # Seemded to be visible but is out of range pass pass def build_for_agent(self, agent, state) -> (List[str], np.ndarray): """ Builds observations for a specific agent. :returns: A tuple containing a list of observation names and the corresponding observation array :rtype: Tuple[List[str], np.ndarray] """ try: agent_want_obs = self.obs_layers[agent.name] except KeyError: self._sort_and_name_observation_conf(agent) agent_want_obs = self.obs_layers[agent.name] # Handle in-grid observations aka visible observations (Things on the map, with pos) visible_entities = self.ray_caster[agent.name].visible_entities(state.entities.pos_dict) pre_sort_obs = defaultdict(lambda: np.zeros(self.obs_shape)) if self.pomdp_r: for e in set(visible_entities): self.place_entity_in_observation(pre_sort_obs[e.obs_tag], agent, e) else: for e in set(visible_entities): pre_sort_obs[e.obs_tag][e.x, e.y] += e.encoding pre_sort_obs = dict(pre_sort_obs) obs = np.zeros((len(agent_want_obs), self.obs_shape[0], self.obs_shape[1])) for idx, l_name in enumerate(agent_want_obs): try: obs[idx] = pre_sort_obs[l_name] except KeyError: if c.COMBINED in l_name: if combined := [pre_sort_obs[x] for x in self.all_obs[f'{c.COMBINED}({agent.name})'].names if x in pre_sort_obs]: obs[idx] = np.sum(combined, axis=0) elif l_name == c.PLACEHOLDER: obs[idx] = self.all_obs[c.PLACEHOLDER] else: try: e = self.all_obs[l_name] except KeyError: try: # Look for bound entity REPRs! pattern = re.compile(f'{re.escape(l_name)}' f'{re.escape("[")}(.*){re.escape("]")}' f'{re.escape("(")}{re.escape(agent.name)}{re.escape(")")}') name = next((key for key, val in self.all_obs.items() if pattern.search(str(val)) and isinstance(val, Object)), None) e = self.all_obs[name] except KeyError: try: e = next(v for k, v in self.all_obs.items() if l_name in k and agent.name in k) except StopIteration: print(f'# Check for spelling errors!') print(f'# No combination of "{l_name}" and "{agent.name}" could not be found in:') print(f'# {list(dict(self.all_obs).keys())}') print('#') print('# exiting...') print('#') exit(-99999) try: positional = e.var_has_position except AttributeError: positional = False if positional: # Seems to be not visible, so just skip it # obs[idx] = np.zeros((self.pomdp_d, self.pomdp_d)) # All good pass else: try: v = e.encodings except AttributeError: try: v = e.encoding except AttributeError: raise AttributeError(f'This env. expects Entity-Clases to report their "encoding"') try: np.put(obs[idx], range(len(v)), v, mode='raise') except TypeError: np.put(obs[idx], 0, v, mode='raise') except IndexError: raise ValueError(f'Max(obs.size) for {e.name}: {obs[idx].size}, but was: {len(v)}.') if self.pomdp_r: try: light_map = self.curr_lightmaps.get(agent.name, np.zeros(self.obs_shape)) light_map[:] = 0.0 visible_floor = self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False) for f in set(visible_floor): self.place_entity_in_observation(light_map, agent, f) # else: # for f in set(visible_floor): # light_map[f.x, f.y] += f.encoding self.curr_lightmaps[agent.name] = light_map except (KeyError, ValueError): pass return obs, self.obs_layers[agent.name] def _sort_and_name_observation_conf(self, agent): """ Builds the useable observation scheme per agent from conf.yaml. :param agent: The agent for whom the observation scheme is built. """ # Fixme: no asymetric shapes possible. self.ray_caster[agent.name] = RayCaster(agent, min(self.obs_shape)) obs_layers = [] for obs_str in agent.observations: if isinstance(obs_str, dict): obs_str, vals = h.get_first(obs_str.items()) else: vals = None if obs_str == c.SELF: obs_layers.append(agent.name) elif obs_str == c.DEFAULTS: obs_layers.extend(self.default_obs) elif obs_str == c.COMBINED: if isinstance(vals, str): vals = [vals] names = list() for val in vals: if val == c.SELF: names.append(agent.name) elif val == c.OTHERS: names.extend([x.name for x in agent.collection if x.name != agent.name]) else: names.append(val) combined = Combined(names, self.size, identifier=agent.name) self.all_obs[combined.name] = combined obs_layers.append(combined.name) elif obs_str == c.OTHERS: obs_layers.extend([x for x in self.all_obs if x != agent.name and x.startswith(f'{c.AGENT}[')]) elif obs_str == c.AGENT: obs_layers.extend([x for x in self.all_obs if x.startswith(f'{c.AGENT}[')]) else: obs_layers.append(obs_str) self.obs_layers[agent.name] = obs_layers self.curr_lightmaps[agent.name] = np.zeros(self.obs_shape)