Experiments look good
This commit is contained in:
@ -10,6 +10,45 @@ from environments.helpers import IGNORED_DF_COLUMNS, MODEL_MAP
|
||||
from plotting.plotting import prepare_plot
|
||||
|
||||
|
||||
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None):
|
||||
run_path = Path(run_path)
|
||||
df_list = list()
|
||||
if run_path.is_dir():
|
||||
monitor_file = next(run_path.glob('*monitor*.pick'))
|
||||
elif run_path.exists() and run_path.is_file():
|
||||
monitor_file = run_path
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
with monitor_file.open('rb') as f:
|
||||
monitor_df = pickle.load(f)
|
||||
|
||||
monitor_df = monitor_df.fillna(0)
|
||||
df_list.append(monitor_df)
|
||||
|
||||
df = pd.concat(df_list, ignore_index=True)
|
||||
df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
|
||||
if column_keys is not None:
|
||||
columns = [col for col in column_keys if col in df.columns]
|
||||
else:
|
||||
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
|
||||
|
||||
roll_n = 50
|
||||
|
||||
non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean()
|
||||
|
||||
df_melted = df[columns + ['Episode']].reset_index().melt(id_vars=['Episode'],
|
||||
value_vars=columns, var_name="Measurement",
|
||||
value_name="Score")
|
||||
|
||||
if df_melted['Episode'].max() > 800:
|
||||
skip_n = round(df_melted['Episode'].max() * 0.02)
|
||||
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||
|
||||
prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||
print('Plotting done.')
|
||||
|
||||
|
||||
def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
|
||||
run_path = Path(run_path)
|
||||
df_list = list()
|
||||
@ -37,7 +76,10 @@ def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
|
||||
skip_n = round(df_melted['Episode'].max() * 0.02)
|
||||
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||
|
||||
prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||
if run_path.is_dir():
|
||||
prepare_plot(run_path / f'{run_path}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||
elif run_path.exists() and run_path.is_file():
|
||||
prepare_plot(run_path.parent / f'{run_path.parent}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||
print('Plotting done.')
|
||||
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
import seaborn as sns
|
||||
import matplotlib as mpl
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
PALETTE = 10 * (
|
||||
@ -21,7 +22,14 @@ PALETTE = 10 * (
|
||||
def plot(filepath, ext='png'):
|
||||
plt.tight_layout()
|
||||
figure = plt.gcf()
|
||||
figure.savefig(str(filepath), format=ext)
|
||||
ax = plt.gca()
|
||||
legends = [c for c in ax.get_children() if isinstance(c, mpl.legend.Legend)]
|
||||
|
||||
if legends:
|
||||
figure.savefig(str(filepath), format=ext, bbox_extra_artists=(*legends,), bbox_inches='tight')
|
||||
else:
|
||||
figure.savefig(str(filepath), format=ext)
|
||||
|
||||
plt.show()
|
||||
plt.clf()
|
||||
|
||||
@ -30,7 +38,7 @@ def prepare_tex(df, hue, style, hue_order):
|
||||
sns.set(rc={'text.usetex': True}, style='whitegrid')
|
||||
lineplot = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE,
|
||||
hue_order=hue_order, hue=hue, style=style)
|
||||
# lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
|
||||
lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
|
||||
plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
|
||||
plt.tight_layout()
|
||||
return lineplot
|
||||
@ -48,6 +56,19 @@ def prepare_plt(df, hue, style, hue_order):
|
||||
return lineplot
|
||||
|
||||
|
||||
def prepare_center_double_column_legend(df, hue, style, hue_order):
|
||||
print('Struggling to plot Figure using LaTeX - going back to normal.')
|
||||
plt.close('all')
|
||||
sns.set(rc={'text.usetex': False}, style='whitegrid')
|
||||
fig = plt.figure(figsize=(10, 11))
|
||||
lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
|
||||
ci=95, palette=PALETTE, hue_order=hue_order, legend=False)
|
||||
# plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
|
||||
lineplot.legend(hue_order, ncol=3, loc='lower center', title='Parameter Combinations', bbox_to_anchor=(0.5, -0.43))
|
||||
plt.tight_layout()
|
||||
return lineplot
|
||||
|
||||
|
||||
def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None, use_tex=False):
|
||||
df = results_df.copy()
|
||||
df[hue] = df[hue].str.replace('_', '-')
|
||||
|
Reference in New Issue
Block a user