mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-11-02 13:37:27 +01:00
fixed render funciton and obsbuilder
This commit is contained in:
@@ -15,7 +15,6 @@ from marl_factory_grid.utils.utility_classes import Floor
|
||||
|
||||
|
||||
class OBSBuilder(object):
|
||||
|
||||
default_obs = [c.WALLS, c.OTHERS]
|
||||
|
||||
@property
|
||||
@@ -95,20 +94,19 @@ class OBSBuilder(object):
|
||||
agent_want_obs = self.obs_layers[agent.name]
|
||||
|
||||
# Handle in-grid observations aka visible observations (Things on the map, with pos)
|
||||
visible_entitites = self.ray_caster[agent.name].visible_entities(state.entities.pos_dict)
|
||||
pre_sort_obs = defaultdict(lambda: np.zeros(self.obs_shape))
|
||||
visible_entities = self.ray_caster[agent.name].visible_entities(state.entities.pos_dict)
|
||||
pre_sort_obs = defaultdict(lambda: np.zeros(self.obs_shape))
|
||||
if self.pomdp_r:
|
||||
for e in set(visible_entitites):
|
||||
for e in set(visible_entities):
|
||||
self.place_entity_in_observation(pre_sort_obs[e.obs_tag], agent, e)
|
||||
else:
|
||||
for e in set(visible_entitites):
|
||||
for e in set(visible_entities):
|
||||
pre_sort_obs[e.obs_tag][e.x, e.y] += e.encoding
|
||||
|
||||
pre_sort_obs = dict(pre_sort_obs)
|
||||
obs = np.zeros((len(agent_want_obs), self.obs_shape[0], self.obs_shape[1]))
|
||||
|
||||
for idx, l_name in enumerate(agent_want_obs):
|
||||
print(l_name)
|
||||
try:
|
||||
obs[idx] = pre_sort_obs[l_name]
|
||||
except KeyError:
|
||||
@@ -125,12 +123,11 @@ class OBSBuilder(object):
|
||||
try:
|
||||
# Look for bound entity names!
|
||||
pattern = re.compile(f'{re.escape(l_name)}(.*){re.escape(agent.name)}')
|
||||
print(pattern)
|
||||
name = next((x for x in self.all_obs if pattern.search(x)), None)
|
||||
e = self.all_obs[name]
|
||||
except KeyError:
|
||||
try:
|
||||
e = next(v for k in self.all_obs.items() if l_name in k and agent.name in k)
|
||||
e = next(v for k, v in self.all_obs.items() if l_name in k and agent.name in k)
|
||||
except StopIteration:
|
||||
raise KeyError(
|
||||
f'Check for spelling errors! \n '
|
||||
@@ -233,7 +230,7 @@ class RayCaster:
|
||||
return f'{self.__class__.__name__}({self.agent.name})'
|
||||
|
||||
def build_ray_targets(self):
|
||||
north = np.array([0, -1])*self.pomdp_r
|
||||
north = np.array([0, -1]) * self.pomdp_r
|
||||
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
|
||||
rot_M = [
|
||||
[[math.cos(theta), -math.sin(theta)],
|
||||
@@ -266,8 +263,9 @@ class RayCaster:
|
||||
diag_hits = all([
|
||||
self.ray_block_cache(
|
||||
key,
|
||||
lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light) and bool(pos_dict[key]))
|
||||
for key in ((x, y-cy), (x-cx, y))
|
||||
lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light) and bool(
|
||||
pos_dict[key]))
|
||||
for key in ((x, y - cy), (x - cx, y))
|
||||
]) if (cx != 0 and cy != 0) else False
|
||||
|
||||
visible += entities_hit if not diag_hits else []
|
||||
|
||||
Reference in New Issue
Block a user