Merge branch 'route_plotting' into rl_plotting

This commit is contained in:
Chanumask
2024-05-10 12:48:06 +02:00
26 changed files with 494 additions and 127 deletions

View File

@ -33,9 +33,12 @@ class TSPBaseAgent(ABC):
self.local_optimization = True
self._env = state
self.state = self._env.state[c.AGENT][agent_i]
self.spawn_position = np.array(self.state.pos)
self._position_graph = self.generate_pos_graph()
self._static_route = None
self.cached_route = None
self.fallback_action = None
self.action_list = []
@abstractmethod
def predict(self, *_, **__) -> int:
@ -47,6 +50,46 @@ class TSPBaseAgent(ABC):
"""
return 0
def calculate_tsp_route(self, target_identifier):
"""
Calculate the TSP route to reach a target.
:param target_identifier: Identifier of the target entity
:type target_identifier: str
:return: TSP route
:rtype: List[int]
"""
target_positions = [x for x in self._env.state[target_identifier].positions if x != c.VALUE_NO_POS]
# if there are cached routes, search for one matching the current and target position
if self._env.state.route_cache and (
route := self._env.state.get_cached_route(self.state.pos, target_positions)) is not None:
# print(f"Retrieved cached route: {route}")
return route
# if none are found, calculate tsp route and cache it
else:
start_time = time.time()
if self.local_optimization:
nodes = \
[self.state.pos] + \
[x for x in target_positions if max(abs(np.subtract(x, self.state.pos))) < 3]
try:
while len(nodes) < 7:
nodes += [next(x for x in target_positions if x not in nodes)]
except StopIteration:
nodes = [self.state.pos] + target_positions
else:
nodes = [self.state.pos] + target_positions
route = tsp.traveling_salesman_problem(self._position_graph,
nodes=nodes, cycle=True, method=tsp.greedy_tsp)
duration = time.time() - start_time
print("TSP calculation took {:.2f} seconds to execute".format(duration))
self._env.state.cache_route(route)
return route
def _use_door_or_move(self, door, target):
"""
Helper method to decide whether to use a door or move towards a target.
@ -65,47 +108,6 @@ class TSPBaseAgent(ABC):
action = self._predict_move(target)
return action
def calculate_tsp_route(self, target_identifier):
"""
Calculate the TSP route to reach a target.
:param target_identifier: Identifier of the target entity
:type target_identifier: str
:return: TSP route
:rtype: List[int]
"""
start_time = time.time()
if self.cached_route is not None:
print(f" Used cached route: {self.cached_route}")
return copy.deepcopy(self.cached_route)
else:
positions = [x for x in self._env.state[target_identifier].positions if x != c.VALUE_NO_POS]
if self.local_optimization:
nodes = \
[self.state.pos] + \
[x for x in positions if max(abs(np.subtract(x, self.state.pos))) < 3]
try:
while len(nodes) < 7:
nodes += [next(x for x in positions if x not in nodes)]
except StopIteration:
nodes = [self.state.pos] + positions
else:
nodes = [self.state.pos] + positions
route = tsp.traveling_salesman_problem(self._position_graph,
nodes=nodes, cycle=True, method=tsp.greedy_tsp)
self.cached_route = copy.deepcopy(route)
print(f"Cached route: {self.cached_route}")
end_time = time.time()
duration = end_time - start_time
print("TSP calculation took {:.2f} seconds to execute".format(duration))
return route
def _door_is_close(self, state):
"""
Check if a door is close to the agent's position.
@ -171,8 +173,11 @@ class TSPBaseAgent(ABC):
action = next(action for action, pos_diff in MOVEMAP.items() if
np.all(diff == pos_diff) and action in allowed_directions)
except StopIteration:
print(f"No valid action found for pos diff: {diff}. Using fallback action.")
action = choice(self.state.actions).name
print(f"No valid action found for pos diff: {diff}. Using fallback action: {self.fallback_action}.")
if self.fallback_action and any(self.fallback_action == action.name for action in self.state.actions):
action = self.fallback_action
else:
action = choice(self.state.actions).name
else:
action = choice(self.state.actions).name
# noinspection PyUnboundLocalVariable

View File

@ -1,6 +1,7 @@
from marl_factory_grid.algorithms.static.TSP_base_agent import TSPBaseAgent
from marl_factory_grid.modules.clean_up import constants as di
from marl_factory_grid.environment import constants as c
future_planning = 7
@ -12,6 +13,7 @@ class TSPDirtAgent(TSPBaseAgent):
Initializes a TSPDirtAgent that aims to clean dirt in the environment.
"""
super(TSPDirtAgent, self).__init__(*args, **kwargs)
self.fallback_action = c.NOOP
def predict(self, *_, **__):
"""
@ -28,6 +30,7 @@ class TSPDirtAgent(TSPBaseAgent):
action = self._use_door_or_move(door, di.DIRT)
else:
action = self._predict_move(di.DIRT)
self.action_list.append(action)
# Translate the action_object to an integer to have the same output as any other model
try:
action_obj = next(action_i for action_i, a in enumerate(self.state.actions) if a.name == action)

View File

@ -3,6 +3,7 @@ import numpy as np
from marl_factory_grid.algorithms.static.TSP_base_agent import TSPBaseAgent
from marl_factory_grid.modules.items import constants as i
from marl_factory_grid.environment import constants as c
future_planning = 7
inventory_size = 3
@ -22,6 +23,7 @@ class TSPItemAgent(TSPBaseAgent):
"""
super(TSPItemAgent, self).__init__(*args, **kwargs)
self.mode = mode
self.fallback_action = c.NOOP
def predict(self, *_, **__):
item_at_position = self._env.state[i.ITEM].by_pos(self.state.pos)
@ -36,6 +38,7 @@ class TSPItemAgent(TSPBaseAgent):
action = self._use_door_or_move(door, i.DROP_OFF if self.mode == MODE_BRING else i.ITEM)
else:
action = self._choose()
self.action_list.append(action)
# Translate the action_object to an integer to have the same output as any other model
try:
action_obj = next(action_i for action_i, a in enumerate(self.state.actions) if a.name == action)

View File

@ -2,6 +2,8 @@ from marl_factory_grid.algorithms.static.TSP_base_agent import TSPBaseAgent
from marl_factory_grid.modules.destinations import constants as d
from marl_factory_grid.modules.doors import constants as do
from marl_factory_grid.environment import constants as c
future_planning = 7
@ -13,6 +15,7 @@ class TSPTargetAgent(TSPBaseAgent):
Initializes a TSPTargetAgent that aims to reach destinations.
"""
super(TSPTargetAgent, self).__init__(*args, **kwargs)
self.fallback_action = c.NOOP
def _handle_doors(self, state):
"""
@ -35,6 +38,7 @@ class TSPTargetAgent(TSPBaseAgent):
action = self._use_door_or_move(door, d.DESTINATION)
else:
action = self._predict_move(d.DESTINATION)
self.action_list.append(action)
# Translate the action_object to an integer to have the same output as any other model
try:
action_obj = next(action_i for action_i, a in enumerate(self.state.actions) if a.name == action)

View File

@ -20,7 +20,7 @@ Agents:
- Destination
# Avaiable Spawn Positions as list
Positions:
- (1,2)
- (2,1)
# It is okay to collide with other agents, so that
# they end up on the same position
is_blocking_pos: false
@ -33,7 +33,7 @@ Agents:
- Other
- Destination
Positions:
- (2,1)
- (1,2)
is_blocking_pos: false
# Other noteworthy Entitites
@ -45,9 +45,9 @@ Entities:
SpawnDestinationsPerAgent:
coords_or_quantity:
Agent_horizontal:
- (3,2)
Agent_vertical:
- (2,3)
Agent_vertical:
- (3,2)
# Whether you want to provide a numeric Position observation.
# GlobalPositions:
# normalized: false

View File

@ -1,45 +1,45 @@
Agents:
Clean test agent:
Actions:
- Noop
- Charge
- Clean
- DoorUse
- Move8
Observations:
- Combined:
- Other
- Walls
- GlobalPosition
- Battery
- ChargePods
- DirtPiles
- Destinations
- Doors
- Maintainers
Clones: 0
Item test agent:
Actions:
- Noop
- Charge
- DestAction
- DoorUse
- ItemAction
- Move8
Observations:
- Combined:
- Other
- Walls
- GlobalPosition
- Battery
- ChargePods
- Destinations
- Doors
- Items
- Inventory
- DropOffLocations
- Maintainers
Clones: 0
# Clean test agent:
# Actions:
# - Noop
# - Charge
# - Clean
# - DoorUse
# - Move8
# Observations:
# - Combined:
# - Other
# - Walls
# - GlobalPosition
# - Battery
# - ChargePods
# - DirtPiles
# - Destinations
# - Doors
# - Maintainers
# Clones: 0
# Item test agent:
# Actions:
# - Noop
# - Charge
# - DestAction
# - DoorUse
# - ItemAction
# - Move8
# Observations:
# - Combined:
# - Other
# - Walls
# - GlobalPosition
# - Battery
# - ChargePods
# - Destinations
# - Doors
# - Items
# - Inventory
# - DropOffLocations
# - Maintainers
# Clones: 0
Target test agent:
Actions:
- Noop
@ -55,7 +55,7 @@ Agents:
- Destinations
- Doors
- Maintainers
Clones: 0
Clones: 1
Entities:
@ -90,7 +90,7 @@ Entities:
General:
env_seed: 69
individual_rewards: true
level_name: large
level_name: quadrant
pomdp_r: 3
verbose: false
tests: false
@ -115,10 +115,10 @@ Rules:
# Done Conditions
DoneAtMaxStepsReached:
max_steps: 500
max_steps: 20
Tests:
MaintainerTest: {}
DirtAgentTest: {}
ItemAgentTest: {}
TargetAgentTest: {}
# MaintainerTest: {}
# DirtAgentTest: {}
# ItemAgentTest: {}
# TargetAgentTest: {}

View File

@ -42,6 +42,7 @@ class LevelParser(object):
level_array = h.one_hot_level(self._parsed_level, c.SYMBOL_WALL)
self.level_shape = level_array.shape
self.size = self.pomdp_r ** 2 if self.pomdp_r else np.prod(self.level_shape)
self.walls = None
def get_coordinates_for_symbol(self, symbol, negate=False) -> np.ndarray:
"""
@ -74,6 +75,7 @@ class LevelParser(object):
# Walls
walls = Walls.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL), self.size)
entities.add_items({c.WALLS: walls})
self.walls = self.get_coordinates_for_symbol(c.SYMBOL_WALL)
# Agents
entities.add_items({c.AGENT: Agents(self.size)})

View File

@ -48,7 +48,7 @@ class EnvRecorder(Wrapper):
"""
obs_type, obs, reward, done, info = self.env.step(actions)
if not self.episodes or self._curr_episode in self.episodes:
summary: dict = self.env.summarize_state()
summary: dict = self.env.unwrapped.summarize_state()
# summary.update(done=done)
# summary.update({'episode': self._curr_episode})
# TODO Protobuff Adjustments ######

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 990 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 455 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 425 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 439 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 443 B

View File

@ -1,13 +1,18 @@
import json
import pickle
from os import PathLike
from pathlib import Path
from typing import Union
import numpy as np
import pandas as pd
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
from marl_factory_grid.utils.renderer import Renderer
from marl_factory_grid.utils.utility_classes import RenderEntity
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None,
file_key: str = 'monitor', file_ext: str = 'pkl'):
@ -60,3 +65,156 @@ def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, colum
prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
print('Plotting done.')
def plot_routes(factory, agents):
"""
Creates a plot of the agents' actions on the level map by creating a Renderer and Render Entities that hold the
icon that corresponds to the action. For deterministic agents, simply displays the agents path of actions while for
RL agents that can supply an action map or action probabilities from their policy net.
"""
renderer = Renderer(factory.map.level_shape, custom_assets_path={
'cardinal': 'marl_factory_grid/utils/plotting/action_assets/cardinal.png',
'diagonal': 'marl_factory_grid/utils/plotting/action_assets/diagonal.png',
'use_door': 'marl_factory_grid/utils/plotting/action_assets/door_action.png',
'wall': 'marl_factory_grid/environment/assets/wall.png',
'machine_action': 'marl_factory_grid/utils/plotting/action_assets/machine_action.png',
'clean_action': 'marl_factory_grid/utils/plotting/action_assets/clean_action.png',
'destination_action': 'marl_factory_grid/utils/plotting/action_assets/destination_action.png',
'noop': 'marl_factory_grid/utils/plotting/action_assets/noop.png',
'charge_action': 'marl_factory_grid/utils/plotting/action_assets/charge_action.png'})
wall_positions = swap_coordinates(factory.map.walls)
wall_entities = [RenderEntity(name='wall', probability=0, pos=np.array(pos)) for pos in wall_positions]
action_entities = list(wall_entities)
for index, agent in enumerate(agents):
current_position = swap_coordinates(agent.spawn_position)
if hasattr(agent, 'action_probabilities'):
# Handle RL agents with action probabilities
top_actions = sorted(agent.action_probabilities.items(), key=lambda x: -x[1])[:4]
else:
# Handle deterministic agents by iterating through all actions in the list
top_actions = [(action, 1.0) for action in agent.action_list]
for action, probability in top_actions:
if action.lower() in rotation_mapping:
base_icon, rotation = rotation_mapping[action.lower()]
icon_name = 'cardinal' if 'diagonal' not in base_icon else 'diagonal'
new_position = action_to_coords(current_position, action.lower())
else:
icon_name = action.lower()
rotation = 0
new_position = current_position
action_entity = RenderEntity(
name=icon_name,
pos=np.array(current_position),
probability=probability,
rotation=rotation
)
action_entities.append(action_entity)
current_position = new_position
renderer.render_single_action_icons(action_entities) # move in/out loop for graph per agent or not
def plot_action_maps(factory, agents):
renderer = Renderer(factory.map.level_shape, cell_size=80, custom_assets_path={
'green_arrow': 'marl_factory_grid/utils/plotting/action_assets/green_arrow.png',
'yellow_arrow': 'marl_factory_grid/utils/plotting/action_assets/yellow_arrow.png',
'red_arrow': 'marl_factory_grid/utils/plotting/action_assets/red_arrow.png',
'grey_arrow': 'marl_factory_grid/utils/plotting/action_assets/grey_arrow.png',
'wall': 'marl_factory_grid/environment/assets/wall.png',
})
directions = ['north', 'east', 'south', 'west']
wall_positions = swap_coordinates(factory.map.walls)
wall_entities = [RenderEntity(name='wall', probability=0, pos=np.array(pos)) for pos in wall_positions]
action_entities = list(wall_entities)
dummy_action_map = load_action_map("example_action_map.txt")
for agent in agents:
# if hasattr(agent, 'action_probability_map'):
# for y in range(len(agent.action_probability_map)):
for y in range(len(dummy_action_map)):
# for x in range(len(agent.action_probability_map[y])):
for x in range(len(dummy_action_map[y])):
position = (x, y)
if position not in wall_positions:
# action_probabilities = agent.action_probability_map[y][x]
action_probabilities = dummy_action_map[y][x]
if sum(action_probabilities) > 0: # Ensure it's not all zeros which would indicate a wall
# Sort actions by probability and assign colors
sorted_indices = sorted(range(len(action_probabilities)),
key=lambda i: -action_probabilities[i])
colors = ['green_arrow', 'yellow_arrow', 'red_arrow', 'grey_arrow']
for rank, direction_index in enumerate(sorted_indices):
action = directions[direction_index]
probability = action_probabilities[direction_index]
arrow_color = colors[rank]
if probability > 0:
action_entity = RenderEntity(
name=arrow_color,
pos=position,
probability=probability,
rotation=direction_index * 90
)
action_entities.append(action_entity)
renderer.render_multi_action_icons(action_entities)
def load_action_map(file_path):
with open(file_path, 'r') as file:
action_map = json.load(file)
return action_map
def swap_coordinates(positions):
"""
Swaps x and y coordinates of single positions, lists or arrays
"""
if isinstance(positions, tuple) or (isinstance(positions, list) and len(positions) == 2):
return positions[1], positions[0]
elif isinstance(positions, np.ndarray) and positions.ndim == 1 and positions.shape[0] == 2:
return positions[1], positions[0]
else:
return [(y, x) for x, y in positions]
def action_to_coords(current_position, action):
"""
Calculates new coordinates based on the current position and a movement action.
"""
delta = direction_mapping.get(action)
if delta is not None:
new_position = [current_position[0] + delta[0], current_position[1] + delta[1]]
return new_position
print(f"No valid movement action found for {action}.")
return current_position
rotation_mapping = {
'north': ('cardinal', 0),
'east': ('cardinal', 270),
'south': ('cardinal', 180),
'west': ('cardinal', 90),
'north_east': ('diagonal', 0),
'south_east': ('diagonal', 270),
'south_west': ('diagonal', 180),
'north_west': ('diagonal', 90)
}
direction_mapping = {
'north': (0, -1),
'south': (0, 1),
'east': (1, 0),
'west': (-1, 0),
'north_east': (1, -1),
'north_west': (-1, -1),
'south_east': (1, 1),
'south_west': (-1, 1)
}

View File

@ -1,7 +1,7 @@
import sys
from pathlib import Path
from collections import deque
from collections import deque, defaultdict
from itertools import product
import numpy as np
@ -24,15 +24,14 @@ SCALE: str = 'scale'
class Renderer:
BG_COLOR = (178, 190, 195) # (99, 110, 114)
WHITE = (223, 230, 233) # (200, 200, 200)
BG_COLOR = (178, 190, 195) # (99, 110, 114)
WHITE = (223, 230, 233) # (200, 200, 200)
AGENT_VIEW_COLOR = (9, 132, 227)
ASSETS = Path(__file__).parent.parent
def __init__(self, lvl_shape: Tuple[int, int] = (16, 16),
lvl_padded_shape: Union[Tuple[int, int], None] = None,
cell_size: int = 40, fps: int = 7, factor: float = 0.9,
grid_lines: bool = True, view_radius: int = 2):
def __init__(self, lvl_shape: Tuple[int, int] = (16, 16), lvl_padded_shape: Union[Tuple[int, int], None] = None,
cell_size: int = 40, fps: int = 7, factor: float = 0.9, grid_lines: bool = True, view_radius: int = 2,
custom_assets_path=None):
"""
The Renderer class initializes and manages the rendering environment for the simulation,
providing methods for preparing entities for display, loading assets, calculating visibility rectangles and
@ -53,7 +52,6 @@ class Renderer:
:param view_radius: Radius for agent's field of view.
:type view_radius: int
"""
# TODO: Custom_assets paths
self.grid_h, self.grid_w = lvl_shape
self.lvl_padded_shape = lvl_padded_shape if lvl_padded_shape is not None else lvl_shape
self.cell_size = cell_size
@ -61,17 +59,18 @@ class Renderer:
self.grid_lines = grid_lines
self.view_radius = view_radius
pygame.init()
self.screen_size = (self.grid_w*cell_size, self.grid_h*cell_size)
self.screen_size = (self.grid_w * cell_size, self.grid_h * cell_size)
self.screen = pygame.display.set_mode(self.screen_size)
self.clock = pygame.time.Clock()
assets = list(self.ASSETS.rglob('*.png'))
self.assets = {path.stem: self.load_asset(str(path), factor) for path in assets}
self.custom_assets_path = custom_assets_path
self.assets = self.load_assets(custom_assets_path)
self.save_counter = 1
self.fill_bg()
now = time.time()
# now = time.time()
self.font = pygame.font.Font(None, 20)
self.font.set_bold(True)
print('Loading System font with pygame.font.Font took', time.time() - now)
# print('Loading System font with pygame.font.Font took', time.time() - now)
def fill_bg(self):
"""
@ -100,22 +99,44 @@ class Renderer:
(self.lvl_padded_shape[1] - self.grid_w) // 2
r, c = entity.pos
r, c = r - offset_r, c-offset_c
r, c = r - offset_r, c - offset_c
img = self.assets[entity.name.lower()]
if entity.value_operation == OPACITY:
img.set_alpha(255*entity.value)
img.set_alpha(255 * entity.value)
elif entity.value_operation == SCALE:
re = img.get_rect()
img = pygame.transform.smoothscale(
img, (int(entity.value*re.width), int(entity.value*re.height))
img, (int(entity.value * re.width), int(entity.value * re.height))
)
o = self.cell_size//2
r_, c_ = r*self.cell_size + o, c*self.cell_size + o
o = self.cell_size // 2
r_, c_ = r * self.cell_size + o, c * self.cell_size + o
rect = img.get_rect()
rect.centerx, rect.centery = c_, r_
return dict(source=img, dest=rect)
def load_assets(self, custom_assets_path):
"""
Loads assets from the custom path if provided, otherwise from the default path.
"""
assets_directory = custom_assets_path if custom_assets_path else self.ASSETS
assets = {}
if isinstance(assets_directory, dict):
for key, path in assets_directory.items():
asset = self.load_asset(path)
if asset is not None:
assets[key] = asset
else:
print(f"Warning: Asset for key '{key}' is missing and was not loaded.")
else:
for path in Path(assets_directory).rglob('*.png'):
asset = self.load_asset(str(path))
if asset is not None:
assets[path.stem] = asset
else:
print(f"Warning: Asset '{path.stem}' is missing and was not loaded.")
return assets
def load_asset(self, path, factor=1.0):
"""
Loads and resizes an asset from the specified path.
@ -126,10 +147,28 @@ class Renderer:
:type factor: float
:return: Resized asset.
"""
s = int(factor*self.cell_size)
asset = pygame.image.load(path).convert_alpha()
asset = pygame.transform.smoothscale(asset, (s, s))
return asset
try:
s = int(factor * self.cell_size)
asset = pygame.image.load(path).convert_alpha()
asset = pygame.transform.smoothscale(asset, (s, s))
return asset
except pygame.error as e:
print(f"Failed to load asset {path}: {e}")
return self.load_default_asset()
def load_default_asset(self, factor=1.0):
"""
Loads a default asset to be used when specific assets fail to load.
"""
default_path = 'marl_factory_grid/utils/plotting/action_assets/default.png'
try:
s = int(factor * self.cell_size)
default_asset = pygame.image.load(default_path).convert_alpha()
default_asset = pygame.transform.smoothscale(default_asset, (s, s))
return default_asset
except pygame.error as e:
print(f"Failed to load default asset: {e}")
return None
def visibility_rects(self, bp, view):
"""
@ -143,13 +182,13 @@ class Renderer:
:rtype: List[dict]
"""
rects = []
for i, j in product(range(-self.view_radius, self.view_radius+1),
range(-self.view_radius, self.view_radius+1)):
for i, j in product(range(-self.view_radius, self.view_radius + 1),
range(-self.view_radius, self.view_radius + 1)):
if view is not None:
if bool(view[self.view_radius+j, self.view_radius+i]):
if bool(view[self.view_radius + j, self.view_radius + i]):
visibility_rect = bp['dest'].copy()
visibility_rect.centerx += i*self.cell_size
visibility_rect.centery += j*self.cell_size
visibility_rect.centerx += i * self.cell_size
visibility_rect.centery += j * self.cell_size
shape_surf = pygame.Surface(visibility_rect.size, pygame.SRCALPHA)
pygame.draw.rect(shape_surf, self.AGENT_VIEW_COLOR, shape_surf.get_rect())
shape_surf.set_alpha(64)
@ -183,7 +222,7 @@ class Renderer:
RenderEntity(agent.state, (agent.pos[0] + 0.12, agent.pos[1]), 0.48, SCALE)
)
textsurface = self.font.render(str(agent.id), False, (0, 0, 0))
text_blit = dict(source=textsurface, dest=(agent_blit['dest'].center[0]-.07*self.cell_size,
text_blit = dict(source=textsurface, dest=(agent_blit['dest'].center[0] - .07 * self.cell_size,
agent_blit['dest'].center[1]))
blits += [agent_blit, state_blit, text_blit]
@ -196,9 +235,113 @@ class Renderer:
return np.transpose(rgb_obs, (2, 0, 1))
# return torch.from_numpy(rgb_obs).permute(2, 0, 1)
def render_single_action_icons(self, action_entities):
"""
Renders action icons based on the entities' specified actions' name, position, rotation and probability.
Renders probabilities unequal 0.
:param action_entities: List of entities representing actions.
:type action_entities: List[RenderEntity]
"""
self.fill_bg()
for action_entity in action_entities:
if not isinstance(action_entity.pos, np.ndarray) or action_entity.pos.ndim != 1:
print(f"Invalid position format for entity: {action_entity.pos}")
continue
# Load and potentially rotate the icon based on action name
img = self.assets[action_entity.name.lower()]
if img is None:
print(f"Error: No asset available for '{action_entity.name}'. Skipping rendering this entity.")
continue
if hasattr(action_entity, 'rotation'):
img = pygame.transform.rotate(img, action_entity.rotation)
# Blit the icon image
img_rect = img.get_rect(center=(action_entity.pos[0] * self.cell_size + self.cell_size // 2,
action_entity.pos[1] * self.cell_size + self.cell_size // 2))
self.screen.blit(img, img_rect)
# Render the probability next to the icon if it exists
if hasattr(action_entity, 'probability') and action_entity.probability != 0:
prob_text = self.font.render(f"{action_entity.probability:.2f}", True, (255, 0, 0))
prob_text_rect = prob_text.get_rect(top=img_rect.bottom, left=img_rect.left)
self.screen.blit(prob_text, prob_text_rect)
pygame.display.flip() # Update the display with all new blits
self.save_screen("route_graph")
def render_multi_action_icons(self, action_entities):
"""
Renders multiple action icons at the same position without overlap and arranges them based on direction, except for
walls which cover the entire grid cell.
"""
self.fill_bg()
font = pygame.font.Font(None, 20)
# prepare position dict to iterate over
position_dict = defaultdict(list)
for entity in action_entities:
position_dict[tuple(entity.pos)].append(entity)
for position, entities in position_dict.items():
entity_size = self.cell_size // 2 # Adjust size to fit multiple entities for non-wall entities
entities.sort(key=lambda x: x.rotation)
for entity in entities:
img = self.assets[entity.name.lower()]
if img is None:
print(f"Error: No asset available for '{entity.name}'. Skipping rendering this entity.")
continue
# Check if the entity is a wall and adjust the size and position accordingly
if entity.name == 'wall':
img = pygame.transform.scale(img, (self.cell_size, self.cell_size))
img_rect = img.get_rect(center=(position[0] * self.cell_size + self.cell_size // 2,
position[1] * self.cell_size + self.cell_size // 2))
else:
# Define offsets for each direction based on a quadrant layout within the cell
offsets = {
0: (0, -entity_size // 2), # North
90: (-entity_size // 2, 0), # West
180: (0, entity_size // 2), # South
270: (entity_size // 2, 0) # East
}
img = pygame.transform.scale(img, (int(entity_size), entity_size))
offset = offsets.get(entity.rotation, (0, 0))
img_rect = img.get_rect(center=(
position[0] * self.cell_size + self.cell_size // 2 + offset[0],
position[1] * self.cell_size + self.cell_size // 2 + offset[1]
))
img = pygame.transform.rotate(img, entity.rotation)
self.screen.blit(img, img_rect)
# Render the probability next to the icon if it exists and is non-zero
if entity.probability > 0 and entity.name != 'wall':
formatted_probability = f"{entity.probability * 100:.2f}"
prob_text = font.render(formatted_probability, True, (0, 0, 0))
prob_text_rect = prob_text.get_rect(center=img_rect.center) # Center text on the arrow
self.screen.blit(prob_text, prob_text_rect)
pygame.display.flip()
self.save_screen("multi_action_graph")
def save_screen(self, filename):
"""
Saves the current screen to a PNG file, appending a counter to ensure uniqueness.
:param filename: The base filename where to save the image.
:param agent_id: Unique identifier for the agent.
"""
unique_filename = f"{filename}_agent_{self.save_counter}.png"
self.save_counter += 1
pygame.image.save(self.screen, unique_filename)
print(f"Image saved as {unique_filename}")
if __name__ == '__main__':
renderer = Renderer(fps=2, cell_size=40)
renderer = Renderer(cell_size=40, fps=2)
for pos_i in range(15):
entity_1 = RenderEntity('agent_collision', [5, pos_i], 1, 'idle', 'idle')
renderer.render([entity_1])

View File

@ -1,3 +1,4 @@
import copy
from itertools import islice
from typing import List, Tuple
@ -116,6 +117,7 @@ class Gamestate(object):
self.rng = np.random.default_rng(env_seed)
self.rules = StepRules(*rules)
self._floortile_graph = None
self.route_cache = []
self.tests = StepTests(*tests)
# Pointer that defines current spawn points of agents
@ -320,6 +322,42 @@ class Gamestate(object):
# json_file.seek(0)
# json.dump(existing_content, json_file, indent=4)
def cache_route(self, route):
"""
Save routes in env-level cache so agents can access it.
:param route: The route to be saved
"""
self.route_cache.append(copy.deepcopy(route))
# print(f"Cached route: {route}")
def get_cached_route(self, current_pos, target_positions, route_cutting=False):
"""
Use a cached route if it includes the current position and a target
:param current_pos: The agent's current position and thus the first position of possibly cached routes
:param target_positions: The positions of targets the agent wants to visit
:param route_cutting: if true, cuts found routes to end at target. False allows target agents to loop.
:returns: A cached route from the agent's position to the first target if it exists
"""
if not self.route_cache:
return None
for route in self.route_cache:
if current_pos in route:
targets = [target for target in target_positions if target in route]
if targets:
first_target = targets[0]
index_start = route.index(current_pos)
if route_cutting:
index_end = route.index(first_target) + 1
return copy.deepcopy(route[index_start:index_end])
else:
return copy.deepcopy(route[index_start:])
return None
class StepTests:
def __init__(self, *args):

View File

@ -33,6 +33,8 @@ class RenderEntity:
id: int = 0
aux: Any = None
real_name: str = 'none'
probability: float = None # Default to None if not used
rotation: int = 0 # Default rotation if not specified
@dataclass

View File

@ -1,4 +1,5 @@
from pathlib import Path
from pprint import pprint
from tqdm import trange
@ -7,12 +8,17 @@ from marl_factory_grid.algorithms.static.TSP_item_agent import TSPItemAgent
from marl_factory_grid.algorithms.static.TSP_target_agent import TSPTargetAgent
from marl_factory_grid.environment.factory import Factory
from marl_factory_grid.utils.plotting.plot_single_runs import plot_routes, plot_action_maps
if __name__ == '__main__':
# Render at each step?
run_path = Path('study_out')
render = True
monitor = True
record = True
# Path to config File
path = Path('marl_factory_grid/configs/simple_crossing.yaml')
path = Path('marl_factory_grid/configs/test_config.yaml')
# Env Init
factory = Factory(path)
@ -25,6 +31,7 @@ if __name__ == '__main__':
action_spaces = factory.action_space
# agents = [TSPDirtAgent(factory, 0), TSPItemAgent(factory, 1), TSPTargetAgent(factory, 2)]
agents = [TSPTargetAgent(factory, 0), TSPTargetAgent(factory, 1)]
# agents = [TSPTargetAgent(factory, 0)]
while not done:
a = [x.predict() for x in agents]
obs_type, _, _, done, info = factory.step(a)
@ -33,3 +40,5 @@ if __name__ == '__main__':
if done:
print(f'Episode {episode} done...')
break
plot_action_maps(factory, agents)