mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-22 14:56:43 +02:00
Render all spawnpoints that are matched with a target dirt pile + Fixed arrow placement
This commit is contained in:
parent
defbaf6f93
commit
98113ea849
@ -1,3 +1,5 @@
|
||||
import ast
|
||||
import os
|
||||
import pickle
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
@ -147,27 +149,46 @@ def plot_action_maps(factory, agents, result_path):
|
||||
target_dirt_pos = factory.state.entities[d.DIRT][action_map_index].pos
|
||||
action_entities.append(
|
||||
RenderEntity(name='target_dirt', probability=0, pos=swap_coordinates(target_dirt_pos)))
|
||||
action_entities.append(RenderEntity(name='spawn_pos', probability=0, pos=swap_coordinates(
|
||||
factory.state.agent_spawn_positions[agent_index])))
|
||||
|
||||
# Render all spawnpoints assigned to current target dirt pile
|
||||
spawnpoints = list(factory.state.agents_conf.values())[agent_index]['positions']
|
||||
all_target_dirts = []
|
||||
if 'DirtPiles' in factory.conf['Entities']:
|
||||
tuples = ast.literal_eval(factory.conf['Entities']['DirtPiles']['coords_or_quantity'])
|
||||
for t in tuples:
|
||||
all_target_dirts.append(t)
|
||||
assigned_spawn_positions = []
|
||||
for j in range(len(spawnpoints) // len(all_target_dirts)):
|
||||
assigned_spawn_positions.append(spawnpoints[j * len(all_target_dirts) + all_target_dirts.index(target_dirt_pos)])
|
||||
for spawn_pos in assigned_spawn_positions:
|
||||
action_entities.append(RenderEntity(name='spawn_pos', probability=0, pos=swap_coordinates(spawn_pos)))
|
||||
|
||||
render_arrows = []
|
||||
for position, probabilities in probabilities_map.items():
|
||||
if position not in wall_positions:
|
||||
if np.any(probabilities) > 0: # Ensure it's not all zeros which would indicate a wall
|
||||
sorted_indices = sorted(range(len(probabilities)), key=lambda i: -probabilities[i])
|
||||
sorted_indices = np.argsort(np.argsort(-probabilities))
|
||||
colors = ['green_arrow', 'yellow_arrow', 'red_arrow', 'grey_arrow']
|
||||
|
||||
render_arrows.append([])
|
||||
for rank, direction_index in enumerate(sorted_indices):
|
||||
action = directions[direction_index]
|
||||
probability = probabilities[direction_index]
|
||||
arrow_color = colors[rank]
|
||||
if probability > 0:
|
||||
action_entity = RenderEntity(
|
||||
name=arrow_color,
|
||||
pos=position,
|
||||
probability=probability,
|
||||
rotation=direction_index * 90
|
||||
)
|
||||
action_entities.append(action_entity)
|
||||
probability = probabilities[rank]
|
||||
arrow_color = colors[direction_index]
|
||||
render_arrows[-1].append((probability, arrow_color, position))
|
||||
|
||||
# Swap west and east
|
||||
for l in render_arrows:
|
||||
l[1], l[3] = l[3], l[1]
|
||||
for l in render_arrows:
|
||||
for rank, (probability, arrow_color, position) in enumerate(l):
|
||||
if probability > 0:
|
||||
action_entity = RenderEntity(
|
||||
name=arrow_color,
|
||||
pos=position,
|
||||
probability=probability,
|
||||
rotation=rank * 90
|
||||
)
|
||||
action_entities.append(action_entity)
|
||||
|
||||
renderer.render_multi_action_icons(action_entities, result_path)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user