Added various plotting methods + Fixed action maps plotting

This commit is contained in:
Julian Schönberger
2024-08-09 16:35:08 +02:00
parent fcd2eaf773
commit 973f3e9fc8

View File

@ -9,6 +9,7 @@ import numpy as np
import pandas as pd
import torch
from matplotlib import pyplot as plt
import scipy.stats as stats
from marl_factory_grid.algorithms.rl.utils import _as_torch
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
@ -106,6 +107,11 @@ def plot_action_maps(factory, agents, result_path):
tuples = ast.literal_eval(factory.conf['Entities']['DirtPiles']['coords_or_quantity'])
for t in tuples:
all_target_dirts.append(t)
if isinstance(all_target_dirts[0], int):
temp = all_target_dirts
all_target_dirts = [tuple(temp)]
assigned_spawn_positions = []
for j in range(len(spawnpoints) // len(all_target_dirts)):
assigned_spawn_positions.append(spawnpoints[j * len(all_target_dirts) + all_target_dirts.index(target_dirt_pos)])
@ -204,73 +210,239 @@ direction_mapping = {
}
def plot_reward_development(reward_development, results_path):
smoothed_data = np.convolve(reward_development, np.ones(10) / 10, mode='valid')
def plot_return_development(return_development, results_path, discounted=False):
smoothed_data = np.convolve(return_development, np.ones(10) / 10, mode='valid')
plt.plot(smoothed_data)
plt.ylim([-10, max(smoothed_data) + 20])
plt.title('Smoothed Reward Development')
plt.title('Smoothed Return Development' if not discounted else 'Smoothed Discounted Return Development')
plt.xlabel('Episode')
plt.ylabel('Reward')
plt.savefig(f"{results_path}/smoothed_reward_development.png")
plt.ylabel('Return' if not discounted else "Discounted Return")
plt.savefig(f"{results_path}/smoothed_return_development.png"
if not discounted else f"{results_path}/smoothed_discounted_return_development.png")
plt.show()
def plot_return_development_change(return_change_development, results_path):
plt.plot(return_change_development)
plt.title('Return Change Development')
plt.xlabel('Episode')
plt.ylabel('Delta Return')
plt.savefig(f"{results_path}/return_change_development.png")
plt.show()
def plot_collected_coins_per_step():
def mean_confidence_interval(data, confidence=0.95):
a = np.array(data)
n = np.sum(~np.isnan(a), axis=0)
mean = np.nanmean(a, axis=0)
se = np.nanstd(a, axis=0) / np.sqrt(n)
h = se * 1.96 # For 95% confidence interval
return mean, mean - h, mean + h
def load_metrics(file_path, key):
with open(file_path, "rb") as pickle_file:
metrics = pickle.load(pickle_file)
return metrics[key][0]
def pad_runs(runs):
max_length = max(len(run) for run in runs)
padded_runs = [np.pad(np.array(run, dtype=float), (0, max_length - len(run)), constant_values=np.nan) for run in runs]
return padded_runs
def get_reached_flags_metrics(runs):
# Find the step where flag 1 and flag 2 are reached
flag1_steps = []
flag2_steps = []
for run in runs:
if 1 in run:
flag1_steps.append(run.index(1))
if 2 in run:
flag2_steps.append(run.index(2))
print(flag1_steps)
print(flag2_steps)
# Calculate the mean steps and confidence intervals
mean_flag1_steps = np.mean(flag1_steps)
mean_flag2_steps = np.mean(flag2_steps)
std_flag1_steps = np.std(flag1_steps, ddof=1)
std_flag2_steps = np.std(flag2_steps, ddof=1)
n_flag1 = len(flag1_steps)
n_flag2 = len(flag2_steps)
confidence_level = 0.95
t_critical_flag1 = stats.t.ppf((1 + confidence_level) / 2, n_flag1 - 1)
t_critical_flag2 = stats.t.ppf((1 + confidence_level) / 2, n_flag2 - 1)
margin_of_error_flag1 = t_critical_flag1 * (std_flag1_steps / np.sqrt(n_flag1))
margin_of_error_flag2 = t_critical_flag2 * (std_flag2_steps / np.sqrt(n_flag2))
# Mean steps including baseline
mean_steps = [0, mean_flag1_steps, mean_flag2_steps]
flags_reached = [0, 1, 2]
error_bars = [0, margin_of_error_flag1, margin_of_error_flag2]
return mean_steps, flags_reached, error_bars
def plot_collected_coins_per_step(rl_runs_names, tsp_runs_names, results_path):
# Observed behaviour for multi-agent setting consisting of run0 and run0
cleaned_dirt_per_step_emergent = [0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 5]
cleaned_dirt_per_step = [0, 0, 0, 1, 1, 2, 2, 3, 3, 3, 4, 5] # RL and TSP
plt.step(range(1, len(cleaned_dirt_per_step) + 1), cleaned_dirt_per_step, color='green', linewidth=3, label='Prevented (RL)')
# Load RL and TSP data from multiple runs
rl_runs = [load_metrics(results_path + f"/{rl_run}/metrics", "cleaned_dirt_piles_per_step") for rl_run in rl_runs_names]
tsp_runs = [load_metrics(results_path + f"/{tsp_run}/metrics", "cleaned_dirt_piles_per_step") for tsp_run in tsp_runs_names]
# Pad runs to handle heterogeneous lengths
rl_runs = pad_runs(rl_runs)
tsp_runs = pad_runs(tsp_runs)
# Calculate mean and confidence intervals
mean_rl, lower_rl, upper_rl = mean_confidence_interval(rl_runs)
mean_tsp, lower_tsp, upper_tsp = mean_confidence_interval(tsp_runs)
# Plot the mean and confidence intervals
plt.fill_between(range(1, len(mean_rl) + 1), lower_rl, upper_rl, color='green', alpha=0.2)
plt.step(range(1, len(mean_rl) + 1), mean_rl, color='green', linewidth=3, label='Prevented (RL)')
plt.fill_between(range(1, len(mean_tsp) + 1), lower_tsp, upper_tsp, color='darkorange', alpha=0.2)
plt.step(range(1, len(mean_tsp) + 1), mean_tsp, linestyle='dotted', color='darkorange', linewidth=3, label='Prevented (TSP)')
plt.step(range(1, len(cleaned_dirt_per_step_emergent) + 1), cleaned_dirt_per_step_emergent, linestyle='--', color='darkred', linewidth=3, label='Emergent')
plt.step(range(1, len(cleaned_dirt_per_step) + 1), cleaned_dirt_per_step, linestyle='dotted', color='darkorange', linewidth=3, label='Prevented (TSP)')
plt.xlabel("Environment step", fontsize=20)
plt.ylabel("Collected Coins", fontsize=20)
yint = range(min(cleaned_dirt_per_step), max(cleaned_dirt_per_step) + 1)
plt.yticks(yint, fontsize=17)
plt.xticks(range(1, len(cleaned_dirt_per_step_emergent) + 1), fontsize=17)
plt.yticks(fontsize=17)
frame1 = plt.gca()
# Only display every 5th tick label
for idx, xlabel_i in enumerate(frame1.axes.get_xticklabels()):
if (idx + 1) % 5 != 0:
xlabel_i.set_visible(False)
xlabel_i.set_fontsize(0.0)
# Change order of labels in legend
handles, labels = frame1.get_legend_handles_labels()
order = [0, 2, 1]
plt.legend([handles[idx] for idx in order], [labels[idx] for idx in order], prop={'size': 20})
fig = plt.gcf()
fig.set_size_inches(8, 7)
plt.savefig("../study_out/number_of_collected_coins.pdf")
plt.savefig(f"{results_path}/number_of_collected_coins.pdf")
plt.show()
def plot_reached_flags_per_step(rl_runs_names, tsp_runs_names, results_path):
reached_flags_per_step_emergent = [0] * 32 # Adjust based on your data length
# Load RL and TSP data from multiple runs
rl_runs = [load_metrics(results_path + f"/{rl_run}/metrics", "cleaned_dirt_piles_per_step") for rl_run in rl_runs_names]
rl_runs = [[pile - 1 for pile in run] for run in rl_runs] # Subtract the auxiliary pile
tsp_runs = [load_metrics(results_path + f"/{tsp_run}/metrics", "reached_flags") for tsp_run in tsp_runs_names]
# Pad runs to handle heterogeneous lengths
rl_runs = pad_runs(rl_runs)
tsp_runs = pad_runs(tsp_runs)
# Calculate mean and confidence intervals
mean_rl, lower_rl, upper_rl = mean_confidence_interval(rl_runs)
mean_tsp, lower_tsp, upper_tsp = mean_confidence_interval(tsp_runs)
# Plot the mean and confidence intervals
plt.fill_between(range(1, len(mean_rl) + 1), lower_rl, upper_rl, color='green', alpha=0.2)
plt.step(range(1, len(mean_rl) + 1), mean_rl, color='green', linewidth=3, label='Prevented (RL)')
plt.fill_between(range(1, len(mean_tsp) + 1), lower_tsp, upper_tsp, color='darkorange', alpha=0.2)
plt.step(range(1, len(mean_tsp) + 1), mean_tsp, linestyle='dotted', color='darkorange', linewidth=3, label='Prevented (TSP)')
plt.step(range(1, len(reached_flags_per_step_emergent) + 1), reached_flags_per_step_emergent, linestyle='--', color='darkred', linewidth=3, label='Emergent')
plt.xlabel("Environment step", fontsize=20)
plt.ylabel("Reached Flags", fontsize=20)
plt.xticks(range(1, len(reached_flags_per_step_emergent) + 1), fontsize=17)
plt.yticks(fontsize=17)
frame1 = plt.gca()
for idx, xlabel_i in enumerate(frame1.axes.get_xticklabels()):
if (idx + 1) % 5 != 0:
xlabel_i.set_visible(False)
xlabel_i.set_fontsize(0.0)
handles, labels = frame1.get_legend_handles_labels()
order = [0, 2, 1]
plt.legend([handles[idx] for idx in order], [labels[idx] for idx in order], prop={'size': 20})
fig = plt.gcf()
fig.set_size_inches(8, 7)
plt.savefig(f"{results_path}/number_of_reached_flags.pdf")
plt.show()
def plot_reached_flags_per_step():
# Observed behaviour for multi-agent setting consisting of runs 1 + 2
reached_flags_per_step_emergent = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
reached_flags_per_step_RL = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2]
reached_flags_per_step_TSP = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]
def plot_performance_distribution_on_coin_quadrant(dirt_quadrant, results_path, grid=False):
plt.rcParams["figure.autolayout"] = True
plt.rcParams["axes.edgecolor"] = "black"
plt.rcParams["axes.linewidth"] = 5.0
fig = plt.figure(figsize=(18, 13))
plt.step(range(1, len(reached_flags_per_step_RL) + 1), reached_flags_per_step_RL, color='green', linewidth=3, label='Prevented (RL)')
plt.step(range(1, len(reached_flags_per_step_emergent) + 1), reached_flags_per_step_emergent, linestyle='--', color='darkred', linewidth=3, label='Emergent')
plt.step(range(1, len(reached_flags_per_step_TSP) + 1), reached_flags_per_step_TSP, linestyle='dotted', color='darkorange', linewidth=3, label='Prevented (TSP)')
plt.xlabel("Environment step", fontsize=20)
plt.ylabel("Reached Flags", fontsize=20)
yint = range(min(reached_flags_per_step_RL), max(reached_flags_per_step_RL) + 1)
plt.yticks(yint, fontsize=17)
plt.xticks(range(1, len(reached_flags_per_step_emergent) + 1), fontsize=17)
frame1 = plt.gca()
# Only display every 5th tick label
for idx, xlabel_i in enumerate(frame1.axes.get_xticklabels()):
if (idx + 1) % 5 != 0:
xlabel_i.set_visible(False)
xlabel_i.set_fontsize(0.0)
# Change order of labels in legend
handles, labels = frame1.get_legend_handles_labels()
order = [0, 2, 1]
plt.legend([handles[idx] for idx in order], [labels[idx] for idx in order], prop={'size': 20})
fig = plt.gcf()
fig.set_size_inches(8, 7)
plt.savefig("../study_out/number_of_reached_flags.pdf")
rl_color = '#5D3A9B'
tsp_color = '#E66100'
# Boxplot
boxprops = dict(linestyle='-', linewidth=4)
whiskerprops = dict(linestyle='-', linewidth=4)
capprops = dict(linestyle='-', linewidth=4)
flierprops = dict(marker='o', markersize=14, markeredgewidth=4,
linestyle='none')
medianprops = dict(linestyle='-', linewidth=4, color='#40B0A6')
meanpointprops = dict(marker='D', markeredgecolor='black',
markerfacecolor='firebrick')
meanlineprops = dict(linestyle='-.', linewidth=4, color='purple')
bp = plt.boxplot([dirt_quadrant["RL_emergence"], dirt_quadrant["RL_prevented"], dirt_quadrant["TSP_emergence"],
dirt_quadrant["TSP_prevented"]], patch_artist=True, widths=0.6, flierprops=flierprops,
boxprops=boxprops, medianprops=medianprops, meanprops=meanlineprops,
whiskerprops=whiskerprops, capprops=capprops,
meanline=True, showmeans=False, positions=[1, 2.5, 4, 5.5])
colors = [rl_color, rl_color, tsp_color, tsp_color]
for bplot, color in zip([bp], [colors, colors]):
for patch, color in zip(bplot['boxes'], color):
patch.set_facecolor(color)
plt.tick_params(width=5, length=10)
plt.xticks([1, 2.5, 4, 5.5], labels=['Emergent \n (RL)', 'Prevented \n (RL)', 'Emergent \n (TSP)', 'Prevented \n (TSP)'], fontsize=50)
plt.yticks(fontsize=50)
plt.ylabel('No. environment steps', fontsize=50)
plt.xlabel("Agent Types", fontsize=50)
plt.grid(grid)
plt.tight_layout()
plt.savefig(f"{results_path}/number_of_collected_coins_distribution{'_grid' if grid else ''}.pdf")
plt.show()
def plot_reached_flags_per_step_with_error(mean_steps_RL_prevented, error_bars_RL_prevented,
mean_steps_TSP_prevented, error_bars_TSP_prevented, flags_reached,
results_path, grid=False):
plt.rcParams["figure.autolayout"] = True
plt.rcParams["axes.edgecolor"] = "black"
plt.rcParams["axes.linewidth"] = 5.0
fig = plt.figure(figsize=(18, 13))
# Line plot with error bars
plt.plot(range(30), [0 for _ in range(30)], color='gray', linestyle='--', linewidth=7,
label='Emergent')
plt.errorbar(mean_steps_RL_prevented, flags_reached, xerr=error_bars_RL_prevented, fmt='-o', ecolor='r', capsize=10, capthick=5,
markersize=20, label='Prevented (RL) + CI', color='#5D3A9B', linewidth=7)
plt.errorbar(mean_steps_TSP_prevented, flags_reached, xerr=error_bars_TSP_prevented, fmt='-o', ecolor='r', capsize=10, capthick=5,
markersize=20, label='Prevented (TSP) + CI', color='#E66100', linewidth=7)
plt.tick_params(width=5, length=10)
plt.xticks(fontsize=50)
plt.yticks(flags_reached, fontsize=50)
plt.xlabel("Avg. environment step", fontsize=50)
plt.ylabel('Reached flags', fontsize=50)
plt.legend(fontsize=45, loc='best', bbox_to_anchor=(0.38, 0.38))
plt.grid(grid)
plt.savefig(f"{results_path}/number_of_reached_flags{'_grid' if grid else ''}.pdf")
plt.show()