This commit is contained in:
Chanumask
2024-05-06 19:30:27 +02:00
committed by Julian Schönberger
parent 0295af34b1
commit 3f88c4ee74
3 changed files with 54 additions and 55 deletions

View File

@@ -66,31 +66,12 @@ def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, colum
print('Plotting done.') print('Plotting done.')
rotation_mapping = {
'north': ('cardinal', 0),
'east': ('cardinal', 270),
'south': ('cardinal', 180),
'west': ('cardinal', 90),
'north_east': ('diagonal', 0),
'south_east': ('diagonal', 270),
'south_west': ('diagonal', 180),
'north_west': ('diagonal', 90)
}
def swap_coordinates(positions):
if isinstance(positions, tuple) or (isinstance(positions, list) and len(positions) == 2):
# Single position, directly return swapped
return positions[1], positions[0]
elif isinstance(positions, np.ndarray) and positions.ndim == 1 and positions.shape[0] == 2:
# Single position in NumPy array
return positions[1], positions[0]
else:
# Assume it's an iterable of positions
return [(y, x) for x, y in positions]
def plot_routes(factory, agents): def plot_routes(factory, agents):
"""
Creates a plot of the agents' actions on the level map by creating a Renderer and Render Entities that hold the
icon that corresponds to the action. For deterministic agents, simply displays the agents path of actions while for
RL agents that can supply an action map or action probabilities from their policy net.
"""
renderer = Renderer(factory.map.level_shape, custom_assets_path={ renderer = Renderer(factory.map.level_shape, custom_assets_path={
'cardinal': 'marl_factory_grid/utils/plotting/action_assets/cardinal.png', 'cardinal': 'marl_factory_grid/utils/plotting/action_assets/cardinal.png',
'diagonal': 'marl_factory_grid/utils/plotting/action_assets/diagonal.png', 'diagonal': 'marl_factory_grid/utils/plotting/action_assets/diagonal.png',
@@ -104,17 +85,11 @@ def plot_routes(factory, agents):
wall_positions = factory.map.walls wall_positions = factory.map.walls
swapped_wall_positions = swap_coordinates(wall_positions) swapped_wall_positions = swap_coordinates(wall_positions)
wall_entities = [RenderEntity(name='wall', probability=0, pos=np.array(pos)) for pos in swapped_wall_positions]
action_entities = list(wall_entities)
action_entities = []
for index, agent in enumerate(agents): for index, agent in enumerate(agents):
# Add walls to the action_entities list current_position = swap_coordinates(agent.spawn_position)
for pos in swapped_wall_positions:
wall_entity = RenderEntity(
name='wall',
probability=0,
pos=np.array(pos),
)
action_entities.append(wall_entity)
if hasattr(agent, 'action_probabilities'): if hasattr(agent, 'action_probabilities'):
# Handle RL agents with action probabilities # Handle RL agents with action probabilities
@@ -123,9 +98,6 @@ def plot_routes(factory, agents):
# Handle deterministic agents by iterating through all actions in the list # Handle deterministic agents by iterating through all actions in the list
top_actions = [(action, 1.0) for action in agent.action_list] top_actions = [(action, 1.0) for action in agent.action_list]
current_position = agent.spawn_position
current_position = swap_coordinates(current_position)
for action, probability in top_actions: for action, probability in top_actions:
if action.lower() in rotation_mapping: if action.lower() in rotation_mapping:
base_icon, rotation = rotation_mapping[action.lower()] base_icon, rotation = rotation_mapping[action.lower()]
@@ -142,27 +114,54 @@ def plot_routes(factory, agents):
probability=probability, probability=probability,
rotation=rotation rotation=rotation
) )
# print(f"curr: {current_position}, new: {new_position}, pos: {action_entity.pos}")
action_entities.append(action_entity) action_entities.append(action_entity)
current_position = new_position current_position = new_position
renderer.render_action_icons(action_entities) # move in/out loop for graph per agent or not renderer.render_action_icons(action_entities) # move in/out loop for graph per agent or not
def swap_coordinates(positions):
"""
Swaps x and y coordinates of single positions, lists or arrays
"""
if isinstance(positions, tuple) or (isinstance(positions, list) and len(positions) == 2):
return positions[1], positions[0]
elif isinstance(positions, np.ndarray) and positions.ndim == 1 and positions.shape[0] == 2:
return positions[1], positions[0]
else:
return [(y, x) for x, y in positions]
def action_to_coords(current_position, action): def action_to_coords(current_position, action):
direction_mapping = { """
'north': (0, -1), Calculates new coordinates based on the current position and a movement action.
'south': (0, 1), """
'east': (1, 0),
'west': (-1, 0),
'north_east': (1, -1),
'north_west': (-1, -1),
'south_east': (1, 1),
'south_west': (-1, 1)
}
delta = direction_mapping.get(action) delta = direction_mapping.get(action)
if delta is not None: if delta is not None:
new_position = [current_position[0] + delta[0], current_position[1] + delta[1]] new_position = [current_position[0] + delta[0], current_position[1] + delta[1]]
return new_position return new_position
print(f"No valid movement action found for {action}.") print(f"No valid movement action found for {action}.")
return current_position return current_position
rotation_mapping = {
'north': ('cardinal', 0),
'east': ('cardinal', 270),
'south': ('cardinal', 180),
'west': ('cardinal', 90),
'north_east': ('diagonal', 0),
'south_east': ('diagonal', 270),
'south_west': ('diagonal', 180),
'north_west': ('diagonal', 90)
}
direction_mapping = {
'north': (0, -1),
'south': (0, 1),
'east': (1, 0),
'west': (-1, 0),
'north_east': (1, -1),
'north_west': (-1, -1),
'south_east': (1, 1),
'south_west': (-1, 1)
}

View File

@@ -242,13 +242,13 @@ class Renderer:
def render_action_icons(self, action_entities): def render_action_icons(self, action_entities):
""" """
Renders action icons based on the entities' specified actions, positions, and probabilities. Renders action icons based on the entities' specified actions' name, position, rotation and probability.
Renders probabilities unequal 0.
:param action_entities: List of entities representing actions. :param action_entities: List of entities representing actions.
:type action_entities: List[RenderEntity] :type action_entities: List[RenderEntity]
""" """
self.fill_bg()
self.fill_bg() # Clear the background
font = pygame.font.Font(None, 24) # Initialize the font once for all text rendering font = pygame.font.Font(None, 24) # Initialize the font once for all text rendering
for action_entity in action_entities: for action_entity in action_entities:
@@ -258,11 +258,11 @@ class Renderer:
# Load and potentially rotate the icon based on action name # Load and potentially rotate the icon based on action name
img = self.assets[action_entity.name.lower()] img = self.assets[action_entity.name.lower()]
if img is None:
print(f"Error: No asset available for '{action_entity.name}'. Skipping rendering this entity.")
continue
if hasattr(action_entity, 'rotation'): if hasattr(action_entity, 'rotation'):
img = pygame.transform.rotate(img, action_entity.rotation) img = pygame.transform.rotate(img, action_entity.rotation)
if img is None:
print(f"Error: No asset available for '{action_entity.name}'. Skipping rendering this entity.")
continue
# Blit the icon image # Blit the icon image
img_rect = img.get_rect(center=(action_entity.pos[0] * self.cell_size + self.cell_size // 2, img_rect = img.get_rect(center=(action_entity.pos[0] * self.cell_size + self.cell_size // 2,

View File

@@ -40,4 +40,4 @@ if __name__ == '__main__':
print(f'Episode {episode} done...') print(f'Episode {episode} done...')
break break
plot_routes(factory, agents, ) plot_routes(factory, agents)