mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2026-01-15 23:41:39 +01:00
merged RL agent with action map plotting and added it to end of agents train loop
This commit is contained in:
@@ -17,6 +17,7 @@ from collections import deque
|
|||||||
|
|
||||||
from marl_factory_grid.environment.actions import Noop
|
from marl_factory_grid.environment.actions import Noop
|
||||||
from marl_factory_grid.modules import Clean, DoorUse
|
from marl_factory_grid.modules import Clean, DoorUse
|
||||||
|
from marl_factory_grid.utils.plotting.plot_single_runs import plot_action_maps
|
||||||
|
|
||||||
|
|
||||||
class Names:
|
class Names:
|
||||||
@@ -571,8 +572,7 @@ class A2C:
|
|||||||
if self.cfg[nms.ENV]["save_and_log"]:
|
if self.cfg[nms.ENV]["save_and_log"]:
|
||||||
self.create_info_maps(env, used_actions)
|
self.create_info_maps(env, used_actions)
|
||||||
self.save_agent_models()
|
self.save_agent_models()
|
||||||
|
plot_action_maps(env, [self])
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode(True)
|
@torch.inference_mode(True)
|
||||||
def eval_loop(self, n_episodes, render=False):
|
def eval_loop(self, n_episodes, render=False):
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
###########
|
##########D
|
||||||
#---#######
|
#---#######
|
||||||
#-----#####
|
#-----#####
|
||||||
#------####
|
#------####
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from os import PathLike
|
from os import PathLike
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -121,50 +122,60 @@ def plot_routes(factory, agents):
|
|||||||
|
|
||||||
|
|
||||||
def plot_action_maps(factory, agents):
|
def plot_action_maps(factory, agents):
|
||||||
renderer = Renderer(factory.map.level_shape, cell_size=80, custom_assets_path={
|
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
'green_arrow': 'marl_factory_grid/utils/plotting/action_assets/green_arrow.png',
|
assets_path = {
|
||||||
'yellow_arrow': 'marl_factory_grid/utils/plotting/action_assets/yellow_arrow.png',
|
'green_arrow': os.path.join(base_dir, 'utils', 'plotting', 'action_assets', 'green_arrow.png'),
|
||||||
'red_arrow': 'marl_factory_grid/utils/plotting/action_assets/red_arrow.png',
|
'yellow_arrow': os.path.join(base_dir, 'utils', 'plotting', 'action_assets', 'yellow_arrow.png'),
|
||||||
'grey_arrow': 'marl_factory_grid/utils/plotting/action_assets/grey_arrow.png',
|
'red_arrow': os.path.join(base_dir, 'utils', 'plotting', 'action_assets', 'red_arrow.png'),
|
||||||
'wall': 'marl_factory_grid/environment/assets/wall.png',
|
'grey_arrow': os.path.join(base_dir, 'utils', 'plotting', 'action_assets', 'grey_arrow.png'),
|
||||||
})
|
'wall': os.path.join(base_dir, 'environment', 'assets', 'wall.png'),
|
||||||
|
}
|
||||||
|
renderer = Renderer(factory.map.level_shape, cell_size=80, custom_assets_path=assets_path)
|
||||||
|
|
||||||
directions = ['north', 'east', 'south', 'west']
|
directions = ['north', 'east', 'south', 'west']
|
||||||
wall_positions = swap_coordinates(factory.map.walls)
|
wall_positions = swap_coordinates(factory.map.walls)
|
||||||
wall_entities = [RenderEntity(name='wall', probability=0, pos=np.array(pos)) for pos in wall_positions]
|
|
||||||
action_entities = list(wall_entities)
|
|
||||||
|
|
||||||
dummy_action_map = load_action_map("example_action_map.txt")
|
for agent_index, agent in enumerate(agents):
|
||||||
for agent in agents:
|
if hasattr(agent, 'action_probabilities'):
|
||||||
# if hasattr(agent, 'action_probability_map'):
|
action_probabilities = unpack_action_probabilities(agent.action_probabilities)
|
||||||
# for y in range(len(agent.action_probability_map)):
|
for action_map_index, probabilities_map in enumerate(action_probabilities[agent_index]):
|
||||||
for y in range(len(dummy_action_map)):
|
wall_entities = [RenderEntity(name='wall', probability=0, pos=np.array(pos)) for pos in wall_positions]
|
||||||
# for x in range(len(agent.action_probability_map[y])):
|
action_entities = list(wall_entities)
|
||||||
for x in range(len(dummy_action_map[y])):
|
for position, probabilities in probabilities_map.items():
|
||||||
position = (x, y)
|
if position not in wall_positions:
|
||||||
if position not in wall_positions:
|
if np.any(probabilities) > 0: # Ensure it's not all zeros which would indicate a wall
|
||||||
# action_probabilities = agent.action_probability_map[y][x]
|
sorted_indices = sorted(range(len(probabilities)), key=lambda i: -probabilities[i])
|
||||||
action_probabilities = dummy_action_map[y][x]
|
colors = ['green_arrow', 'yellow_arrow', 'red_arrow', 'grey_arrow']
|
||||||
if sum(action_probabilities) > 0: # Ensure it's not all zeros which would indicate a wall
|
|
||||||
# Sort actions by probability and assign colors
|
|
||||||
sorted_indices = sorted(range(len(action_probabilities)),
|
|
||||||
key=lambda i: -action_probabilities[i])
|
|
||||||
colors = ['green_arrow', 'yellow_arrow', 'red_arrow', 'grey_arrow']
|
|
||||||
|
|
||||||
for rank, direction_index in enumerate(sorted_indices):
|
for rank, direction_index in enumerate(sorted_indices):
|
||||||
action = directions[direction_index]
|
action = directions[direction_index]
|
||||||
probability = action_probabilities[direction_index]
|
probability = probabilities[direction_index]
|
||||||
arrow_color = colors[rank]
|
arrow_color = colors[rank]
|
||||||
if probability > 0:
|
if probability > 0:
|
||||||
action_entity = RenderEntity(
|
action_entity = RenderEntity(
|
||||||
name=arrow_color,
|
name=arrow_color,
|
||||||
pos=position,
|
pos=position,
|
||||||
probability=probability,
|
probability=probability,
|
||||||
rotation=direction_index * 90
|
rotation=direction_index * 90
|
||||||
)
|
)
|
||||||
action_entities.append(action_entity)
|
action_entities.append(action_entity)
|
||||||
|
|
||||||
renderer.render_multi_action_icons(action_entities)
|
renderer.render_multi_action_icons(action_entities)
|
||||||
|
|
||||||
|
|
||||||
|
def unpack_action_probabilities(action_probabilities):
|
||||||
|
unpacked = {}
|
||||||
|
for agent_index, maps in action_probabilities.items():
|
||||||
|
unpacked[agent_index] = []
|
||||||
|
for map_index, probability_map in enumerate(maps):
|
||||||
|
single_map = {}
|
||||||
|
for y in range(len(probability_map)):
|
||||||
|
for x in range(len(probability_map[y])):
|
||||||
|
position = (x, y)
|
||||||
|
probabilities = probability_map[y][x]
|
||||||
|
single_map[position] = probabilities
|
||||||
|
unpacked[agent_index].append(single_map)
|
||||||
|
return unpacked
|
||||||
|
|
||||||
|
|
||||||
def load_action_map(file_path):
|
def load_action_map(file_path):
|
||||||
|
|||||||
Reference in New Issue
Block a user