From 30af16190edf40f87a82b5708bbc3d95de382956 Mon Sep 17 00:00:00 2001 From: romue Date: Tue, 18 May 2021 17:25:20 +0200 Subject: [PATCH] added viz. of agent collision --- environments/factory/simple_factory_getting_dirty.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/environments/factory/simple_factory_getting_dirty.py b/environments/factory/simple_factory_getting_dirty.py index 190f3c5..39dd988 100644 --- a/environments/factory/simple_factory_getting_dirty.py +++ b/environments/factory/simple_factory_getting_dirty.py @@ -42,8 +42,13 @@ class GettingDirty(BaseFactory): dirt = [Entity('dirt', [x, y], min(0.15+self.state[DIRT_INDEX, x, y], 1.5), 'scale') for x, y in np.argwhere(self.state[DIRT_INDEX] > h.IS_FREE_CELL)] walls = [Entity('wall', pos) for pos in np.argwhere(self.state[h.LEVEL_IDX] > h.IS_FREE_CELL)] - asset_str = lambda agent: f'agent{agent.i+1}violation' if (not agent.action_valid or agent.collision_vector[h.LEVEL_IDX] > 0)\ - else (f'agent{agent.i+1}valid' if self._is_clean_up_action(agent.action) else f'agent{agent.i+1}') + + def asset_str(agent): + cols = ' '.join([self.slice_strings[j] for j in agent.collisions]) + asset_str = f'agent{agent.i + 1}violation' if (not agent.action_valid or 'level' in cols or 'agent' in cols) \ + else (f'agent{agent.i + 1}valid' if self._is_clean_up_action(agent.action) else f'agent{agent.i + 1}') + return asset_str + agents = {f'agent{i+1}': [Entity(asset_str(agent), agent.pos)] for i, agent in enumerate(self.agent_states)} self.renderer.render(OrderedDict(dirt=dirt, wall=walls, **agents))