added plotting probability maps

This commit is contained in:
Chanumask
2024-05-08 14:27:08 +02:00
committed by Julian Schönberger
parent 3f88c4ee74
commit 83f0c70cfb
4 changed files with 124 additions and 11 deletions

View File

@@ -90,7 +90,7 @@ Entities:
General: General:
env_seed: 69 env_seed: 69
individual_rewards: true individual_rewards: true
level_name: large level_name: quadrant
pomdp_r: 3 pomdp_r: 3
verbose: false verbose: false
tests: false tests: false

View File

@@ -1,3 +1,4 @@
import json
import pickle import pickle
from os import PathLike from os import PathLike
from pathlib import Path from pathlib import Path
@@ -83,9 +84,8 @@ def plot_routes(factory, agents):
'noop': 'marl_factory_grid/utils/plotting/action_assets/noop.png', 'noop': 'marl_factory_grid/utils/plotting/action_assets/noop.png',
'charge_action': 'marl_factory_grid/utils/plotting/action_assets/charge_action.png'}) 'charge_action': 'marl_factory_grid/utils/plotting/action_assets/charge_action.png'})
wall_positions = factory.map.walls wall_positions = swap_coordinates(factory.map.walls)
swapped_wall_positions = swap_coordinates(wall_positions) wall_entities = [RenderEntity(name='wall', probability=0, pos=np.array(pos)) for pos in wall_positions]
wall_entities = [RenderEntity(name='wall', probability=0, pos=np.array(pos)) for pos in swapped_wall_positions]
action_entities = list(wall_entities) action_entities = list(wall_entities)
for index, agent in enumerate(agents): for index, agent in enumerate(agents):
@@ -117,7 +117,60 @@ def plot_routes(factory, agents):
action_entities.append(action_entity) action_entities.append(action_entity)
current_position = new_position current_position = new_position
renderer.render_action_icons(action_entities) # move in/out loop for graph per agent or not renderer.render_single_action_icons(action_entities) # move in/out loop for graph per agent or not
def plot_action_maps(factory, agents):
renderer = Renderer(factory.map.level_shape, custom_assets_path={
'green_arrow': 'marl_factory_grid/utils/plotting/action_assets/green_arrow.png',
'yellow_arrow': 'marl_factory_grid/utils/plotting/action_assets/yellow_arrow.png',
'red_arrow': 'marl_factory_grid/utils/plotting/action_assets/red_arrow.png',
'grey_arrow': 'marl_factory_grid/utils/plotting/action_assets/grey_arrow.png',
'wall': 'marl_factory_grid/environment/assets/wall.png',
})
directions = ['north', 'east', 'south', 'west']
wall_positions = swap_coordinates(factory.map.walls)
wall_entities = [RenderEntity(name='wall', probability=0, pos=np.array(pos)) for pos in wall_positions]
action_entities = list(wall_entities)
dummy_action_map = load_action_map("example_action_map.txt")
for agent in agents:
# if hasattr(agent, 'action_probability_map'):
# for y in range(len(agent.action_probability_map)):
for y in range(len(dummy_action_map)):
# for x in range(len(agent.action_probability_map[y])):
for x in range(len(dummy_action_map[y])):
position = (x, y)
if position not in wall_positions:
# action_probabilities = agent.action_probability_map[y][x]
action_probabilities = dummy_action_map[y][x]
if sum(action_probabilities) > 0: # Ensure it's not all zeros which would indicate a wall
# Sort actions by probability and assign colors
sorted_indices = sorted(range(len(action_probabilities)),
key=lambda i: -action_probabilities[i])
colors = ['green_arrow', 'yellow_arrow', 'red_arrow', 'grey_arrow']
for rank, direction_index in enumerate(sorted_indices):
action = directions[direction_index]
probability = action_probabilities[direction_index]
arrow_color = colors[rank]
if probability > 0:
action_entity = RenderEntity(
name=arrow_color,
pos=position,
probability=probability,
rotation=direction_index * 90
)
action_entities.append(action_entity)
renderer.render_multi_action_icons(action_entities)
def load_action_map(file_path):
with open(file_path, 'r') as file:
action_map = json.load(file)
return action_map
def swap_coordinates(positions): def swap_coordinates(positions):

View File

@@ -1,7 +1,7 @@
import sys import sys
from pathlib import Path from pathlib import Path
from collections import deque from collections import deque, defaultdict
from itertools import product from itertools import product
import numpy as np import numpy as np
@@ -240,7 +240,7 @@ class Renderer:
return np.transpose(rgb_obs, (2, 0, 1)) return np.transpose(rgb_obs, (2, 0, 1))
# return torch.from_numpy(rgb_obs).permute(2, 0, 1) # return torch.from_numpy(rgb_obs).permute(2, 0, 1)
def render_action_icons(self, action_entities): def render_single_action_icons(self, action_entities):
""" """
Renders action icons based on the entities' specified actions' name, position, rotation and probability. Renders action icons based on the entities' specified actions' name, position, rotation and probability.
Renders probabilities unequal 0. Renders probabilities unequal 0.
@@ -249,7 +249,6 @@ class Renderer:
:type action_entities: List[RenderEntity] :type action_entities: List[RenderEntity]
""" """
self.fill_bg() self.fill_bg()
font = pygame.font.Font(None, 24) # Initialize the font once for all text rendering
for action_entity in action_entities: for action_entity in action_entities:
if not isinstance(action_entity.pos, np.ndarray) or action_entity.pos.ndim != 1: if not isinstance(action_entity.pos, np.ndarray) or action_entity.pos.ndim != 1:
@@ -271,13 +270,74 @@ class Renderer:
# Render the probability next to the icon if it exists # Render the probability next to the icon if it exists
if hasattr(action_entity, 'probability') and action_entity.probability != 0: if hasattr(action_entity, 'probability') and action_entity.probability != 0:
prob_text = font.render(f"{action_entity.probability:.2f}", True, (255, 0, 0)) prob_text = self.font.render(f"{action_entity.probability:.2f}", True, (255, 0, 0))
prob_text_rect = prob_text.get_rect(top=img_rect.bottom, left=img_rect.left) prob_text_rect = prob_text.get_rect(top=img_rect.bottom, left=img_rect.left)
self.screen.blit(prob_text, prob_text_rect) self.screen.blit(prob_text, prob_text_rect)
pygame.display.flip() # Update the display with all new blits pygame.display.flip() # Update the display with all new blits
self.save_screen("route_graph") self.save_screen("route_graph")
def render_multi_action_icons(self, action_entities):
"""
Renders multiple action icons at the same position without overlap and arranges them based on direction, except for
walls which cover the entire grid cell.
"""
self.fill_bg()
font = pygame.font.Font(None, 18)
# prepare position dict to iterate over
position_dict = defaultdict(list)
for entity in action_entities:
position_dict[tuple(entity.pos)].append(entity)
for position, entities in position_dict.items():
num_entities = len(entities)
entity_size = self.cell_size // 2 # Adjust size to fit multiple entities for non-wall entities
# Define offsets for each direction based on a quadrant layout within the cell
offsets = {
0: (-entity_size // 2, -entity_size // 2), # North
90: (-entity_size // 2, entity_size // 2), # East
180: (entity_size // 2, entity_size // 2), # South
270: (entity_size // 2, -entity_size // 2) # West
}
# Sort entities based on direction to ensure consistent positioning
entities.sort(key=lambda x: x.rotation)
for entity in entities:
img = self.assets[entity.name.lower()]
if img is None:
print(f"Error: No asset available for '{entity.name}'. Skipping rendering this entity.")
continue
img = pygame.transform.rotate(img, entity.rotation)
# Check if the entity is a wall and adjust the size and position accordingly
if entity.name == 'wall':
img = pygame.transform.scale(img, (self.cell_size, self.cell_size))
img_rect = img.get_rect(center=(position[0] * self.cell_size + self.cell_size // 2,
position[1] * self.cell_size + self.cell_size // 2))
else:
img = pygame.transform.scale(img, (entity_size, entity_size)) # Scale down the image for arrows
offset = offsets.get(entity.rotation, (0, 0))
img_rect = img.get_rect(center=(
position[0] * self.cell_size + self.cell_size // 2 + offset[0],
position[1] * self.cell_size + self.cell_size // 2 + offset[1]
))
self.screen.blit(img, img_rect)
# Render the probability next to the icon if it exists and is non-zero
if entity.probability > 0 and entity.name != 'wall':
formatted_probability = f"{entity.probability:.4f}"
prob_text = font.render(formatted_probability, True, (0, 0, 0)) # Black color for readability
prob_text_rect = prob_text.get_rect(center=img_rect.center) # Center text on the arrow
self.screen.blit(prob_text, prob_text_rect)
pygame.display.flip() # Update the display
self.save_screen("multi_action_graph")
def save_screen(self, filename): def save_screen(self, filename):
""" """
Saves the current screen to a PNG file, appending a counter to ensure uniqueness. Saves the current screen to a PNG file, appending a counter to ensure uniqueness.

View File

@@ -8,7 +8,7 @@ from marl_factory_grid.algorithms.static.TSP_item_agent import TSPItemAgent
from marl_factory_grid.algorithms.static.TSP_target_agent import TSPTargetAgent from marl_factory_grid.algorithms.static.TSP_target_agent import TSPTargetAgent
from marl_factory_grid.environment.factory import Factory from marl_factory_grid.environment.factory import Factory
from marl_factory_grid.utils.plotting.plot_single_runs import plot_routes from marl_factory_grid.utils.plotting.plot_single_runs import plot_routes, plot_action_maps
if __name__ == '__main__': if __name__ == '__main__':
@@ -40,4 +40,4 @@ if __name__ == '__main__':
print(f'Episode {episode} done...') print(f'Episode {episode} done...')
break break
plot_routes(factory, agents) plot_action_maps(factory, agents)