mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-08 02:21:36 +02:00
Adapted commit: "started visualization of routes in plot single runs, assets missing."
This commit is contained in:

committed by
Julian Schönberger

parent
4571dc1cd1
commit
0d5b20a16f
@ -33,9 +33,11 @@ class TSPBaseAgent(ABC):
|
||||
self.local_optimization = True
|
||||
self._env = state
|
||||
self.state = self._env.state[c.AGENT][agent_i]
|
||||
self.spawn_position = np.array(self.state.pos)
|
||||
self._position_graph = self.generate_pos_graph()
|
||||
self._static_route = None
|
||||
self.cached_route = None
|
||||
self.action_list = []
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, *_, **__) -> int:
|
||||
@ -78,7 +80,7 @@ class TSPBaseAgent(ABC):
|
||||
start_time = time.time()
|
||||
|
||||
if self.cached_route is not None:
|
||||
print(f" Used cached route: {self.cached_route}")
|
||||
#print(f" Used cached route: {self.cached_route}")
|
||||
return copy.deepcopy(self.cached_route)
|
||||
|
||||
else:
|
||||
@ -99,11 +101,11 @@ class TSPBaseAgent(ABC):
|
||||
route = tsp.traveling_salesman_problem(self._position_graph,
|
||||
nodes=nodes, cycle=True, method=tsp.greedy_tsp)
|
||||
self.cached_route = copy.deepcopy(route)
|
||||
print(f"Cached route: {self.cached_route}")
|
||||
#print(f"Cached route: {self.cached_route}")
|
||||
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
print("TSP calculation took {:.2f} seconds to execute".format(duration))
|
||||
#print("TSP calculation took {:.2f} seconds to execute".format(duration))
|
||||
return route
|
||||
|
||||
def _door_is_close(self, state):
|
||||
|
@ -28,6 +28,7 @@ class TSPDirtAgent(TSPBaseAgent):
|
||||
action = self._use_door_or_move(door, di.DIRT)
|
||||
else:
|
||||
action = self._predict_move(di.DIRT)
|
||||
self.action_list.append(action)
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
try:
|
||||
action_obj = next(action_i for action_i, a in enumerate(self.state.actions) if a.name == action)
|
||||
|
@ -36,6 +36,7 @@ class TSPItemAgent(TSPBaseAgent):
|
||||
action = self._use_door_or_move(door, i.DROP_OFF if self.mode == MODE_BRING else i.ITEM)
|
||||
else:
|
||||
action = self._choose()
|
||||
self.action_list.append(action)
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
try:
|
||||
action_obj = next(action_i for action_i, a in enumerate(self.state.actions) if a.name == action)
|
||||
|
@ -35,6 +35,7 @@ class TSPTargetAgent(TSPBaseAgent):
|
||||
action = self._use_door_or_move(door, d.DESTINATION)
|
||||
else:
|
||||
action = self._predict_move(d.DESTINATION)
|
||||
self.action_list.append(action)
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
try:
|
||||
action_obj = next(action_i for action_i, a in enumerate(self.state.actions) if a.name == action)
|
||||
|
@ -1,45 +1,45 @@
|
||||
Agents:
|
||||
Clean test agent:
|
||||
Actions:
|
||||
- Noop
|
||||
- Charge
|
||||
- Clean
|
||||
- DoorUse
|
||||
- Move8
|
||||
Observations:
|
||||
- Combined:
|
||||
- Other
|
||||
- Walls
|
||||
- GlobalPosition
|
||||
- Battery
|
||||
- ChargePods
|
||||
- DirtPiles
|
||||
- Destinations
|
||||
- Doors
|
||||
- Maintainers
|
||||
Clones: 0
|
||||
Item test agent:
|
||||
Actions:
|
||||
- Noop
|
||||
- Charge
|
||||
- DestAction
|
||||
- DoorUse
|
||||
- ItemAction
|
||||
- Move8
|
||||
Observations:
|
||||
- Combined:
|
||||
- Other
|
||||
- Walls
|
||||
- GlobalPosition
|
||||
- Battery
|
||||
- ChargePods
|
||||
- Destinations
|
||||
- Doors
|
||||
- Items
|
||||
- Inventory
|
||||
- DropOffLocations
|
||||
- Maintainers
|
||||
Clones: 0
|
||||
# Clean test agent:
|
||||
# Actions:
|
||||
# - Noop
|
||||
# - Charge
|
||||
# - Clean
|
||||
# - DoorUse
|
||||
# - Move8
|
||||
# Observations:
|
||||
# - Combined:
|
||||
# - Other
|
||||
# - Walls
|
||||
# - GlobalPosition
|
||||
# - Battery
|
||||
# - ChargePods
|
||||
# - DirtPiles
|
||||
# - Destinations
|
||||
# - Doors
|
||||
# - Maintainers
|
||||
# Clones: 0
|
||||
# Item test agent:
|
||||
# Actions:
|
||||
# - Noop
|
||||
# - Charge
|
||||
# - DestAction
|
||||
# - DoorUse
|
||||
# - ItemAction
|
||||
# - Move8
|
||||
# Observations:
|
||||
# - Combined:
|
||||
# - Other
|
||||
# - Walls
|
||||
# - GlobalPosition
|
||||
# - Battery
|
||||
# - ChargePods
|
||||
# - Destinations
|
||||
# - Doors
|
||||
# - Items
|
||||
# - Inventory
|
||||
# - DropOffLocations
|
||||
# - Maintainers
|
||||
# Clones: 0
|
||||
Target test agent:
|
||||
Actions:
|
||||
- Noop
|
||||
@ -55,7 +55,7 @@ Agents:
|
||||
- Destinations
|
||||
- Doors
|
||||
- Maintainers
|
||||
Clones: 0
|
||||
Clones: 1
|
||||
|
||||
Entities:
|
||||
|
||||
@ -118,7 +118,7 @@ Rules:
|
||||
max_steps: 500
|
||||
|
||||
Tests:
|
||||
MaintainerTest: {}
|
||||
DirtAgentTest: {}
|
||||
ItemAgentTest: {}
|
||||
TargetAgentTest: {}
|
||||
# MaintainerTest: {}
|
||||
# DirtAgentTest: {}
|
||||
# ItemAgentTest: {}
|
||||
# TargetAgentTest: {}
|
||||
|
@ -42,6 +42,7 @@ class LevelParser(object):
|
||||
level_array = h.one_hot_level(self._parsed_level, c.SYMBOL_WALL)
|
||||
self.level_shape = level_array.shape
|
||||
self.size = self.pomdp_r ** 2 if self.pomdp_r else np.prod(self.level_shape)
|
||||
self.walls = None
|
||||
|
||||
def get_coordinates_for_symbol(self, symbol, negate=False) -> np.ndarray:
|
||||
"""
|
||||
@ -74,6 +75,7 @@ class LevelParser(object):
|
||||
# Walls
|
||||
walls = Walls.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL), self.size)
|
||||
entities.add_items({c.WALLS: walls})
|
||||
self.walls = self.get_coordinates_for_symbol(c.SYMBOL_WALL)
|
||||
|
||||
# Agents
|
||||
entities.add_items({c.AGENT: Agents(self.size)})
|
||||
|
@ -48,7 +48,7 @@ class EnvRecorder(Wrapper):
|
||||
"""
|
||||
obs_type, obs, reward, done, info = self.env.step(actions)
|
||||
if not self.episodes or self._curr_episode in self.episodes:
|
||||
summary: dict = self.env.summarize_state()
|
||||
summary: dict = self.env.unwrapped.summarize_state()
|
||||
# summary.update(done=done)
|
||||
# summary.update({'episode': self._curr_episode})
|
||||
# TODO Protobuff Adjustments ######
|
||||
|
@ -3,11 +3,15 @@ from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
|
||||
from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
|
||||
|
||||
from marl_factory_grid.utils.renderer import Renderer
|
||||
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||
|
||||
|
||||
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None,
|
||||
file_key: str = 'monitor', file_ext: str = 'pkl'):
|
||||
@ -60,3 +64,81 @@ def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, colum
|
||||
|
||||
prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||
print('Plotting done.')
|
||||
|
||||
|
||||
rotation_mapping = {
|
||||
'north': ('cardinal', 0),
|
||||
'east': ('cardinal', 270),
|
||||
'south': ('cardinal', 180),
|
||||
'west': ('cardinal', 90),
|
||||
'northeast': ('diagonal', 0),
|
||||
'southeast': ('diagonal', 270),
|
||||
'southwest': ('diagonal', 180),
|
||||
'northwest': ('diagonal', 90)
|
||||
}
|
||||
|
||||
|
||||
def plot_routes(factory, agents):
|
||||
renderer = Renderer(factory.map.level_shape, custom_assets_path={
|
||||
'cardinal': 'marl_factory_grid/utils/plotting/action_assets/cardinal.png',
|
||||
'diagonal': 'marl_factory_grid/utils/plotting/action_assets/diagonal.png',
|
||||
'door': 'marl_factory_grid/utils/plotting/action_assets/door.png',
|
||||
'wall': 'marl_factory_grid/environment/assets/wall.png'})
|
||||
|
||||
wall_positions = factory.map.walls
|
||||
|
||||
for index, agent in enumerate(agents):
|
||||
action_entities = []
|
||||
# Add walls to the action_entities list
|
||||
for pos in wall_positions:
|
||||
wall_entity = RenderEntity(
|
||||
name='wall',
|
||||
probability=1.0,
|
||||
pos=np.array(pos),
|
||||
)
|
||||
action_entities.append(wall_entity)
|
||||
current_position = agent.spawn_position
|
||||
|
||||
if hasattr(agent, 'action_probabilities'):
|
||||
# Handle RL agents with action probabilities
|
||||
top_actions = sorted(agent.action_probabilities.items(), key=lambda x: -x[1])[:4]
|
||||
else:
|
||||
# Handle deterministic agents by iterating through all actions in the list
|
||||
top_actions = [(action, 1.0) for action in agent.action_list]
|
||||
|
||||
for action, probability in top_actions:
|
||||
base_icon, rotation = rotation_mapping.get(action.lower(), ('north', 0))
|
||||
icon_name = base_icon
|
||||
new_position = action_to_coords(current_position, action.lower())
|
||||
print(f"current position type and value: {type(current_position)}, {new_position}")
|
||||
|
||||
action_entity = RenderEntity(
|
||||
name=icon_name,
|
||||
pos=np.array(current_position),
|
||||
probability=probability,
|
||||
rotation=rotation
|
||||
)
|
||||
action_entities.append(action_entity)
|
||||
current_position = new_position
|
||||
|
||||
renderer.render_action_icons(action_entities)
|
||||
|
||||
|
||||
def action_to_coords(current_position, action):
|
||||
direction_mapping = {
|
||||
'north': (0, -1),
|
||||
'south': (0, 1),
|
||||
'east': (1, 0),
|
||||
'west': (-1, 0),
|
||||
'northeast': (1, -1),
|
||||
'northwest': (-1, -1),
|
||||
'southeast': (1, 1),
|
||||
'southwest': (-1, 1)
|
||||
}
|
||||
delta = direction_mapping.get(action)
|
||||
if delta is not None:
|
||||
new_position = [current_position[0] + delta[0], current_position[1] + delta[1]]
|
||||
return new_position
|
||||
print(f"No valid action found for {action}.")
|
||||
return current_position
|
||||
|
||||
|
@ -29,10 +29,9 @@ class Renderer:
|
||||
AGENT_VIEW_COLOR = (9, 132, 227)
|
||||
ASSETS = Path(__file__).parent.parent
|
||||
|
||||
def __init__(self, lvl_shape: Tuple[int, int] = (16, 16),
|
||||
lvl_padded_shape: Union[Tuple[int, int], None] = None,
|
||||
cell_size: int = 40, fps: int = 7, factor: float = 0.9,
|
||||
grid_lines: bool = True, view_radius: int = 2):
|
||||
def __init__(self, lvl_shape: Tuple[int, int] = (16, 16), lvl_padded_shape: Union[Tuple[int, int], None] = None,
|
||||
cell_size: int = 40, fps: int = 7, factor: float = 0.9, grid_lines: bool = True, view_radius: int = 2,
|
||||
custom_assets_path=None):
|
||||
"""
|
||||
The Renderer class initializes and manages the rendering environment for the simulation,
|
||||
providing methods for preparing entities for display, loading assets, calculating visibility rectangles and
|
||||
@ -53,7 +52,6 @@ class Renderer:
|
||||
:param view_radius: Radius for agent's field of view.
|
||||
:type view_radius: int
|
||||
"""
|
||||
# TODO: Custom_assets paths
|
||||
self.grid_h, self.grid_w = lvl_shape
|
||||
self.lvl_padded_shape = lvl_padded_shape if lvl_padded_shape is not None else lvl_shape
|
||||
self.cell_size = cell_size
|
||||
@ -64,8 +62,9 @@ class Renderer:
|
||||
self.screen_size = (self.grid_w*cell_size, self.grid_h*cell_size)
|
||||
self.screen = pygame.display.set_mode(self.screen_size)
|
||||
self.clock = pygame.time.Clock()
|
||||
assets = list(self.ASSETS.rglob('*.png'))
|
||||
self.assets = {path.stem: self.load_asset(str(path), factor) for path in assets}
|
||||
self.custom_assets_path = custom_assets_path
|
||||
self.assets = self.load_assets(custom_assets_path)
|
||||
self.save_counter = 1
|
||||
self.fill_bg()
|
||||
|
||||
now = time.time()
|
||||
@ -116,6 +115,28 @@ class Renderer:
|
||||
rect.centerx, rect.centery = c_, r_
|
||||
return dict(source=img, dest=rect)
|
||||
|
||||
def load_assets(self, custom_assets_path):
|
||||
"""
|
||||
Loads assets from the custom path if provided, otherwise from the default path.
|
||||
"""
|
||||
assets_directory = custom_assets_path if custom_assets_path else self.ASSETS
|
||||
assets = {}
|
||||
if isinstance(assets_directory, dict):
|
||||
for key, path in assets_directory.items():
|
||||
asset = self.load_asset(path)
|
||||
if asset is not None:
|
||||
assets[key] = asset
|
||||
else:
|
||||
print(f"Warning: Asset for key '{key}' is missing and was not loaded.")
|
||||
else:
|
||||
for path in Path(assets_directory).rglob('*.png'):
|
||||
asset = self.load_asset(str(path))
|
||||
if asset is not None:
|
||||
assets[path.stem] = asset
|
||||
else:
|
||||
print(f"Warning: Asset '{path.stem}' is missing and was not loaded.")
|
||||
return assets
|
||||
|
||||
def load_asset(self, path, factor=1.0):
|
||||
"""
|
||||
Loads and resizes an asset from the specified path.
|
||||
@ -126,10 +147,28 @@ class Renderer:
|
||||
:type factor: float
|
||||
:return: Resized asset.
|
||||
"""
|
||||
s = int(factor*self.cell_size)
|
||||
asset = pygame.image.load(path).convert_alpha()
|
||||
asset = pygame.transform.smoothscale(asset, (s, s))
|
||||
return asset
|
||||
try:
|
||||
s = int(factor * self.cell_size)
|
||||
asset = pygame.image.load(path).convert_alpha()
|
||||
asset = pygame.transform.smoothscale(asset, (s, s))
|
||||
return asset
|
||||
except pygame.error as e:
|
||||
print(f"Failed to load asset {path}: {e}")
|
||||
return self.load_default_asset()
|
||||
|
||||
def load_default_asset(self, factor=1.0):
|
||||
"""
|
||||
Loads a default asset to be used when specific assets fail to load.
|
||||
"""
|
||||
default_path = 'marl_factory_grid/utils/plotting/action_assets/default.png'
|
||||
try:
|
||||
s = int(factor * self.cell_size)
|
||||
default_asset = pygame.image.load(default_path).convert_alpha()
|
||||
default_asset = pygame.transform.smoothscale(default_asset, (s, s))
|
||||
return default_asset
|
||||
except pygame.error as e:
|
||||
print(f"Failed to load default asset: {e}")
|
||||
return None
|
||||
|
||||
def visibility_rects(self, bp, view):
|
||||
"""
|
||||
@ -201,9 +240,58 @@ class Renderer:
|
||||
return np.transpose(rgb_obs, (2, 0, 1))
|
||||
# return torch.from_numpy(rgb_obs).permute(2, 0, 1)
|
||||
|
||||
def render_action_icons(self, action_entities):
|
||||
"""
|
||||
Renders action icons based on the entities' specified actions, positions, and probabilities.
|
||||
|
||||
:param action_entities: List of entities representing actions.
|
||||
:type action_entities: List[RenderEntity]
|
||||
"""
|
||||
|
||||
self.fill_bg() # Clear the background
|
||||
font = pygame.font.Font(None, 24) # Initialize the font once for all text rendering
|
||||
|
||||
for action_entity in action_entities:
|
||||
if not isinstance(action_entity.pos, np.ndarray) or action_entity.pos.ndim != 1:
|
||||
print(f"Invalid position format for entity: {action_entity.pos}")
|
||||
continue
|
||||
|
||||
# Load and potentially rotate the icon based on action name
|
||||
img = self.assets[action_entity.name.lower()]
|
||||
if hasattr(action_entity, 'rotation'):
|
||||
img = pygame.transform.rotate(img, action_entity.rotation)
|
||||
if img is None:
|
||||
print(f"Error: No asset available for '{action_entity.name}'. Skipping rendering this entity.")
|
||||
continue
|
||||
|
||||
# Blit the icon image
|
||||
img_rect = img.get_rect(center=(action_entity.pos[0] * self.cell_size + self.cell_size // 2,
|
||||
action_entity.pos[1] * self.cell_size + self.cell_size // 2))
|
||||
self.screen.blit(img, img_rect)
|
||||
|
||||
# Render the probability next to the icon if it exists
|
||||
if hasattr(action_entity, 'probability'):
|
||||
prob_text = font.render(f"{action_entity.probability:.2f}", True, (255, 0, 0))
|
||||
prob_text_rect = prob_text.get_rect(top=img_rect.bottom, left=img_rect.left)
|
||||
self.screen.blit(prob_text, prob_text_rect)
|
||||
|
||||
pygame.display.flip() # Update the display with all new blits
|
||||
self.save_screen("route_graph")
|
||||
|
||||
def save_screen(self, filename):
|
||||
"""
|
||||
Saves the current screen to a PNG file, appending a counter to ensure uniqueness.
|
||||
:param filename: The base filename where to save the image.
|
||||
:param agent_id: Unique identifier for the agent.
|
||||
"""
|
||||
unique_filename = f"{filename}_agent_{self.save_counter}.png"
|
||||
self.save_counter += 1
|
||||
pygame.image.save(self.screen, unique_filename)
|
||||
print(f"Image saved as {unique_filename}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
renderer = Renderer(fps=2, cell_size=40)
|
||||
renderer = Renderer(cell_size=40, fps=2)
|
||||
for pos_i in range(15):
|
||||
entity_1 = RenderEntity('agent_collision', [5, pos_i], 1, 'idle', 'idle')
|
||||
renderer.render([entity_1])
|
||||
|
@ -33,6 +33,8 @@ class RenderEntity:
|
||||
id: int = 0
|
||||
aux: Any = None
|
||||
real_name: str = 'none'
|
||||
probability: float = None # Default to None if not used
|
||||
rotation: int = 0 # Default rotation if not specified
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -94,4 +94,4 @@ def two_rooms_one_door_modified_multi_agent_eval(emergent_phenomenon):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dirt_quadrant_single_agent_training()
|
||||
dirt_quadrant_5_multi_agent_ctde_eval(True)
|
@ -104,4 +104,4 @@ def two_rooms_one_door_modified_multi_agent_tsp(emergent_phenomenon):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dirt_quadrant_multi_agent_tsp(False)
|
||||
two_rooms_one_door_modified_multi_agent_tsp(False)
|
||||
|
12
test_run.py
12
test_run.py
@ -1,4 +1,5 @@
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
|
||||
from tqdm import trange
|
||||
|
||||
@ -7,12 +8,17 @@ from marl_factory_grid.algorithms.static.TSP_item_agent import TSPItemAgent
|
||||
from marl_factory_grid.algorithms.static.TSP_target_agent import TSPTargetAgent
|
||||
from marl_factory_grid.environment.factory import Factory
|
||||
|
||||
from marl_factory_grid.utils.plotting.plot_single_runs import plot_routes
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Render at each step?
|
||||
|
||||
run_path = Path('study_out')
|
||||
render = True
|
||||
monitor = True
|
||||
record = True
|
||||
|
||||
# Path to config File
|
||||
path = Path('marl_factory_grid/configs/simple_crossing.yaml')
|
||||
path = Path('marl_factory_grid/configs/test_config.yaml')
|
||||
|
||||
# Env Init
|
||||
factory = Factory(path)
|
||||
@ -33,3 +39,5 @@ if __name__ == '__main__':
|
||||
if done:
|
||||
print(f'Episode {episode} done...')
|
||||
break
|
||||
|
||||
plot_routes(factory, agents, )
|
||||
|
Reference in New Issue
Block a user