diff --git a/marl_factory_grid/utils/observation_builder.py b/marl_factory_grid/utils/observation_builder.py index bb5bf1e..678fba7 100644 --- a/marl_factory_grid/utils/observation_builder.py +++ b/marl_factory_grid/utils/observation_builder.py @@ -74,6 +74,14 @@ class OBSBuilder(object): named_obs_dict[agent.name] = {'observation': obs, 'names': names} return named_obs_dict + def place_entity_in_observation(self, obs_array, agent, e): + x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r + try: + obs_array[x, y] += e.encoding + except IndexError: + # Seemded to be visible but is out of range + pass + def build_for_agent(self, agent, state) -> (List[str], np.ndarray): assert self._curr_env_step == state.curr_step, ( "The observation objekt has not been reset this state! Call 'reset_struc_obs_block(state)'" @@ -89,12 +97,7 @@ class OBSBuilder(object): pre_sort_obs = defaultdict(lambda: np.zeros(self.obs_shape)) if self.pomdp_r: for e in set(visible_entitites): - x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r - try: - pre_sort_obs[e.obs_tag][x, y] += e.encoding - except IndexError: - # Seemded to be visible but is out or range - pass + self.place_entity_in_observation(pre_sort_obs[e.obs_tag], agent, e) else: for e in set(visible_entitites): pre_sort_obs[e.obs_tag][e.x, e.y] += e.encoding @@ -151,15 +154,16 @@ class OBSBuilder(object): if self.pomdp_r: try: light_map = np.zeros(self.obs_shape) - visible_floor = set(self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False)) + visible_floor = self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False) if self.pomdp_r: - # Fixme: This Sucks if the Map is too small!! - coords = [((f.x - agent.x) + self.pomdp_r, (f.y - agent.y) + self.pomdp_r) for f in visible_floor] + for f in set(visible_floor): + self.place_entity_in_observation(light_map, agent, f) else: - coords = [x.pos for x in visible_floor] - np.put(light_map, np.ravel_multi_index(np.asarray(coords).T, light_map.shape), 1) + for f in set(visible_floor): + light_map[f.x, f.y] += f.encoding self.curr_lightmaps[agent.name] = light_map except (KeyError, ValueError): + print() pass return obs, self.obs_layers[agent.name] diff --git a/marl_factory_grid/utils/utility_classes.py b/marl_factory_grid/utils/utility_classes.py index 5574a81..4844133 100644 --- a/marl_factory_grid/utils/utility_classes.py +++ b/marl_factory_grid/utils/utility_classes.py @@ -31,6 +31,10 @@ class RenderEntity: @dataclass class Floor: + @property + def encoding(self): + return 1 + @property def name(self): return f"Floor({self.pos})"