mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
added vizualization for violations
This commit is contained in:
parent
5a6a444af4
commit
a6793e11ca
@ -67,6 +67,7 @@ class BaseFactory(gym.Env):
|
||||
self.steps = 0
|
||||
self.cumulative_reward = 0
|
||||
self.monitor = FactoryMonitor(self)
|
||||
self.agent_states = []
|
||||
# Agent placement ...
|
||||
agents = np.zeros((self.n_agents, *self.level.shape), dtype=np.int8)
|
||||
floor_tiles = np.argwhere(self.level == h.IS_FREE_CELL)
|
||||
@ -74,9 +75,13 @@ class BaseFactory(gym.Env):
|
||||
np.random.shuffle(floor_tiles)
|
||||
for i, (x, y) in enumerate(floor_tiles[:self.n_agents]):
|
||||
agents[i, x, y] = h.IS_OCCUPIED_CELL
|
||||
agent_state = AgentState(i, -1)
|
||||
agent_state.update(pos=[x, y])
|
||||
self.agent_states.append(agent_state)
|
||||
# state.shape = level, agent 1,..., agent n,
|
||||
self.state = np.concatenate((np.expand_dims(self.level, axis=0), agents), axis=0)
|
||||
# Returns State, Reward, Done, Info
|
||||
|
||||
return self.state, 0, self.done, {}
|
||||
|
||||
def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
||||
@ -102,6 +107,7 @@ class BaseFactory(gym.Env):
|
||||
for i, collision_vec in enumerate(self.check_all_collisions(states, self.state.shape[0])):
|
||||
states[i].update(collision_vector=collision_vec)
|
||||
|
||||
self.agent_states = states
|
||||
reward, info = self.calculate_reward(states)
|
||||
self.cumulative_reward += reward
|
||||
|
||||
|
@ -41,9 +41,9 @@ class Renderer:
|
||||
rect = pygame.Rect(x, y, self.cell_size, self.cell_size)
|
||||
pygame.draw.rect(self.screen, Renderer.WHITE, rect, 1)
|
||||
|
||||
def blit_params(self, entity, name):
|
||||
def blit_params(self, entity):
|
||||
r, c = entity.pos
|
||||
img = self.assets[name]
|
||||
img = self.assets[entity.name]
|
||||
if entity.value_operation == 'opacity':
|
||||
img.set_alpha(255*entity.value)
|
||||
elif entity.value_operation == 'scale':
|
||||
@ -71,7 +71,7 @@ class Renderer:
|
||||
self.fill_bg()
|
||||
for asset, entities in pos_dict.items():
|
||||
for entity in entities:
|
||||
bp = self.blit_params(entity, asset)
|
||||
bp = self.blit_params(entity)
|
||||
if 'agent' in asset and self.view_radius > 0:
|
||||
visibility_rect = bp['dest'].inflate((self.view_radius*2)*self.cell_size, (self.view_radius*2)*self.cell_size)
|
||||
shape_surf = pygame.Surface(visibility_rect.size, pygame.SRCALPHA)
|
||||
|
@ -41,10 +41,15 @@ class GettingDirty(BaseFactory):
|
||||
|
||||
dirt = [Entity('dirt', [x, y], min(1.1*self.state[DIRT_INDEX, x, y], 1), 'opacity')
|
||||
for x, y in np.argwhere(self.state[DIRT_INDEX] > h.IS_FREE_CELL)]
|
||||
walls = [Entity('dirt', pos) for pos in np.argwhere(self.state[h.LEVEL_IDX] > h.IS_FREE_CELL)]
|
||||
agents = [Entity('agent1', pos) for pos in np.argwhere(self.state[h.AGENT_START_IDX] > h.IS_FREE_CELL)]
|
||||
walls = [Entity('wall', pos) for pos in np.argwhere(self.state[h.LEVEL_IDX] > h.IS_FREE_CELL)]
|
||||
|
||||
self.renderer.render(OrderedDict(dirt=dirt, wall=walls, agent1=agents))
|
||||
agents = {f'agent{i+1}': [Entity(f'agent{i+1}'
|
||||
if (agent.action_valid and agent.collision_vector[h.LEVEL_IDX] <= 0) else f'agent{i+1}violation',
|
||||
agent.pos)
|
||||
]
|
||||
for i, agent in enumerate(self.agent_states)}
|
||||
print(agents)
|
||||
self.renderer.render(OrderedDict(dirt=dirt, wall=walls, **agents))
|
||||
|
||||
def spawn_dirt(self) -> None:
|
||||
free_for_dirt = self.free_cells(excluded_slices=DIRT_INDEX)
|
||||
|
Loading…
x
Reference in New Issue
Block a user