diff --git a/environments/logging/plotting.py b/environments/logging/plotting.py index b1c8732..c99c191 100644 --- a/environments/logging/plotting.py +++ b/environments/logging/plotting.py @@ -30,11 +30,11 @@ def prepare_plot(filepath, results_df, ext='png'): results_df.Measurement = results_df.Measurement.str.replace('_', '-') try: sns.set(rc={'text.usetex': True}, style='whitegrid') - sns.lineplot(data=results_df, x='Episode', y='Score', hue='Measurement', ci='sd', palette=PALETTE) + sns.lineplot(data=results_df, x='Episode', y='Score', hue='Measurement', ci=95, palette=PALETTE) plot(filepath, ext=ext) # plot raises errors not lineplot! except (FileNotFoundError, RuntimeError): print('Struggling to plot Figure using LaTeX - going back to normal.') plt.close('all') sns.set(rc={'text.usetex': False}, style='whitegrid') - sns.lineplot(data=results_df, x='Episode', y='Score', hue='Measurement', ci='sd', palette=PALETTE) + sns.lineplot(data=results_df, x='Episode', y='Score', hue='Measurement', ci=95, palette=PALETTE) plot(filepath, ext=ext) diff --git a/main.py b/main.py index 997b866..00fcb29 100644 --- a/main.py +++ b/main.py @@ -32,13 +32,17 @@ def combine_runs(run_path: Union[str, PathLike]): df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'}) columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS] - print(df.head()) + roll_n = 30 + skip_n = 20 - non_overlapp_window = df.groupby(['Run', df['Episode'] // 20]).mean() + non_overlapp_window = df.groupby(['Run', 'Episode']).rolling(roll_n, min_periods=1).mean() df_melted = non_overlapp_window[columns].reset_index().melt(id_vars=['Episode', 'Run'], value_vars=columns, var_name="Measurement", value_name="Score") + df_melted = df_melted[df_melted['Episode'] % skip_n == 0] + #df_melted['Episode'] = df_melted['Episode'] * skip_n # only needed for old version + prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted) print('Plotting done.') @@ -51,8 +55,6 @@ if __name__ == '__main__': time_stamp = int(time.time()) out_path = None - combine_runs(Path('/Users/romue/PycharmProjects/EDYS/debug_out/A2C_1622571986')) - exit() for modeL_type in [A2C, PPO, DQN]: for seed in range(5):