This commit is contained in:
Chanumask
2024-05-06 19:30:27 +02:00
parent 865669055d
commit 39b123221b
3 changed files with 54 additions and 55 deletions
marl_factory_grid/utils
test_run.py

@ -66,31 +66,12 @@ def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, colum
print('Plotting done.')
rotation_mapping = {
'north': ('cardinal', 0),
'east': ('cardinal', 270),
'south': ('cardinal', 180),
'west': ('cardinal', 90),
'north_east': ('diagonal', 0),
'south_east': ('diagonal', 270),
'south_west': ('diagonal', 180),
'north_west': ('diagonal', 90)
}
def swap_coordinates(positions):
if isinstance(positions, tuple) or (isinstance(positions, list) and len(positions) == 2):
# Single position, directly return swapped
return positions[1], positions[0]
elif isinstance(positions, np.ndarray) and positions.ndim == 1 and positions.shape[0] == 2:
# Single position in NumPy array
return positions[1], positions[0]
else:
# Assume it's an iterable of positions
return [(y, x) for x, y in positions]
def plot_routes(factory, agents):
"""
Creates a plot of the agents' actions on the level map by creating a Renderer and Render Entities that hold the
icon that corresponds to the action. For deterministic agents, simply displays the agents path of actions while for
RL agents that can supply an action map or action probabilities from their policy net.
"""
renderer = Renderer(factory.map.level_shape, custom_assets_path={
'cardinal': 'marl_factory_grid/utils/plotting/action_assets/cardinal.png',
'diagonal': 'marl_factory_grid/utils/plotting/action_assets/diagonal.png',
@ -104,17 +85,11 @@ def plot_routes(factory, agents):
wall_positions = factory.map.walls
swapped_wall_positions = swap_coordinates(wall_positions)
wall_entities = [RenderEntity(name='wall', probability=0, pos=np.array(pos)) for pos in swapped_wall_positions]
action_entities = list(wall_entities)
action_entities = []
for index, agent in enumerate(agents):
# Add walls to the action_entities list
for pos in swapped_wall_positions:
wall_entity = RenderEntity(
name='wall',
probability=0,
pos=np.array(pos),
)
action_entities.append(wall_entity)
current_position = swap_coordinates(agent.spawn_position)
if hasattr(agent, 'action_probabilities'):
# Handle RL agents with action probabilities
@ -123,9 +98,6 @@ def plot_routes(factory, agents):
# Handle deterministic agents by iterating through all actions in the list
top_actions = [(action, 1.0) for action in agent.action_list]
current_position = agent.spawn_position
current_position = swap_coordinates(current_position)
for action, probability in top_actions:
if action.lower() in rotation_mapping:
base_icon, rotation = rotation_mapping[action.lower()]
@ -142,27 +114,54 @@ def plot_routes(factory, agents):
probability=probability,
rotation=rotation
)
# print(f"curr: {current_position}, new: {new_position}, pos: {action_entity.pos}")
action_entities.append(action_entity)
current_position = new_position
renderer.render_action_icons(action_entities) # move in/out loop for graph per agent or not
renderer.render_action_icons(action_entities) # move in/out loop for graph per agent or not
def swap_coordinates(positions):
"""
Swaps x and y coordinates of single positions, lists or arrays
"""
if isinstance(positions, tuple) or (isinstance(positions, list) and len(positions) == 2):
return positions[1], positions[0]
elif isinstance(positions, np.ndarray) and positions.ndim == 1 and positions.shape[0] == 2:
return positions[1], positions[0]
else:
return [(y, x) for x, y in positions]
def action_to_coords(current_position, action):
direction_mapping = {
'north': (0, -1),
'south': (0, 1),
'east': (1, 0),
'west': (-1, 0),
'north_east': (1, -1),
'north_west': (-1, -1),
'south_east': (1, 1),
'south_west': (-1, 1)
}
"""
Calculates new coordinates based on the current position and a movement action.
"""
delta = direction_mapping.get(action)
if delta is not None:
new_position = [current_position[0] + delta[0], current_position[1] + delta[1]]
return new_position
print(f"No valid movement action found for {action}.")
return current_position
rotation_mapping = {
'north': ('cardinal', 0),
'east': ('cardinal', 270),
'south': ('cardinal', 180),
'west': ('cardinal', 90),
'north_east': ('diagonal', 0),
'south_east': ('diagonal', 270),
'south_west': ('diagonal', 180),
'north_west': ('diagonal', 90)
}
direction_mapping = {
'north': (0, -1),
'south': (0, 1),
'east': (1, 0),
'west': (-1, 0),
'north_east': (1, -1),
'north_west': (-1, -1),
'south_east': (1, 1),
'south_west': (-1, 1)
}

@ -237,13 +237,13 @@ class Renderer:
def render_action_icons(self, action_entities):
"""
Renders action icons based on the entities' specified actions, positions, and probabilities.
Renders action icons based on the entities' specified actions' name, position, rotation and probability.
Renders probabilities unequal 0.
:param action_entities: List of entities representing actions.
:type action_entities: List[RenderEntity]
"""
self.fill_bg() # Clear the background
self.fill_bg()
font = pygame.font.Font(None, 24) # Initialize the font once for all text rendering
for action_entity in action_entities:
@ -253,11 +253,11 @@ class Renderer:
# Load and potentially rotate the icon based on action name
img = self.assets[action_entity.name.lower()]
if img is None:
print(f"Error: No asset available for '{action_entity.name}'. Skipping rendering this entity.")
continue
if hasattr(action_entity, 'rotation'):
img = pygame.transform.rotate(img, action_entity.rotation)
if img is None:
print(f"Error: No asset available for '{action_entity.name}'. Skipping rendering this entity.")
continue
# Blit the icon image
img_rect = img.get_rect(center=(action_entity.pos[0] * self.cell_size + self.cell_size // 2,

@ -41,4 +41,4 @@ if __name__ == '__main__':
print(f'Episode {episode} done...')
break
plot_routes(factory, agents, )
plot_routes(factory, agents)