started visualization of routes in plot single runs, assets missing.

This commit is contained in:
Chanumask
2024-05-02 17:07:33 +02:00
parent 5ee39eba8d
commit 9f2cb103f4
12 changed files with 249 additions and 62 deletions

View File

@@ -33,10 +33,12 @@ class TSPBaseAgent(ABC):
self.local_optimization = True self.local_optimization = True
self._env = state self._env = state
self.state = self._env.state[c.AGENT][agent_i] self.state = self._env.state[c.AGENT][agent_i]
self.spawn_position = np.array(self.state.pos)
self._position_graph = self.generate_pos_graph() self._position_graph = self.generate_pos_graph()
self._static_route = None self._static_route = None
self.cached_route = None self.cached_route = None
self.fallback_action = None self.fallback_action = None
self.action_list = []
@abstractmethod @abstractmethod
def predict(self, *_, **__) -> int: def predict(self, *_, **__) -> int:

View File

@@ -30,6 +30,7 @@ class TSPDirtAgent(TSPBaseAgent):
action = self._use_door_or_move(door, di.DIRT) action = self._use_door_or_move(door, di.DIRT)
else: else:
action = self._predict_move(di.DIRT) action = self._predict_move(di.DIRT)
self.action_list.append(action)
# Translate the action_object to an integer to have the same output as any other model # Translate the action_object to an integer to have the same output as any other model
try: try:
action_obj = next(action_i for action_i, a in enumerate(self.state.actions) if a.name == action) action_obj = next(action_i for action_i, a in enumerate(self.state.actions) if a.name == action)

View File

@@ -38,6 +38,7 @@ class TSPItemAgent(TSPBaseAgent):
action = self._use_door_or_move(door, i.DROP_OFF if self.mode == MODE_BRING else i.ITEM) action = self._use_door_or_move(door, i.DROP_OFF if self.mode == MODE_BRING else i.ITEM)
else: else:
action = self._choose() action = self._choose()
self.action_list.append(action)
# Translate the action_object to an integer to have the same output as any other model # Translate the action_object to an integer to have the same output as any other model
try: try:
action_obj = next(action_i for action_i, a in enumerate(self.state.actions) if a.name == action) action_obj = next(action_i for action_i, a in enumerate(self.state.actions) if a.name == action)

View File

@@ -38,6 +38,7 @@ class TSPTargetAgent(TSPBaseAgent):
action = self._use_door_or_move(door, d.DESTINATION) action = self._use_door_or_move(door, d.DESTINATION)
else: else:
action = self._predict_move(d.DESTINATION) action = self._predict_move(d.DESTINATION)
self.action_list.append(action)
# Translate the action_object to an integer to have the same output as any other model # Translate the action_object to an integer to have the same output as any other model
try: try:
action_obj = next(action_i for action_i, a in enumerate(self.state.actions) if a.name == action) action_obj = next(action_i for action_i, a in enumerate(self.state.actions) if a.name == action)

View File

@@ -6,7 +6,7 @@ General:
level_name: simple_crossing level_name: simple_crossing
# View Radius; 0 = full observatbility # View Radius; 0 = full observatbility
pomdp_r: 0 pomdp_r: 0
verbose: true verbose: false
tests: false tests: false
Agents: Agents:

View File

@@ -1,45 +1,45 @@
Agents: Agents:
Clean test agent: # Clean test agent:
Actions: # Actions:
- Noop # - Noop
- Charge # - Charge
- Clean # - Clean
- DoorUse # - DoorUse
- Move8 # - Move8
Observations: # Observations:
- Combined: # - Combined:
- Other # - Other
- Walls # - Walls
- GlobalPosition # - GlobalPosition
- Battery # - Battery
- ChargePods # - ChargePods
- DirtPiles # - DirtPiles
- Destinations # - Destinations
- Doors # - Doors
- Maintainers # - Maintainers
Clones: 0 # Clones: 0
Item test agent: # Item test agent:
Actions: # Actions:
- Noop # - Noop
- Charge # - Charge
- DestAction # - DestAction
- DoorUse # - DoorUse
- ItemAction # - ItemAction
- Move8 # - Move8
Observations: # Observations:
- Combined: # - Combined:
- Other # - Other
- Walls # - Walls
- GlobalPosition # - GlobalPosition
- Battery # - Battery
- ChargePods # - ChargePods
- Destinations # - Destinations
- Doors # - Doors
- Items # - Items
- Inventory # - Inventory
- DropOffLocations # - DropOffLocations
- Maintainers # - Maintainers
Clones: 0 # Clones: 0
Target test agent: Target test agent:
Actions: Actions:
- Noop - Noop
@@ -55,7 +55,7 @@ Agents:
- Destinations - Destinations
- Doors - Doors
- Maintainers - Maintainers
Clones: 0 Clones: 1
Entities: Entities:
@@ -118,7 +118,7 @@ Rules:
max_steps: 500 max_steps: 500
Tests: Tests:
MaintainerTest: {} # MaintainerTest: {}
DirtAgentTest: {} # DirtAgentTest: {}
ItemAgentTest: {} # ItemAgentTest: {}
TargetAgentTest: {} # TargetAgentTest: {}

View File

@@ -42,6 +42,7 @@ class LevelParser(object):
level_array = h.one_hot_level(self._parsed_level, c.SYMBOL_WALL) level_array = h.one_hot_level(self._parsed_level, c.SYMBOL_WALL)
self.level_shape = level_array.shape self.level_shape = level_array.shape
self.size = self.pomdp_r ** 2 if self.pomdp_r else np.prod(self.level_shape) self.size = self.pomdp_r ** 2 if self.pomdp_r else np.prod(self.level_shape)
self.walls = None
def get_coordinates_for_symbol(self, symbol, negate=False) -> np.ndarray: def get_coordinates_for_symbol(self, symbol, negate=False) -> np.ndarray:
""" """
@@ -74,6 +75,7 @@ class LevelParser(object):
# Walls # Walls
walls = Walls.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL), self.size) walls = Walls.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL), self.size)
entities.add_items({c.WALLS: walls}) entities.add_items({c.WALLS: walls})
self.walls = self.get_coordinates_for_symbol(c.SYMBOL_WALL)
# Agents # Agents
entities.add_items({c.AGENT: Agents(self.size)}) entities.add_items({c.AGENT: Agents(self.size)})

View File

@@ -48,7 +48,7 @@ class EnvRecorder(Wrapper):
""" """
obs_type, obs, reward, done, info = self.env.step(actions) obs_type, obs, reward, done, info = self.env.step(actions)
if not self.episodes or self._curr_episode in self.episodes: if not self.episodes or self._curr_episode in self.episodes:
summary: dict = self.env.summarize_state() summary: dict = self.env.unwrapped.summarize_state()
# summary.update(done=done) # summary.update(done=done)
# summary.update({'episode': self._curr_episode}) # summary.update({'episode': self._curr_episode})
# TODO Protobuff Adjustments ###### # TODO Protobuff Adjustments ######

View File

@@ -3,11 +3,15 @@ from os import PathLike
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
import numpy as np
import pandas as pd import pandas as pd
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
from marl_factory_grid.utils.renderer import Renderer
from marl_factory_grid.utils.utility_classes import RenderEntity
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None, def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None,
file_key: str = 'monitor', file_ext: str = 'pkl'): file_key: str = 'monitor', file_ext: str = 'pkl'):
@@ -60,3 +64,81 @@ def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, colum
prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex) prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
print('Plotting done.') print('Plotting done.')
rotation_mapping = {
'north': ('cardinal', 0),
'east': ('cardinal', 270),
'south': ('cardinal', 180),
'west': ('cardinal', 90),
'northeast': ('diagonal', 0),
'southeast': ('diagonal', 270),
'southwest': ('diagonal', 180),
'northwest': ('diagonal', 90)
}
def plot_routes(factory, agents):
renderer = Renderer(factory.map.level_shape, custom_assets_path={
'cardinal': 'marl_factory_grid/utils/plotting/action_assets/cardinal.png',
'diagonal': 'marl_factory_grid/utils/plotting/action_assets/diagonal.png',
'door': 'marl_factory_grid/utils/plotting/action_assets/door.png',
'wall': 'marl_factory_grid/environment/assets/wall.png'})
wall_positions = factory.map.walls
for index, agent in enumerate(agents):
action_entities = []
# Add walls to the action_entities list
for pos in wall_positions:
wall_entity = RenderEntity(
name='wall',
probability=1.0,
pos=np.array(pos),
)
action_entities.append(wall_entity)
current_position = agent.spawn_position
if hasattr(agent, 'action_probabilities'):
# Handle RL agents with action probabilities
top_actions = sorted(agent.action_probabilities.items(), key=lambda x: -x[1])[:4]
else:
# Handle deterministic agents by iterating through all actions in the list
top_actions = [(action, 1.0) for action in agent.action_list]
for action, probability in top_actions:
base_icon, rotation = rotation_mapping.get(action.lower(), ('north', 0))
icon_name = base_icon
new_position = action_to_coords(current_position, action.lower())
print(f"current position type and value: {type(current_position)}, {new_position}")
action_entity = RenderEntity(
name=icon_name,
pos=np.array(current_position),
probability=probability,
rotation=rotation
)
action_entities.append(action_entity)
current_position = new_position
renderer.render_action_icons(action_entities)
def action_to_coords(current_position, action):
direction_mapping = {
'north': (0, -1),
'south': (0, 1),
'east': (1, 0),
'west': (-1, 0),
'northeast': (1, -1),
'northwest': (-1, -1),
'southeast': (1, 1),
'southwest': (-1, 1)
}
delta = direction_mapping.get(action)
if delta is not None:
new_position = [current_position[0] + delta[0], current_position[1] + delta[1]]
return new_position
print(f"No valid action found for {action}.")
return current_position

View File

@@ -29,10 +29,9 @@ class Renderer:
AGENT_VIEW_COLOR = (9, 132, 227) AGENT_VIEW_COLOR = (9, 132, 227)
ASSETS = Path(__file__).parent.parent ASSETS = Path(__file__).parent.parent
def __init__(self, lvl_shape: Tuple[int, int] = (16, 16), def __init__(self, lvl_shape: Tuple[int, int] = (16, 16), lvl_padded_shape: Union[Tuple[int, int], None] = None,
lvl_padded_shape: Union[Tuple[int, int], None] = None, cell_size: int = 40, fps: int = 7, factor: float = 0.9, grid_lines: bool = True, view_radius: int = 2,
cell_size: int = 40, fps: int = 7, factor: float = 0.9, custom_assets_path=None):
grid_lines: bool = True, view_radius: int = 2):
""" """
The Renderer class initializes and manages the rendering environment for the simulation, The Renderer class initializes and manages the rendering environment for the simulation,
providing methods for preparing entities for display, loading assets, calculating visibility rectangles and providing methods for preparing entities for display, loading assets, calculating visibility rectangles and
@@ -53,7 +52,6 @@ class Renderer:
:param view_radius: Radius for agent's field of view. :param view_radius: Radius for agent's field of view.
:type view_radius: int :type view_radius: int
""" """
# TODO: Custom_assets paths
self.grid_h, self.grid_w = lvl_shape self.grid_h, self.grid_w = lvl_shape
self.lvl_padded_shape = lvl_padded_shape if lvl_padded_shape is not None else lvl_shape self.lvl_padded_shape = lvl_padded_shape if lvl_padded_shape is not None else lvl_shape
self.cell_size = cell_size self.cell_size = cell_size
@@ -64,8 +62,9 @@ class Renderer:
self.screen_size = (self.grid_w*cell_size, self.grid_h*cell_size) self.screen_size = (self.grid_w*cell_size, self.grid_h*cell_size)
self.screen = pygame.display.set_mode(self.screen_size) self.screen = pygame.display.set_mode(self.screen_size)
self.clock = pygame.time.Clock() self.clock = pygame.time.Clock()
assets = list(self.ASSETS.rglob('*.png')) self.custom_assets_path = custom_assets_path
self.assets = {path.stem: self.load_asset(str(path), factor) for path in assets} self.assets = self.load_assets(custom_assets_path)
self.save_counter = 1
self.fill_bg() self.fill_bg()
# now = time.time() # now = time.time()
@@ -116,6 +115,28 @@ class Renderer:
rect.centerx, rect.centery = c_, r_ rect.centerx, rect.centery = c_, r_
return dict(source=img, dest=rect) return dict(source=img, dest=rect)
def load_assets(self, custom_assets_path):
"""
Loads assets from the custom path if provided, otherwise from the default path.
"""
assets_directory = custom_assets_path if custom_assets_path else self.ASSETS
assets = {}
if isinstance(assets_directory, dict):
for key, path in assets_directory.items():
asset = self.load_asset(path)
if asset is not None:
assets[key] = asset
else:
print(f"Warning: Asset for key '{key}' is missing and was not loaded.")
else:
for path in Path(assets_directory).rglob('*.png'):
asset = self.load_asset(str(path))
if asset is not None:
assets[path.stem] = asset
else:
print(f"Warning: Asset '{path.stem}' is missing and was not loaded.")
return assets
def load_asset(self, path, factor=1.0): def load_asset(self, path, factor=1.0):
""" """
Loads and resizes an asset from the specified path. Loads and resizes an asset from the specified path.
@@ -126,10 +147,28 @@ class Renderer:
:type factor: float :type factor: float
:return: Resized asset. :return: Resized asset.
""" """
s = int(factor*self.cell_size) try:
asset = pygame.image.load(path).convert_alpha() s = int(factor * self.cell_size)
asset = pygame.transform.smoothscale(asset, (s, s)) asset = pygame.image.load(path).convert_alpha()
return asset asset = pygame.transform.smoothscale(asset, (s, s))
return asset
except pygame.error as e:
print(f"Failed to load asset {path}: {e}")
return self.load_default_asset()
def load_default_asset(self, factor=1.0):
"""
Loads a default asset to be used when specific assets fail to load.
"""
default_path = 'marl_factory_grid/utils/plotting/action_assets/default.png'
try:
s = int(factor * self.cell_size)
default_asset = pygame.image.load(default_path).convert_alpha()
default_asset = pygame.transform.smoothscale(default_asset, (s, s))
return default_asset
except pygame.error as e:
print(f"Failed to load default asset: {e}")
return None
def visibility_rects(self, bp, view): def visibility_rects(self, bp, view):
""" """
@@ -196,9 +235,58 @@ class Renderer:
return np.transpose(rgb_obs, (2, 0, 1)) return np.transpose(rgb_obs, (2, 0, 1))
# return torch.from_numpy(rgb_obs).permute(2, 0, 1) # return torch.from_numpy(rgb_obs).permute(2, 0, 1)
def render_action_icons(self, action_entities):
"""
Renders action icons based on the entities' specified actions, positions, and probabilities.
:param action_entities: List of entities representing actions.
:type action_entities: List[RenderEntity]
"""
self.fill_bg() # Clear the background
font = pygame.font.Font(None, 24) # Initialize the font once for all text rendering
for action_entity in action_entities:
if not isinstance(action_entity.pos, np.ndarray) or action_entity.pos.ndim != 1:
print(f"Invalid position format for entity: {action_entity.pos}")
continue
# Load and potentially rotate the icon based on action name
img = self.assets[action_entity.name.lower()]
if hasattr(action_entity, 'rotation'):
img = pygame.transform.rotate(img, action_entity.rotation)
if img is None:
print(f"Error: No asset available for '{action_entity.name}'. Skipping rendering this entity.")
continue
# Blit the icon image
img_rect = img.get_rect(center=(action_entity.pos[0] * self.cell_size + self.cell_size // 2,
action_entity.pos[1] * self.cell_size + self.cell_size // 2))
self.screen.blit(img, img_rect)
# Render the probability next to the icon if it exists
if hasattr(action_entity, 'probability'):
prob_text = font.render(f"{action_entity.probability:.2f}", True, (255, 0, 0))
prob_text_rect = prob_text.get_rect(top=img_rect.bottom, left=img_rect.left)
self.screen.blit(prob_text, prob_text_rect)
pygame.display.flip() # Update the display with all new blits
self.save_screen("route_graph")
def save_screen(self, filename):
"""
Saves the current screen to a PNG file, appending a counter to ensure uniqueness.
:param filename: The base filename where to save the image.
:param agent_id: Unique identifier for the agent.
"""
unique_filename = f"{filename}_agent_{self.save_counter}.png"
self.save_counter += 1
pygame.image.save(self.screen, unique_filename)
print(f"Image saved as {unique_filename}")
if __name__ == '__main__': if __name__ == '__main__':
renderer = Renderer(fps=2, cell_size=40) renderer = Renderer(cell_size=40, fps=2)
for pos_i in range(15): for pos_i in range(15):
entity_1 = RenderEntity('agent_collision', [5, pos_i], 1, 'idle', 'idle') entity_1 = RenderEntity('agent_collision', [5, pos_i], 1, 'idle', 'idle')
renderer.render([entity_1]) renderer.render([entity_1])

View File

@@ -33,6 +33,8 @@ class RenderEntity:
id: int = 0 id: int = 0
aux: Any = None aux: Any = None
real_name: str = 'none' real_name: str = 'none'
probability: float = None # Default to None if not used
rotation: int = 0 # Default rotation if not specified
@dataclass @dataclass

View File

@@ -1,4 +1,5 @@
from pathlib import Path from pathlib import Path
from pprint import pprint
from tqdm import trange from tqdm import trange
@@ -7,12 +8,17 @@ from marl_factory_grid.algorithms.static.TSP_item_agent import TSPItemAgent
from marl_factory_grid.algorithms.static.TSP_target_agent import TSPTargetAgent from marl_factory_grid.algorithms.static.TSP_target_agent import TSPTargetAgent
from marl_factory_grid.environment.factory import Factory from marl_factory_grid.environment.factory import Factory
from marl_factory_grid.utils.plotting.plot_single_runs import plot_routes
if __name__ == '__main__': if __name__ == '__main__':
# Render at each step?
run_path = Path('study_out')
render = True render = True
monitor = True
record = True
# Path to config File # Path to config File
path = Path('marl_factory_grid/configs/simple_crossing.yaml') path = Path('marl_factory_grid/configs/test_config.yaml')
# Env Init # Env Init
factory = Factory(path) factory = Factory(path)
@@ -34,3 +40,5 @@ if __name__ == '__main__':
if done: if done:
print(f'Episode {episode} done...') print(f'Episode {episode} done...')
break break
plot_routes(factory, agents, )