mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
fixed lightmap
This commit is contained in:
parent
374a38971a
commit
35a42d7d47
@ -74,6 +74,14 @@ class OBSBuilder(object):
|
||||
named_obs_dict[agent.name] = {'observation': obs, 'names': names}
|
||||
return named_obs_dict
|
||||
|
||||
def place_entity_in_observation(self, obs_array, agent, e):
|
||||
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
|
||||
try:
|
||||
obs_array[x, y] += e.encoding
|
||||
except IndexError:
|
||||
# Seemded to be visible but is out of range
|
||||
pass
|
||||
|
||||
def build_for_agent(self, agent, state) -> (List[str], np.ndarray):
|
||||
assert self._curr_env_step == state.curr_step, (
|
||||
"The observation objekt has not been reset this state! Call 'reset_struc_obs_block(state)'"
|
||||
@ -89,12 +97,7 @@ class OBSBuilder(object):
|
||||
pre_sort_obs = defaultdict(lambda: np.zeros(self.obs_shape))
|
||||
if self.pomdp_r:
|
||||
for e in set(visible_entitites):
|
||||
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
|
||||
try:
|
||||
pre_sort_obs[e.obs_tag][x, y] += e.encoding
|
||||
except IndexError:
|
||||
# Seemded to be visible but is out or range
|
||||
pass
|
||||
self.place_entity_in_observation(pre_sort_obs[e.obs_tag], agent, e)
|
||||
else:
|
||||
for e in set(visible_entitites):
|
||||
pre_sort_obs[e.obs_tag][e.x, e.y] += e.encoding
|
||||
@ -151,15 +154,16 @@ class OBSBuilder(object):
|
||||
if self.pomdp_r:
|
||||
try:
|
||||
light_map = np.zeros(self.obs_shape)
|
||||
visible_floor = set(self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False))
|
||||
visible_floor = self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False)
|
||||
if self.pomdp_r:
|
||||
# Fixme: This Sucks if the Map is too small!!
|
||||
coords = [((f.x - agent.x) + self.pomdp_r, (f.y - agent.y) + self.pomdp_r) for f in visible_floor]
|
||||
for f in set(visible_floor):
|
||||
self.place_entity_in_observation(light_map, agent, f)
|
||||
else:
|
||||
coords = [x.pos for x in visible_floor]
|
||||
np.put(light_map, np.ravel_multi_index(np.asarray(coords).T, light_map.shape), 1)
|
||||
for f in set(visible_floor):
|
||||
light_map[f.x, f.y] += f.encoding
|
||||
self.curr_lightmaps[agent.name] = light_map
|
||||
except (KeyError, ValueError):
|
||||
print()
|
||||
pass
|
||||
return obs, self.obs_layers[agent.name]
|
||||
|
||||
|
@ -31,6 +31,10 @@ class RenderEntity:
|
||||
@dataclass
|
||||
class Floor:
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return 1
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f"Floor({self.pos})"
|
||||
|
Loading…
x
Reference in New Issue
Block a user