From efedce579e255d0fda109cc1f9ce481963ae7b12 Mon Sep 17 00:00:00 2001 From: steffen-illium Date: Fri, 28 May 2021 18:05:53 +0200 Subject: [PATCH] better plotting --- environments/logging/plotting.py | 1 + main.py | 24 +++++++++--------------- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/environments/logging/plotting.py b/environments/logging/plotting.py index a9726d7..0592091 100644 --- a/environments/logging/plotting.py +++ b/environments/logging/plotting.py @@ -54,3 +54,4 @@ def prepare_plot(filepath, results_df, ext='png', tag=''): except (FileNotFoundError, RuntimeError): tex_fonts['text.usetex'] = False plot(filepath, ext=ext, tag=tag, **tex_fonts) + plt.show() diff --git a/main.py b/main.py index d71d26c..3199496 100644 --- a/main.py +++ b/main.py @@ -5,7 +5,6 @@ from os import PathLike from pathlib import Path import time import pandas as pd -from natsort import natsorted from stable_baselines3.common.callbacks import CallbackList @@ -44,20 +43,16 @@ def combine_runs(run_path: Union[str, PathLike]): df_list.append(df) df = pd.concat(df_list, ignore_index=True) df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'}) + columns = [col for col in df.columns if col not in ['Episode', 'Run', 'train_step', 'step']] - df_group = df.groupby(['Episode', 'Run']).aggregate({col: 'sum' if col in ['dirt_amount', - 'dirty_tiles'] else 'sum' - for col in df.columns if - col not in ['Episode', 'Run', 'train_step'] - }) + df_group = df.groupby(['Episode', 'Run']).aggregate( + {col: 'mean' if col in ['dirt_amount', 'dirty_tiles'] else 'sum' for col in columns} + ) - non_overlapp_window = df_group.groupby(['Run', (df_group.index.get_level_values('Episode') // 50)]).mean() + non_overlapp_window = df_group.groupby(['Run', (df_group.index.get_level_values('Episode') // 20)]).mean() df_melted = non_overlapp_window.reset_index().melt(id_vars=['Episode', 'Run'], - value_vars=['agent_0_vs_level', 'dirt_amount', - 'dirty_tiles', 'step_reward', - 'failed_cleanup_attempt', - 'dirt_cleaned'], var_name="Measurement", + value_vars=columns, var_name="Measurement", value_name="Score") prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted) @@ -66,11 +61,10 @@ def combine_runs(run_path: Union[str, PathLike]): if __name__ == '__main__': - # combine_runs('debug_out/PPO_1622128912') - # exit() - - from stable_baselines3 import DQN, PPO + combine_runs('debug_out/PPO_1622120377') + exit() + from stable_baselines3 import PPO # DQN dirt_props = DirtProperties() time_stamp = int(time.time())