diff --git a/environments/factory/simple_factory_getting_dirty.py b/environments/factory/simple_factory_getting_dirty.py index 190f3c5..39dd988 100644 --- a/environments/factory/simple_factory_getting_dirty.py +++ b/environments/factory/simple_factory_getting_dirty.py @@ -42,8 +42,13 @@ class GettingDirty(BaseFactory): dirt = [Entity('dirt', [x, y], min(0.15+self.state[DIRT_INDEX, x, y], 1.5), 'scale') for x, y in np.argwhere(self.state[DIRT_INDEX] > h.IS_FREE_CELL)] walls = [Entity('wall', pos) for pos in np.argwhere(self.state[h.LEVEL_IDX] > h.IS_FREE_CELL)] - asset_str = lambda agent: f'agent{agent.i+1}violation' if (not agent.action_valid or agent.collision_vector[h.LEVEL_IDX] > 0)\ - else (f'agent{agent.i+1}valid' if self._is_clean_up_action(agent.action) else f'agent{agent.i+1}') + + def asset_str(agent): + cols = ' '.join([self.slice_strings[j] for j in agent.collisions]) + asset_str = f'agent{agent.i + 1}violation' if (not agent.action_valid or 'level' in cols or 'agent' in cols) \ + else (f'agent{agent.i + 1}valid' if self._is_clean_up_action(agent.action) else f'agent{agent.i + 1}') + return asset_str + agents = {f'agent{i+1}': [Entity(asset_str(agent), agent.pos)] for i, agent in enumerate(self.agent_states)} self.renderer.render(OrderedDict(dirt=dirt, wall=walls, **agents))