mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-08 02:21:36 +02:00
Added various plotting methods + Fixed action maps plotting
This commit is contained in:
@ -9,6 +9,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
import scipy.stats as stats
|
||||
|
||||
from marl_factory_grid.algorithms.rl.utils import _as_torch
|
||||
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
|
||||
@ -106,6 +107,11 @@ def plot_action_maps(factory, agents, result_path):
|
||||
tuples = ast.literal_eval(factory.conf['Entities']['DirtPiles']['coords_or_quantity'])
|
||||
for t in tuples:
|
||||
all_target_dirts.append(t)
|
||||
|
||||
if isinstance(all_target_dirts[0], int):
|
||||
temp = all_target_dirts
|
||||
all_target_dirts = [tuple(temp)]
|
||||
|
||||
assigned_spawn_positions = []
|
||||
for j in range(len(spawnpoints) // len(all_target_dirts)):
|
||||
assigned_spawn_positions.append(spawnpoints[j * len(all_target_dirts) + all_target_dirts.index(target_dirt_pos)])
|
||||
@ -204,73 +210,239 @@ direction_mapping = {
|
||||
}
|
||||
|
||||
|
||||
def plot_reward_development(reward_development, results_path):
|
||||
smoothed_data = np.convolve(reward_development, np.ones(10) / 10, mode='valid')
|
||||
def plot_return_development(return_development, results_path, discounted=False):
|
||||
smoothed_data = np.convolve(return_development, np.ones(10) / 10, mode='valid')
|
||||
plt.plot(smoothed_data)
|
||||
plt.ylim([-10, max(smoothed_data) + 20])
|
||||
plt.title('Smoothed Reward Development')
|
||||
plt.title('Smoothed Return Development' if not discounted else 'Smoothed Discounted Return Development')
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('Reward')
|
||||
plt.savefig(f"{results_path}/smoothed_reward_development.png")
|
||||
plt.ylabel('Return' if not discounted else "Discounted Return")
|
||||
plt.savefig(f"{results_path}/smoothed_return_development.png"
|
||||
if not discounted else f"{results_path}/smoothed_discounted_return_development.png")
|
||||
plt.show()
|
||||
|
||||
def plot_return_development_change(return_change_development, results_path):
|
||||
plt.plot(return_change_development)
|
||||
plt.title('Return Change Development')
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('Delta Return')
|
||||
plt.savefig(f"{results_path}/return_change_development.png")
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_collected_coins_per_step():
|
||||
def mean_confidence_interval(data, confidence=0.95):
|
||||
a = np.array(data)
|
||||
n = np.sum(~np.isnan(a), axis=0)
|
||||
mean = np.nanmean(a, axis=0)
|
||||
se = np.nanstd(a, axis=0) / np.sqrt(n)
|
||||
h = se * 1.96 # For 95% confidence interval
|
||||
return mean, mean - h, mean + h
|
||||
|
||||
def load_metrics(file_path, key):
|
||||
with open(file_path, "rb") as pickle_file:
|
||||
metrics = pickle.load(pickle_file)
|
||||
return metrics[key][0]
|
||||
|
||||
def pad_runs(runs):
|
||||
max_length = max(len(run) for run in runs)
|
||||
padded_runs = [np.pad(np.array(run, dtype=float), (0, max_length - len(run)), constant_values=np.nan) for run in runs]
|
||||
return padded_runs
|
||||
|
||||
def get_reached_flags_metrics(runs):
|
||||
# Find the step where flag 1 and flag 2 are reached
|
||||
flag1_steps = []
|
||||
flag2_steps = []
|
||||
|
||||
for run in runs:
|
||||
if 1 in run:
|
||||
flag1_steps.append(run.index(1))
|
||||
if 2 in run:
|
||||
flag2_steps.append(run.index(2))
|
||||
|
||||
print(flag1_steps)
|
||||
print(flag2_steps)
|
||||
|
||||
# Calculate the mean steps and confidence intervals
|
||||
mean_flag1_steps = np.mean(flag1_steps)
|
||||
mean_flag2_steps = np.mean(flag2_steps)
|
||||
|
||||
std_flag1_steps = np.std(flag1_steps, ddof=1)
|
||||
std_flag2_steps = np.std(flag2_steps, ddof=1)
|
||||
|
||||
n_flag1 = len(flag1_steps)
|
||||
n_flag2 = len(flag2_steps)
|
||||
|
||||
confidence_level = 0.95
|
||||
t_critical_flag1 = stats.t.ppf((1 + confidence_level) / 2, n_flag1 - 1)
|
||||
t_critical_flag2 = stats.t.ppf((1 + confidence_level) / 2, n_flag2 - 1)
|
||||
|
||||
margin_of_error_flag1 = t_critical_flag1 * (std_flag1_steps / np.sqrt(n_flag1))
|
||||
margin_of_error_flag2 = t_critical_flag2 * (std_flag2_steps / np.sqrt(n_flag2))
|
||||
|
||||
# Mean steps including baseline
|
||||
mean_steps = [0, mean_flag1_steps, mean_flag2_steps]
|
||||
flags_reached = [0, 1, 2]
|
||||
error_bars = [0, margin_of_error_flag1, margin_of_error_flag2]
|
||||
return mean_steps, flags_reached, error_bars
|
||||
|
||||
def plot_collected_coins_per_step(rl_runs_names, tsp_runs_names, results_path):
|
||||
# Observed behaviour for multi-agent setting consisting of run0 and run0
|
||||
cleaned_dirt_per_step_emergent = [0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 5]
|
||||
cleaned_dirt_per_step = [0, 0, 0, 1, 1, 2, 2, 3, 3, 3, 4, 5] # RL and TSP
|
||||
|
||||
plt.step(range(1, len(cleaned_dirt_per_step) + 1), cleaned_dirt_per_step, color='green', linewidth=3, label='Prevented (RL)')
|
||||
# Load RL and TSP data from multiple runs
|
||||
rl_runs = [load_metrics(results_path + f"/{rl_run}/metrics", "cleaned_dirt_piles_per_step") for rl_run in rl_runs_names]
|
||||
|
||||
tsp_runs = [load_metrics(results_path + f"/{tsp_run}/metrics", "cleaned_dirt_piles_per_step") for tsp_run in tsp_runs_names]
|
||||
|
||||
# Pad runs to handle heterogeneous lengths
|
||||
rl_runs = pad_runs(rl_runs)
|
||||
tsp_runs = pad_runs(tsp_runs)
|
||||
|
||||
# Calculate mean and confidence intervals
|
||||
mean_rl, lower_rl, upper_rl = mean_confidence_interval(rl_runs)
|
||||
mean_tsp, lower_tsp, upper_tsp = mean_confidence_interval(tsp_runs)
|
||||
|
||||
# Plot the mean and confidence intervals
|
||||
plt.fill_between(range(1, len(mean_rl) + 1), lower_rl, upper_rl, color='green', alpha=0.2)
|
||||
plt.step(range(1, len(mean_rl) + 1), mean_rl, color='green', linewidth=3, label='Prevented (RL)')
|
||||
|
||||
plt.fill_between(range(1, len(mean_tsp) + 1), lower_tsp, upper_tsp, color='darkorange', alpha=0.2)
|
||||
plt.step(range(1, len(mean_tsp) + 1), mean_tsp, linestyle='dotted', color='darkorange', linewidth=3, label='Prevented (TSP)')
|
||||
|
||||
plt.step(range(1, len(cleaned_dirt_per_step_emergent) + 1), cleaned_dirt_per_step_emergent, linestyle='--', color='darkred', linewidth=3, label='Emergent')
|
||||
plt.step(range(1, len(cleaned_dirt_per_step) + 1), cleaned_dirt_per_step, linestyle='dotted', color='darkorange', linewidth=3, label='Prevented (TSP)')
|
||||
|
||||
plt.xlabel("Environment step", fontsize=20)
|
||||
plt.ylabel("Collected Coins", fontsize=20)
|
||||
yint = range(min(cleaned_dirt_per_step), max(cleaned_dirt_per_step) + 1)
|
||||
plt.yticks(yint, fontsize=17)
|
||||
plt.xticks(range(1, len(cleaned_dirt_per_step_emergent) + 1), fontsize=17)
|
||||
plt.yticks(fontsize=17)
|
||||
|
||||
frame1 = plt.gca()
|
||||
# Only display every 5th tick label
|
||||
for idx, xlabel_i in enumerate(frame1.axes.get_xticklabels()):
|
||||
if (idx + 1) % 5 != 0:
|
||||
xlabel_i.set_visible(False)
|
||||
xlabel_i.set_fontsize(0.0)
|
||||
# Change order of labels in legend
|
||||
|
||||
handles, labels = frame1.get_legend_handles_labels()
|
||||
order = [0, 2, 1]
|
||||
plt.legend([handles[idx] for idx in order], [labels[idx] for idx in order], prop={'size': 20})
|
||||
|
||||
fig = plt.gcf()
|
||||
fig.set_size_inches(8, 7)
|
||||
plt.savefig("../study_out/number_of_collected_coins.pdf")
|
||||
plt.savefig(f"{results_path}/number_of_collected_coins.pdf")
|
||||
plt.show()
|
||||
|
||||
def plot_reached_flags_per_step(rl_runs_names, tsp_runs_names, results_path):
|
||||
reached_flags_per_step_emergent = [0] * 32 # Adjust based on your data length
|
||||
|
||||
# Load RL and TSP data from multiple runs
|
||||
rl_runs = [load_metrics(results_path + f"/{rl_run}/metrics", "cleaned_dirt_piles_per_step") for rl_run in rl_runs_names]
|
||||
rl_runs = [[pile - 1 for pile in run] for run in rl_runs] # Subtract the auxiliary pile
|
||||
|
||||
tsp_runs = [load_metrics(results_path + f"/{tsp_run}/metrics", "reached_flags") for tsp_run in tsp_runs_names]
|
||||
|
||||
# Pad runs to handle heterogeneous lengths
|
||||
rl_runs = pad_runs(rl_runs)
|
||||
tsp_runs = pad_runs(tsp_runs)
|
||||
|
||||
# Calculate mean and confidence intervals
|
||||
mean_rl, lower_rl, upper_rl = mean_confidence_interval(rl_runs)
|
||||
mean_tsp, lower_tsp, upper_tsp = mean_confidence_interval(tsp_runs)
|
||||
|
||||
# Plot the mean and confidence intervals
|
||||
plt.fill_between(range(1, len(mean_rl) + 1), lower_rl, upper_rl, color='green', alpha=0.2)
|
||||
plt.step(range(1, len(mean_rl) + 1), mean_rl, color='green', linewidth=3, label='Prevented (RL)')
|
||||
|
||||
plt.fill_between(range(1, len(mean_tsp) + 1), lower_tsp, upper_tsp, color='darkorange', alpha=0.2)
|
||||
plt.step(range(1, len(mean_tsp) + 1), mean_tsp, linestyle='dotted', color='darkorange', linewidth=3, label='Prevented (TSP)')
|
||||
|
||||
plt.step(range(1, len(reached_flags_per_step_emergent) + 1), reached_flags_per_step_emergent, linestyle='--', color='darkred', linewidth=3, label='Emergent')
|
||||
|
||||
plt.xlabel("Environment step", fontsize=20)
|
||||
plt.ylabel("Reached Flags", fontsize=20)
|
||||
plt.xticks(range(1, len(reached_flags_per_step_emergent) + 1), fontsize=17)
|
||||
plt.yticks(fontsize=17)
|
||||
|
||||
frame1 = plt.gca()
|
||||
for idx, xlabel_i in enumerate(frame1.axes.get_xticklabels()):
|
||||
if (idx + 1) % 5 != 0:
|
||||
xlabel_i.set_visible(False)
|
||||
xlabel_i.set_fontsize(0.0)
|
||||
|
||||
handles, labels = frame1.get_legend_handles_labels()
|
||||
order = [0, 2, 1]
|
||||
plt.legend([handles[idx] for idx in order], [labels[idx] for idx in order], prop={'size': 20})
|
||||
|
||||
fig = plt.gcf()
|
||||
fig.set_size_inches(8, 7)
|
||||
plt.savefig(f"{results_path}/number_of_reached_flags.pdf")
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_reached_flags_per_step():
|
||||
# Observed behaviour for multi-agent setting consisting of runs 1 + 2
|
||||
reached_flags_per_step_emergent = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
||||
reached_flags_per_step_RL = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2]
|
||||
reached_flags_per_step_TSP = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]
|
||||
def plot_performance_distribution_on_coin_quadrant(dirt_quadrant, results_path, grid=False):
|
||||
plt.rcParams["figure.autolayout"] = True
|
||||
plt.rcParams["axes.edgecolor"] = "black"
|
||||
plt.rcParams["axes.linewidth"] = 5.0
|
||||
fig = plt.figure(figsize=(18, 13))
|
||||
|
||||
plt.step(range(1, len(reached_flags_per_step_RL) + 1), reached_flags_per_step_RL, color='green', linewidth=3, label='Prevented (RL)')
|
||||
plt.step(range(1, len(reached_flags_per_step_emergent) + 1), reached_flags_per_step_emergent, linestyle='--', color='darkred', linewidth=3, label='Emergent')
|
||||
plt.step(range(1, len(reached_flags_per_step_TSP) + 1), reached_flags_per_step_TSP, linestyle='dotted', color='darkorange', linewidth=3, label='Prevented (TSP)')
|
||||
plt.xlabel("Environment step", fontsize=20)
|
||||
plt.ylabel("Reached Flags", fontsize=20)
|
||||
yint = range(min(reached_flags_per_step_RL), max(reached_flags_per_step_RL) + 1)
|
||||
plt.yticks(yint, fontsize=17)
|
||||
plt.xticks(range(1, len(reached_flags_per_step_emergent) + 1), fontsize=17)
|
||||
frame1 = plt.gca()
|
||||
# Only display every 5th tick label
|
||||
for idx, xlabel_i in enumerate(frame1.axes.get_xticklabels()):
|
||||
if (idx + 1) % 5 != 0:
|
||||
xlabel_i.set_visible(False)
|
||||
xlabel_i.set_fontsize(0.0)
|
||||
# Change order of labels in legend
|
||||
handles, labels = frame1.get_legend_handles_labels()
|
||||
order = [0, 2, 1]
|
||||
plt.legend([handles[idx] for idx in order], [labels[idx] for idx in order], prop={'size': 20})
|
||||
fig = plt.gcf()
|
||||
fig.set_size_inches(8, 7)
|
||||
plt.savefig("../study_out/number_of_reached_flags.pdf")
|
||||
rl_color = '#5D3A9B'
|
||||
tsp_color = '#E66100'
|
||||
|
||||
# Boxplot
|
||||
boxprops = dict(linestyle='-', linewidth=4)
|
||||
whiskerprops = dict(linestyle='-', linewidth=4)
|
||||
capprops = dict(linestyle='-', linewidth=4)
|
||||
flierprops = dict(marker='o', markersize=14, markeredgewidth=4,
|
||||
linestyle='none')
|
||||
medianprops = dict(linestyle='-', linewidth=4, color='#40B0A6')
|
||||
meanpointprops = dict(marker='D', markeredgecolor='black',
|
||||
markerfacecolor='firebrick')
|
||||
meanlineprops = dict(linestyle='-.', linewidth=4, color='purple')
|
||||
|
||||
bp = plt.boxplot([dirt_quadrant["RL_emergence"], dirt_quadrant["RL_prevented"], dirt_quadrant["TSP_emergence"],
|
||||
dirt_quadrant["TSP_prevented"]], patch_artist=True, widths=0.6, flierprops=flierprops,
|
||||
boxprops=boxprops, medianprops=medianprops, meanprops=meanlineprops,
|
||||
whiskerprops=whiskerprops, capprops=capprops,
|
||||
meanline=True, showmeans=False, positions=[1, 2.5, 4, 5.5])
|
||||
|
||||
colors = [rl_color, rl_color, tsp_color, tsp_color]
|
||||
|
||||
for bplot, color in zip([bp], [colors, colors]):
|
||||
for patch, color in zip(bplot['boxes'], color):
|
||||
patch.set_facecolor(color)
|
||||
|
||||
plt.tick_params(width=5, length=10)
|
||||
plt.xticks([1, 2.5, 4, 5.5], labels=['Emergent \n (RL)', 'Prevented \n (RL)', 'Emergent \n (TSP)', 'Prevented \n (TSP)'], fontsize=50)
|
||||
plt.yticks(fontsize=50)
|
||||
plt.ylabel('No. environment steps', fontsize=50)
|
||||
plt.xlabel("Agent Types", fontsize=50)
|
||||
plt.grid(grid)
|
||||
plt.tight_layout()
|
||||
plt.savefig(f"{results_path}/number_of_collected_coins_distribution{'_grid' if grid else ''}.pdf")
|
||||
plt.show()
|
||||
|
||||
def plot_reached_flags_per_step_with_error(mean_steps_RL_prevented, error_bars_RL_prevented,
|
||||
mean_steps_TSP_prevented, error_bars_TSP_prevented, flags_reached,
|
||||
results_path, grid=False):
|
||||
plt.rcParams["figure.autolayout"] = True
|
||||
plt.rcParams["axes.edgecolor"] = "black"
|
||||
plt.rcParams["axes.linewidth"] = 5.0
|
||||
fig = plt.figure(figsize=(18, 13))
|
||||
|
||||
# Line plot with error bars
|
||||
plt.plot(range(30), [0 for _ in range(30)], color='gray', linestyle='--', linewidth=7,
|
||||
label='Emergent')
|
||||
plt.errorbar(mean_steps_RL_prevented, flags_reached, xerr=error_bars_RL_prevented, fmt='-o', ecolor='r', capsize=10, capthick=5,
|
||||
markersize=20, label='Prevented (RL) + CI', color='#5D3A9B', linewidth=7)
|
||||
plt.errorbar(mean_steps_TSP_prevented, flags_reached, xerr=error_bars_TSP_prevented, fmt='-o', ecolor='r', capsize=10, capthick=5,
|
||||
markersize=20, label='Prevented (TSP) + CI', color='#E66100', linewidth=7)
|
||||
plt.tick_params(width=5, length=10)
|
||||
plt.xticks(fontsize=50)
|
||||
plt.yticks(flags_reached, fontsize=50)
|
||||
plt.xlabel("Avg. environment step", fontsize=50)
|
||||
plt.ylabel('Reached flags', fontsize=50)
|
||||
plt.legend(fontsize=45, loc='best', bbox_to_anchor=(0.38, 0.38))
|
||||
plt.grid(grid)
|
||||
plt.savefig(f"{results_path}/number_of_reached_flags{'_grid' if grid else ''}.pdf")
|
||||
plt.show()
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user