mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2026-01-15 23:41:39 +01:00
added plotting probability maps
This commit is contained in:
committed by
Julian Schönberger
parent
3f88c4ee74
commit
83f0c70cfb
@@ -90,7 +90,7 @@ Entities:
|
|||||||
General:
|
General:
|
||||||
env_seed: 69
|
env_seed: 69
|
||||||
individual_rewards: true
|
individual_rewards: true
|
||||||
level_name: large
|
level_name: quadrant
|
||||||
pomdp_r: 3
|
pomdp_r: 3
|
||||||
verbose: false
|
verbose: false
|
||||||
tests: false
|
tests: false
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import pickle
|
import pickle
|
||||||
from os import PathLike
|
from os import PathLike
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -83,9 +84,8 @@ def plot_routes(factory, agents):
|
|||||||
'noop': 'marl_factory_grid/utils/plotting/action_assets/noop.png',
|
'noop': 'marl_factory_grid/utils/plotting/action_assets/noop.png',
|
||||||
'charge_action': 'marl_factory_grid/utils/plotting/action_assets/charge_action.png'})
|
'charge_action': 'marl_factory_grid/utils/plotting/action_assets/charge_action.png'})
|
||||||
|
|
||||||
wall_positions = factory.map.walls
|
wall_positions = swap_coordinates(factory.map.walls)
|
||||||
swapped_wall_positions = swap_coordinates(wall_positions)
|
wall_entities = [RenderEntity(name='wall', probability=0, pos=np.array(pos)) for pos in wall_positions]
|
||||||
wall_entities = [RenderEntity(name='wall', probability=0, pos=np.array(pos)) for pos in swapped_wall_positions]
|
|
||||||
action_entities = list(wall_entities)
|
action_entities = list(wall_entities)
|
||||||
|
|
||||||
for index, agent in enumerate(agents):
|
for index, agent in enumerate(agents):
|
||||||
@@ -117,7 +117,60 @@ def plot_routes(factory, agents):
|
|||||||
action_entities.append(action_entity)
|
action_entities.append(action_entity)
|
||||||
current_position = new_position
|
current_position = new_position
|
||||||
|
|
||||||
renderer.render_action_icons(action_entities) # move in/out loop for graph per agent or not
|
renderer.render_single_action_icons(action_entities) # move in/out loop for graph per agent or not
|
||||||
|
|
||||||
|
|
||||||
|
def plot_action_maps(factory, agents):
|
||||||
|
renderer = Renderer(factory.map.level_shape, custom_assets_path={
|
||||||
|
'green_arrow': 'marl_factory_grid/utils/plotting/action_assets/green_arrow.png',
|
||||||
|
'yellow_arrow': 'marl_factory_grid/utils/plotting/action_assets/yellow_arrow.png',
|
||||||
|
'red_arrow': 'marl_factory_grid/utils/plotting/action_assets/red_arrow.png',
|
||||||
|
'grey_arrow': 'marl_factory_grid/utils/plotting/action_assets/grey_arrow.png',
|
||||||
|
'wall': 'marl_factory_grid/environment/assets/wall.png',
|
||||||
|
})
|
||||||
|
|
||||||
|
directions = ['north', 'east', 'south', 'west']
|
||||||
|
wall_positions = swap_coordinates(factory.map.walls)
|
||||||
|
wall_entities = [RenderEntity(name='wall', probability=0, pos=np.array(pos)) for pos in wall_positions]
|
||||||
|
action_entities = list(wall_entities)
|
||||||
|
|
||||||
|
dummy_action_map = load_action_map("example_action_map.txt")
|
||||||
|
for agent in agents:
|
||||||
|
# if hasattr(agent, 'action_probability_map'):
|
||||||
|
# for y in range(len(agent.action_probability_map)):
|
||||||
|
for y in range(len(dummy_action_map)):
|
||||||
|
# for x in range(len(agent.action_probability_map[y])):
|
||||||
|
for x in range(len(dummy_action_map[y])):
|
||||||
|
position = (x, y)
|
||||||
|
if position not in wall_positions:
|
||||||
|
# action_probabilities = agent.action_probability_map[y][x]
|
||||||
|
action_probabilities = dummy_action_map[y][x]
|
||||||
|
if sum(action_probabilities) > 0: # Ensure it's not all zeros which would indicate a wall
|
||||||
|
# Sort actions by probability and assign colors
|
||||||
|
sorted_indices = sorted(range(len(action_probabilities)),
|
||||||
|
key=lambda i: -action_probabilities[i])
|
||||||
|
colors = ['green_arrow', 'yellow_arrow', 'red_arrow', 'grey_arrow']
|
||||||
|
|
||||||
|
for rank, direction_index in enumerate(sorted_indices):
|
||||||
|
action = directions[direction_index]
|
||||||
|
probability = action_probabilities[direction_index]
|
||||||
|
arrow_color = colors[rank]
|
||||||
|
if probability > 0:
|
||||||
|
action_entity = RenderEntity(
|
||||||
|
name=arrow_color,
|
||||||
|
pos=position,
|
||||||
|
probability=probability,
|
||||||
|
rotation=direction_index * 90
|
||||||
|
)
|
||||||
|
action_entities.append(action_entity)
|
||||||
|
|
||||||
|
renderer.render_multi_action_icons(action_entities)
|
||||||
|
|
||||||
|
|
||||||
|
def load_action_map(file_path):
|
||||||
|
with open(file_path, 'r') as file:
|
||||||
|
action_map = json.load(file)
|
||||||
|
return action_map
|
||||||
|
|
||||||
|
|
||||||
def swap_coordinates(positions):
|
def swap_coordinates(positions):
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from collections import deque
|
from collections import deque, defaultdict
|
||||||
from itertools import product
|
from itertools import product
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -240,7 +240,7 @@ class Renderer:
|
|||||||
return np.transpose(rgb_obs, (2, 0, 1))
|
return np.transpose(rgb_obs, (2, 0, 1))
|
||||||
# return torch.from_numpy(rgb_obs).permute(2, 0, 1)
|
# return torch.from_numpy(rgb_obs).permute(2, 0, 1)
|
||||||
|
|
||||||
def render_action_icons(self, action_entities):
|
def render_single_action_icons(self, action_entities):
|
||||||
"""
|
"""
|
||||||
Renders action icons based on the entities' specified actions' name, position, rotation and probability.
|
Renders action icons based on the entities' specified actions' name, position, rotation and probability.
|
||||||
Renders probabilities unequal 0.
|
Renders probabilities unequal 0.
|
||||||
@@ -249,7 +249,6 @@ class Renderer:
|
|||||||
:type action_entities: List[RenderEntity]
|
:type action_entities: List[RenderEntity]
|
||||||
"""
|
"""
|
||||||
self.fill_bg()
|
self.fill_bg()
|
||||||
font = pygame.font.Font(None, 24) # Initialize the font once for all text rendering
|
|
||||||
|
|
||||||
for action_entity in action_entities:
|
for action_entity in action_entities:
|
||||||
if not isinstance(action_entity.pos, np.ndarray) or action_entity.pos.ndim != 1:
|
if not isinstance(action_entity.pos, np.ndarray) or action_entity.pos.ndim != 1:
|
||||||
@@ -271,13 +270,74 @@ class Renderer:
|
|||||||
|
|
||||||
# Render the probability next to the icon if it exists
|
# Render the probability next to the icon if it exists
|
||||||
if hasattr(action_entity, 'probability') and action_entity.probability != 0:
|
if hasattr(action_entity, 'probability') and action_entity.probability != 0:
|
||||||
prob_text = font.render(f"{action_entity.probability:.2f}", True, (255, 0, 0))
|
prob_text = self.font.render(f"{action_entity.probability:.2f}", True, (255, 0, 0))
|
||||||
prob_text_rect = prob_text.get_rect(top=img_rect.bottom, left=img_rect.left)
|
prob_text_rect = prob_text.get_rect(top=img_rect.bottom, left=img_rect.left)
|
||||||
self.screen.blit(prob_text, prob_text_rect)
|
self.screen.blit(prob_text, prob_text_rect)
|
||||||
|
|
||||||
pygame.display.flip() # Update the display with all new blits
|
pygame.display.flip() # Update the display with all new blits
|
||||||
self.save_screen("route_graph")
|
self.save_screen("route_graph")
|
||||||
|
|
||||||
|
def render_multi_action_icons(self, action_entities):
|
||||||
|
"""
|
||||||
|
Renders multiple action icons at the same position without overlap and arranges them based on direction, except for
|
||||||
|
walls which cover the entire grid cell.
|
||||||
|
"""
|
||||||
|
self.fill_bg()
|
||||||
|
font = pygame.font.Font(None, 18)
|
||||||
|
|
||||||
|
# prepare position dict to iterate over
|
||||||
|
position_dict = defaultdict(list)
|
||||||
|
for entity in action_entities:
|
||||||
|
position_dict[tuple(entity.pos)].append(entity)
|
||||||
|
|
||||||
|
for position, entities in position_dict.items():
|
||||||
|
num_entities = len(entities)
|
||||||
|
entity_size = self.cell_size // 2 # Adjust size to fit multiple entities for non-wall entities
|
||||||
|
|
||||||
|
# Define offsets for each direction based on a quadrant layout within the cell
|
||||||
|
offsets = {
|
||||||
|
0: (-entity_size // 2, -entity_size // 2), # North
|
||||||
|
90: (-entity_size // 2, entity_size // 2), # East
|
||||||
|
180: (entity_size // 2, entity_size // 2), # South
|
||||||
|
270: (entity_size // 2, -entity_size // 2) # West
|
||||||
|
}
|
||||||
|
|
||||||
|
# Sort entities based on direction to ensure consistent positioning
|
||||||
|
entities.sort(key=lambda x: x.rotation)
|
||||||
|
|
||||||
|
for entity in entities:
|
||||||
|
img = self.assets[entity.name.lower()]
|
||||||
|
if img is None:
|
||||||
|
print(f"Error: No asset available for '{entity.name}'. Skipping rendering this entity.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
img = pygame.transform.rotate(img, entity.rotation)
|
||||||
|
|
||||||
|
# Check if the entity is a wall and adjust the size and position accordingly
|
||||||
|
if entity.name == 'wall':
|
||||||
|
img = pygame.transform.scale(img, (self.cell_size, self.cell_size))
|
||||||
|
img_rect = img.get_rect(center=(position[0] * self.cell_size + self.cell_size // 2,
|
||||||
|
position[1] * self.cell_size + self.cell_size // 2))
|
||||||
|
else:
|
||||||
|
img = pygame.transform.scale(img, (entity_size, entity_size)) # Scale down the image for arrows
|
||||||
|
offset = offsets.get(entity.rotation, (0, 0))
|
||||||
|
img_rect = img.get_rect(center=(
|
||||||
|
position[0] * self.cell_size + self.cell_size // 2 + offset[0],
|
||||||
|
position[1] * self.cell_size + self.cell_size // 2 + offset[1]
|
||||||
|
))
|
||||||
|
|
||||||
|
self.screen.blit(img, img_rect)
|
||||||
|
|
||||||
|
# Render the probability next to the icon if it exists and is non-zero
|
||||||
|
if entity.probability > 0 and entity.name != 'wall':
|
||||||
|
formatted_probability = f"{entity.probability:.4f}"
|
||||||
|
prob_text = font.render(formatted_probability, True, (0, 0, 0)) # Black color for readability
|
||||||
|
prob_text_rect = prob_text.get_rect(center=img_rect.center) # Center text on the arrow
|
||||||
|
self.screen.blit(prob_text, prob_text_rect)
|
||||||
|
|
||||||
|
pygame.display.flip() # Update the display
|
||||||
|
self.save_screen("multi_action_graph")
|
||||||
|
|
||||||
def save_screen(self, filename):
|
def save_screen(self, filename):
|
||||||
"""
|
"""
|
||||||
Saves the current screen to a PNG file, appending a counter to ensure uniqueness.
|
Saves the current screen to a PNG file, appending a counter to ensure uniqueness.
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from marl_factory_grid.algorithms.static.TSP_item_agent import TSPItemAgent
|
|||||||
from marl_factory_grid.algorithms.static.TSP_target_agent import TSPTargetAgent
|
from marl_factory_grid.algorithms.static.TSP_target_agent import TSPTargetAgent
|
||||||
from marl_factory_grid.environment.factory import Factory
|
from marl_factory_grid.environment.factory import Factory
|
||||||
|
|
||||||
from marl_factory_grid.utils.plotting.plot_single_runs import plot_routes
|
from marl_factory_grid.utils.plotting.plot_single_runs import plot_routes, plot_action_maps
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
@@ -40,4 +40,4 @@ if __name__ == '__main__':
|
|||||||
print(f'Episode {episode} done...')
|
print(f'Episode {episode} done...')
|
||||||
break
|
break
|
||||||
|
|
||||||
plot_routes(factory, agents)
|
plot_action_maps(factory, agents)
|
||||||
|
|||||||
Reference in New Issue
Block a user