From 98113ea8492c15c117edaf755bb176a0f8aa8131 Mon Sep 17 00:00:00 2001 From: Chanumask Date: Sun, 12 May 2024 11:48:05 +0200 Subject: [PATCH] Render all spawnpoints that are matched with a target dirt pile + Fixed arrow placement --- .../utils/plotting/plot_single_runs.py | 49 +++++++++++++------ 1 file changed, 35 insertions(+), 14 deletions(-) diff --git a/marl_factory_grid/utils/plotting/plot_single_runs.py b/marl_factory_grid/utils/plotting/plot_single_runs.py index b44446d..5fbb024 100644 --- a/marl_factory_grid/utils/plotting/plot_single_runs.py +++ b/marl_factory_grid/utils/plotting/plot_single_runs.py @@ -1,3 +1,5 @@ +import ast +import os import pickle from os import PathLike from pathlib import Path @@ -147,27 +149,46 @@ def plot_action_maps(factory, agents, result_path): target_dirt_pos = factory.state.entities[d.DIRT][action_map_index].pos action_entities.append( RenderEntity(name='target_dirt', probability=0, pos=swap_coordinates(target_dirt_pos))) - action_entities.append(RenderEntity(name='spawn_pos', probability=0, pos=swap_coordinates( - factory.state.agent_spawn_positions[agent_index]))) + # Render all spawnpoints assigned to current target dirt pile + spawnpoints = list(factory.state.agents_conf.values())[agent_index]['positions'] + all_target_dirts = [] + if 'DirtPiles' in factory.conf['Entities']: + tuples = ast.literal_eval(factory.conf['Entities']['DirtPiles']['coords_or_quantity']) + for t in tuples: + all_target_dirts.append(t) + assigned_spawn_positions = [] + for j in range(len(spawnpoints) // len(all_target_dirts)): + assigned_spawn_positions.append(spawnpoints[j * len(all_target_dirts) + all_target_dirts.index(target_dirt_pos)]) + for spawn_pos in assigned_spawn_positions: + action_entities.append(RenderEntity(name='spawn_pos', probability=0, pos=swap_coordinates(spawn_pos))) + + render_arrows = [] for position, probabilities in probabilities_map.items(): if position not in wall_positions: if np.any(probabilities) > 0: # Ensure it's not all zeros which would indicate a wall - sorted_indices = sorted(range(len(probabilities)), key=lambda i: -probabilities[i]) + sorted_indices = np.argsort(np.argsort(-probabilities)) colors = ['green_arrow', 'yellow_arrow', 'red_arrow', 'grey_arrow'] - + render_arrows.append([]) for rank, direction_index in enumerate(sorted_indices): action = directions[direction_index] - probability = probabilities[direction_index] - arrow_color = colors[rank] - if probability > 0: - action_entity = RenderEntity( - name=arrow_color, - pos=position, - probability=probability, - rotation=direction_index * 90 - ) - action_entities.append(action_entity) + probability = probabilities[rank] + arrow_color = colors[direction_index] + render_arrows[-1].append((probability, arrow_color, position)) + + # Swap west and east + for l in render_arrows: + l[1], l[3] = l[3], l[1] + for l in render_arrows: + for rank, (probability, arrow_color, position) in enumerate(l): + if probability > 0: + action_entity = RenderEntity( + name=arrow_color, + pos=position, + probability=probability, + rotation=rank * 90 + ) + action_entities.append(action_entity) renderer.render_multi_action_icons(action_entities, result_path)