From 9b28b011258e87f71a8477e61965d85348e8fdd6 Mon Sep 17 00:00:00 2001 From: romue Date: Tue, 18 May 2021 14:29:55 +0200 Subject: [PATCH] added vizualization for violations --- .../factory/simple_factory_getting_dirty.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/environments/factory/simple_factory_getting_dirty.py b/environments/factory/simple_factory_getting_dirty.py index 74501eb..e3cb5fe 100644 --- a/environments/factory/simple_factory_getting_dirty.py +++ b/environments/factory/simple_factory_getting_dirty.py @@ -39,16 +39,12 @@ class GettingDirty(BaseFactory): height, width = self.state.shape[1:] self.renderer = Renderer(width, height, view_radius=2) - dirt = [Entity('dirt', [x, y], min(1.1*self.state[DIRT_INDEX, x, y], 1), 'opacity') - for x, y in np.argwhere(self.state[DIRT_INDEX] > h.IS_FREE_CELL)] - walls = [Entity('wall', pos) for pos in np.argwhere(self.state[h.LEVEL_IDX] > h.IS_FREE_CELL)] - - agents = {f'agent{i+1}': [Entity(f'agent{i+1}' - if (agent.action_valid and agent.collision_vector[h.LEVEL_IDX] <= 0) else f'agent{i+1}violation', - agent.pos) - ] + dirt = [Entity('dirt', [x, y], min(1.1*self.state[DIRT_INDEX, x, y], 1), 'opacity') + for x, y in np.argwhere(self.state[DIRT_INDEX] > h.IS_FREE_CELL)] + walls = [Entity('wall', pos) for pos in np.argwhere(self.state[h.LEVEL_IDX] > h.IS_FREE_CELL)] + violation = lambda agent: agent.action_valid and agent.collision_vector[h.LEVEL_IDX] <= 0 + agents = {f'agent{i+1}': [Entity(f'agent{i+1}' if violation(agent) else f'agent{i+1}violation', agent.pos)] for i, agent in enumerate(self.agent_states)} - print(agents) self.renderer.render(OrderedDict(dirt=dirt, wall=walls, **agents)) def spawn_dirt(self) -> None: