From c798226b2695f33ef6ea6cb33962cea648bbd124 Mon Sep 17 00:00:00 2001 From: romue Date: Wed, 2 Jun 2021 10:31:49 +0200 Subject: [PATCH] fixed prepare_plot and added LaTeX support --- environments/logging/plotting.py | 34 ++++++++++---------------------- main.py | 4 ++++ 2 files changed, 14 insertions(+), 24 deletions(-) diff --git a/environments/logging/plotting.py b/environments/logging/plotting.py index 654f150..b1c8732 100644 --- a/environments/logging/plotting.py +++ b/environments/logging/plotting.py @@ -18,9 +18,7 @@ PALETTE = 10 * ( ) -def plot(filepath, ext='png', **kwargs): - plt.rcParams.update(kwargs) - +def plot(filepath, ext='png'): plt.tight_layout() figure = plt.gcf() figure.savefig(str(filepath), format=ext) @@ -29,26 +27,14 @@ def plot(filepath, ext='png', **kwargs): def prepare_plot(filepath, results_df, ext='png'): - - sns.set_theme(palette=PALETTE, style='whitegrid') - font_size = 16 - tex_fonts = { - # Use LaTeX to write all text - "text.usetex": False, - "font.family": "serif", - # Use 10pt font in plots, to match 10pt font in document - "axes.labelsize": font_size, - "font.size": font_size, - # Make the legend/label fonts a little smaller - "legend.fontsize": font_size - 2, - "xtick.labelsize": font_size - 2, - "ytick.labelsize": font_size - 2 - } - - sns.lineplot(data=results_df, x='Episode', y='Score', hue='Measurement', ci='sd') - + results_df.Measurement = results_df.Measurement.str.replace('_', '-') try: - plot(filepath, ext=ext, **tex_fonts) + sns.set(rc={'text.usetex': True}, style='whitegrid') + sns.lineplot(data=results_df, x='Episode', y='Score', hue='Measurement', ci='sd', palette=PALETTE) + plot(filepath, ext=ext) # plot raises errors not lineplot! except (FileNotFoundError, RuntimeError): - tex_fonts['text.usetex'] = False - plot(filepath, ext=ext, **tex_fonts) + print('Struggling to plot Figure using LaTeX - going back to normal.') + plt.close('all') + sns.set(rc={'text.usetex': False}, style='whitegrid') + sns.lineplot(data=results_df, x='Episode', y='Score', hue='Measurement', ci='sd', palette=PALETTE) + plot(filepath, ext=ext) diff --git a/main.py b/main.py index e364e80..997b866 100644 --- a/main.py +++ b/main.py @@ -32,6 +32,8 @@ def combine_runs(run_path: Union[str, PathLike]): df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'}) columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS] + print(df.head()) + non_overlapp_window = df.groupby(['Run', df['Episode'] // 20]).mean() df_melted = non_overlapp_window[columns].reset_index().melt(id_vars=['Episode', 'Run'], @@ -49,6 +51,8 @@ if __name__ == '__main__': time_stamp = int(time.time()) out_path = None + combine_runs(Path('/Users/romue/PycharmProjects/EDYS/debug_out/A2C_1622571986')) + exit() for modeL_type in [A2C, PPO, DQN]: for seed in range(5):