mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
fixed lightmap
This commit is contained in:
parent
374a38971a
commit
35a42d7d47
@ -74,6 +74,14 @@ class OBSBuilder(object):
|
|||||||
named_obs_dict[agent.name] = {'observation': obs, 'names': names}
|
named_obs_dict[agent.name] = {'observation': obs, 'names': names}
|
||||||
return named_obs_dict
|
return named_obs_dict
|
||||||
|
|
||||||
|
def place_entity_in_observation(self, obs_array, agent, e):
|
||||||
|
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
|
||||||
|
try:
|
||||||
|
obs_array[x, y] += e.encoding
|
||||||
|
except IndexError:
|
||||||
|
# Seemded to be visible but is out of range
|
||||||
|
pass
|
||||||
|
|
||||||
def build_for_agent(self, agent, state) -> (List[str], np.ndarray):
|
def build_for_agent(self, agent, state) -> (List[str], np.ndarray):
|
||||||
assert self._curr_env_step == state.curr_step, (
|
assert self._curr_env_step == state.curr_step, (
|
||||||
"The observation objekt has not been reset this state! Call 'reset_struc_obs_block(state)'"
|
"The observation objekt has not been reset this state! Call 'reset_struc_obs_block(state)'"
|
||||||
@ -89,12 +97,7 @@ class OBSBuilder(object):
|
|||||||
pre_sort_obs = defaultdict(lambda: np.zeros(self.obs_shape))
|
pre_sort_obs = defaultdict(lambda: np.zeros(self.obs_shape))
|
||||||
if self.pomdp_r:
|
if self.pomdp_r:
|
||||||
for e in set(visible_entitites):
|
for e in set(visible_entitites):
|
||||||
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
|
self.place_entity_in_observation(pre_sort_obs[e.obs_tag], agent, e)
|
||||||
try:
|
|
||||||
pre_sort_obs[e.obs_tag][x, y] += e.encoding
|
|
||||||
except IndexError:
|
|
||||||
# Seemded to be visible but is out or range
|
|
||||||
pass
|
|
||||||
else:
|
else:
|
||||||
for e in set(visible_entitites):
|
for e in set(visible_entitites):
|
||||||
pre_sort_obs[e.obs_tag][e.x, e.y] += e.encoding
|
pre_sort_obs[e.obs_tag][e.x, e.y] += e.encoding
|
||||||
@ -151,15 +154,16 @@ class OBSBuilder(object):
|
|||||||
if self.pomdp_r:
|
if self.pomdp_r:
|
||||||
try:
|
try:
|
||||||
light_map = np.zeros(self.obs_shape)
|
light_map = np.zeros(self.obs_shape)
|
||||||
visible_floor = set(self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False))
|
visible_floor = self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False)
|
||||||
if self.pomdp_r:
|
if self.pomdp_r:
|
||||||
# Fixme: This Sucks if the Map is too small!!
|
for f in set(visible_floor):
|
||||||
coords = [((f.x - agent.x) + self.pomdp_r, (f.y - agent.y) + self.pomdp_r) for f in visible_floor]
|
self.place_entity_in_observation(light_map, agent, f)
|
||||||
else:
|
else:
|
||||||
coords = [x.pos for x in visible_floor]
|
for f in set(visible_floor):
|
||||||
np.put(light_map, np.ravel_multi_index(np.asarray(coords).T, light_map.shape), 1)
|
light_map[f.x, f.y] += f.encoding
|
||||||
self.curr_lightmaps[agent.name] = light_map
|
self.curr_lightmaps[agent.name] = light_map
|
||||||
except (KeyError, ValueError):
|
except (KeyError, ValueError):
|
||||||
|
print()
|
||||||
pass
|
pass
|
||||||
return obs, self.obs_layers[agent.name]
|
return obs, self.obs_layers[agent.name]
|
||||||
|
|
||||||
|
@ -31,6 +31,10 @@ class RenderEntity:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Floor:
|
class Floor:
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoding(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return f"Floor({self.pos})"
|
return f"Floor({self.pos})"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user