mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2026-01-15 23:41:39 +01:00
cleanup
This commit is contained in:
committed by
Julian Schönberger
parent
0295af34b1
commit
3f88c4ee74
@@ -66,31 +66,12 @@ def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, colum
|
|||||||
print('Plotting done.')
|
print('Plotting done.')
|
||||||
|
|
||||||
|
|
||||||
rotation_mapping = {
|
|
||||||
'north': ('cardinal', 0),
|
|
||||||
'east': ('cardinal', 270),
|
|
||||||
'south': ('cardinal', 180),
|
|
||||||
'west': ('cardinal', 90),
|
|
||||||
'north_east': ('diagonal', 0),
|
|
||||||
'south_east': ('diagonal', 270),
|
|
||||||
'south_west': ('diagonal', 180),
|
|
||||||
'north_west': ('diagonal', 90)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def swap_coordinates(positions):
|
|
||||||
if isinstance(positions, tuple) or (isinstance(positions, list) and len(positions) == 2):
|
|
||||||
# Single position, directly return swapped
|
|
||||||
return positions[1], positions[0]
|
|
||||||
elif isinstance(positions, np.ndarray) and positions.ndim == 1 and positions.shape[0] == 2:
|
|
||||||
# Single position in NumPy array
|
|
||||||
return positions[1], positions[0]
|
|
||||||
else:
|
|
||||||
# Assume it's an iterable of positions
|
|
||||||
return [(y, x) for x, y in positions]
|
|
||||||
|
|
||||||
|
|
||||||
def plot_routes(factory, agents):
|
def plot_routes(factory, agents):
|
||||||
|
"""
|
||||||
|
Creates a plot of the agents' actions on the level map by creating a Renderer and Render Entities that hold the
|
||||||
|
icon that corresponds to the action. For deterministic agents, simply displays the agents path of actions while for
|
||||||
|
RL agents that can supply an action map or action probabilities from their policy net.
|
||||||
|
"""
|
||||||
renderer = Renderer(factory.map.level_shape, custom_assets_path={
|
renderer = Renderer(factory.map.level_shape, custom_assets_path={
|
||||||
'cardinal': 'marl_factory_grid/utils/plotting/action_assets/cardinal.png',
|
'cardinal': 'marl_factory_grid/utils/plotting/action_assets/cardinal.png',
|
||||||
'diagonal': 'marl_factory_grid/utils/plotting/action_assets/diagonal.png',
|
'diagonal': 'marl_factory_grid/utils/plotting/action_assets/diagonal.png',
|
||||||
@@ -104,17 +85,11 @@ def plot_routes(factory, agents):
|
|||||||
|
|
||||||
wall_positions = factory.map.walls
|
wall_positions = factory.map.walls
|
||||||
swapped_wall_positions = swap_coordinates(wall_positions)
|
swapped_wall_positions = swap_coordinates(wall_positions)
|
||||||
|
wall_entities = [RenderEntity(name='wall', probability=0, pos=np.array(pos)) for pos in swapped_wall_positions]
|
||||||
|
action_entities = list(wall_entities)
|
||||||
|
|
||||||
action_entities = []
|
|
||||||
for index, agent in enumerate(agents):
|
for index, agent in enumerate(agents):
|
||||||
# Add walls to the action_entities list
|
current_position = swap_coordinates(agent.spawn_position)
|
||||||
for pos in swapped_wall_positions:
|
|
||||||
wall_entity = RenderEntity(
|
|
||||||
name='wall',
|
|
||||||
probability=0,
|
|
||||||
pos=np.array(pos),
|
|
||||||
)
|
|
||||||
action_entities.append(wall_entity)
|
|
||||||
|
|
||||||
if hasattr(agent, 'action_probabilities'):
|
if hasattr(agent, 'action_probabilities'):
|
||||||
# Handle RL agents with action probabilities
|
# Handle RL agents with action probabilities
|
||||||
@@ -123,9 +98,6 @@ def plot_routes(factory, agents):
|
|||||||
# Handle deterministic agents by iterating through all actions in the list
|
# Handle deterministic agents by iterating through all actions in the list
|
||||||
top_actions = [(action, 1.0) for action in agent.action_list]
|
top_actions = [(action, 1.0) for action in agent.action_list]
|
||||||
|
|
||||||
current_position = agent.spawn_position
|
|
||||||
current_position = swap_coordinates(current_position)
|
|
||||||
|
|
||||||
for action, probability in top_actions:
|
for action, probability in top_actions:
|
||||||
if action.lower() in rotation_mapping:
|
if action.lower() in rotation_mapping:
|
||||||
base_icon, rotation = rotation_mapping[action.lower()]
|
base_icon, rotation = rotation_mapping[action.lower()]
|
||||||
@@ -142,27 +114,54 @@ def plot_routes(factory, agents):
|
|||||||
probability=probability,
|
probability=probability,
|
||||||
rotation=rotation
|
rotation=rotation
|
||||||
)
|
)
|
||||||
# print(f"curr: {current_position}, new: {new_position}, pos: {action_entity.pos}")
|
|
||||||
action_entities.append(action_entity)
|
action_entities.append(action_entity)
|
||||||
current_position = new_position
|
current_position = new_position
|
||||||
|
|
||||||
renderer.render_action_icons(action_entities) # move in/out loop for graph per agent or not
|
renderer.render_action_icons(action_entities) # move in/out loop for graph per agent or not
|
||||||
|
|
||||||
|
|
||||||
|
def swap_coordinates(positions):
|
||||||
|
"""
|
||||||
|
Swaps x and y coordinates of single positions, lists or arrays
|
||||||
|
"""
|
||||||
|
if isinstance(positions, tuple) or (isinstance(positions, list) and len(positions) == 2):
|
||||||
|
return positions[1], positions[0]
|
||||||
|
elif isinstance(positions, np.ndarray) and positions.ndim == 1 and positions.shape[0] == 2:
|
||||||
|
return positions[1], positions[0]
|
||||||
|
else:
|
||||||
|
return [(y, x) for x, y in positions]
|
||||||
|
|
||||||
|
|
||||||
def action_to_coords(current_position, action):
|
def action_to_coords(current_position, action):
|
||||||
direction_mapping = {
|
"""
|
||||||
'north': (0, -1),
|
Calculates new coordinates based on the current position and a movement action.
|
||||||
'south': (0, 1),
|
"""
|
||||||
'east': (1, 0),
|
|
||||||
'west': (-1, 0),
|
|
||||||
'north_east': (1, -1),
|
|
||||||
'north_west': (-1, -1),
|
|
||||||
'south_east': (1, 1),
|
|
||||||
'south_west': (-1, 1)
|
|
||||||
}
|
|
||||||
delta = direction_mapping.get(action)
|
delta = direction_mapping.get(action)
|
||||||
if delta is not None:
|
if delta is not None:
|
||||||
new_position = [current_position[0] + delta[0], current_position[1] + delta[1]]
|
new_position = [current_position[0] + delta[0], current_position[1] + delta[1]]
|
||||||
return new_position
|
return new_position
|
||||||
print(f"No valid movement action found for {action}.")
|
print(f"No valid movement action found for {action}.")
|
||||||
return current_position
|
return current_position
|
||||||
|
|
||||||
|
|
||||||
|
rotation_mapping = {
|
||||||
|
'north': ('cardinal', 0),
|
||||||
|
'east': ('cardinal', 270),
|
||||||
|
'south': ('cardinal', 180),
|
||||||
|
'west': ('cardinal', 90),
|
||||||
|
'north_east': ('diagonal', 0),
|
||||||
|
'south_east': ('diagonal', 270),
|
||||||
|
'south_west': ('diagonal', 180),
|
||||||
|
'north_west': ('diagonal', 90)
|
||||||
|
}
|
||||||
|
|
||||||
|
direction_mapping = {
|
||||||
|
'north': (0, -1),
|
||||||
|
'south': (0, 1),
|
||||||
|
'east': (1, 0),
|
||||||
|
'west': (-1, 0),
|
||||||
|
'north_east': (1, -1),
|
||||||
|
'north_west': (-1, -1),
|
||||||
|
'south_east': (1, 1),
|
||||||
|
'south_west': (-1, 1)
|
||||||
|
}
|
||||||
|
|||||||
@@ -242,13 +242,13 @@ class Renderer:
|
|||||||
|
|
||||||
def render_action_icons(self, action_entities):
|
def render_action_icons(self, action_entities):
|
||||||
"""
|
"""
|
||||||
Renders action icons based on the entities' specified actions, positions, and probabilities.
|
Renders action icons based on the entities' specified actions' name, position, rotation and probability.
|
||||||
|
Renders probabilities unequal 0.
|
||||||
|
|
||||||
:param action_entities: List of entities representing actions.
|
:param action_entities: List of entities representing actions.
|
||||||
:type action_entities: List[RenderEntity]
|
:type action_entities: List[RenderEntity]
|
||||||
"""
|
"""
|
||||||
|
self.fill_bg()
|
||||||
self.fill_bg() # Clear the background
|
|
||||||
font = pygame.font.Font(None, 24) # Initialize the font once for all text rendering
|
font = pygame.font.Font(None, 24) # Initialize the font once for all text rendering
|
||||||
|
|
||||||
for action_entity in action_entities:
|
for action_entity in action_entities:
|
||||||
@@ -258,11 +258,11 @@ class Renderer:
|
|||||||
|
|
||||||
# Load and potentially rotate the icon based on action name
|
# Load and potentially rotate the icon based on action name
|
||||||
img = self.assets[action_entity.name.lower()]
|
img = self.assets[action_entity.name.lower()]
|
||||||
|
if img is None:
|
||||||
|
print(f"Error: No asset available for '{action_entity.name}'. Skipping rendering this entity.")
|
||||||
|
continue
|
||||||
if hasattr(action_entity, 'rotation'):
|
if hasattr(action_entity, 'rotation'):
|
||||||
img = pygame.transform.rotate(img, action_entity.rotation)
|
img = pygame.transform.rotate(img, action_entity.rotation)
|
||||||
if img is None:
|
|
||||||
print(f"Error: No asset available for '{action_entity.name}'. Skipping rendering this entity.")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Blit the icon image
|
# Blit the icon image
|
||||||
img_rect = img.get_rect(center=(action_entity.pos[0] * self.cell_size + self.cell_size // 2,
|
img_rect = img.get_rect(center=(action_entity.pos[0] * self.cell_size + self.cell_size // 2,
|
||||||
|
|||||||
@@ -40,4 +40,4 @@ if __name__ == '__main__':
|
|||||||
print(f'Episode {episode} done...')
|
print(f'Episode {episode} done...')
|
||||||
break
|
break
|
||||||
|
|
||||||
plot_routes(factory, agents, )
|
plot_routes(factory, agents)
|
||||||
|
|||||||
Reference in New Issue
Block a user