From 9ed83e676511d378b34537df88b70c7a7a0a290a Mon Sep 17 00:00:00 2001 From: romue Date: Wed, 2 Jun 2021 14:01:22 +0200 Subject: [PATCH] use same colors for plots --- environments/logging/plotting.py | 7 +++++-- main.py | 2 ++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/environments/logging/plotting.py b/environments/logging/plotting.py index c99c191..6cd43ea 100644 --- a/environments/logging/plotting.py +++ b/environments/logging/plotting.py @@ -28,13 +28,16 @@ def plot(filepath, ext='png'): def prepare_plot(filepath, results_df, ext='png'): results_df.Measurement = results_df.Measurement.str.replace('_', '-') + hue_order = sorted(list(results_df.Measurement.unique())) try: sns.set(rc={'text.usetex': True}, style='whitegrid') - sns.lineplot(data=results_df, x='Episode', y='Score', hue='Measurement', ci=95, palette=PALETTE) + sns.lineplot(data=results_df, x='Episode', y='Score', hue='Measurement', + ci=95, palette=PALETTE, hue_order=hue_order) plot(filepath, ext=ext) # plot raises errors not lineplot! except (FileNotFoundError, RuntimeError): print('Struggling to plot Figure using LaTeX - going back to normal.') plt.close('all') sns.set(rc={'text.usetex': False}, style='whitegrid') - sns.lineplot(data=results_df, x='Episode', y='Score', hue='Measurement', ci=95, palette=PALETTE) + sns.lineplot(data=results_df, x='Episode', y='Score', hue='Measurement', + ci=95, palette=PALETTE, hue_order=hue_order) plot(filepath, ext=ext) diff --git a/main.py b/main.py index 00fcb29..4b61a3d 100644 --- a/main.py +++ b/main.py @@ -55,6 +55,8 @@ if __name__ == '__main__': time_stamp = int(time.time()) out_path = None + combine_runs(Path(__file__).parent / 'debug_out'/ 'A2C_1622571986') + exit() for modeL_type in [A2C, PPO, DQN]: for seed in range(5):