recoder adaption
This commit is contained in:
0
plotting/__init__.py
Normal file
0
plotting/__init__.py
Normal file
155
plotting/compare_runs.py
Normal file
155
plotting/compare_runs.py
Normal file
@ -0,0 +1,155 @@
|
||||
import pickle
|
||||
import re
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Union, List
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from environments.helpers import IGNORED_DF_COLUMNS, MODEL_MAP
|
||||
from plotting.plotting import prepare_plot
|
||||
|
||||
|
||||
def compare_seed_runs(run_path: Union[str, PathLike]):
|
||||
run_path = Path(run_path)
|
||||
df_list = list()
|
||||
for run, monitor_file in enumerate(run_path.rglob('monitor*.pick')):
|
||||
with monitor_file.open('rb') as f:
|
||||
monitor_df = pickle.load(f)
|
||||
|
||||
monitor_df['run'] = run
|
||||
monitor_df = monitor_df.fillna(0)
|
||||
df_list.append(monitor_df)
|
||||
|
||||
df = pd.concat(df_list, ignore_index=True)
|
||||
df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'}).sort_values(['Run', 'Episode'])
|
||||
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
|
||||
|
||||
roll_n = 50
|
||||
|
||||
non_overlapp_window = df.groupby(['Run', 'Episode']).rolling(roll_n, min_periods=1).mean()
|
||||
|
||||
df_melted = non_overlapp_window[columns].reset_index().melt(id_vars=['Episode', 'Run'],
|
||||
value_vars=columns, var_name="Measurement",
|
||||
value_name="Score")
|
||||
|
||||
if df_melted['Episode'].max() > 800:
|
||||
skip_n = round(df_melted['Episode'].max() * 0.02)
|
||||
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||
|
||||
prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted)
|
||||
print('Plotting done.')
|
||||
|
||||
|
||||
def compare_model_runs(run_path: Path, run_identifier: Union[str, int], parameter: Union[str, List[str]]):
|
||||
run_path = Path(run_path)
|
||||
df_list = list()
|
||||
parameter = [parameter] if isinstance(parameter, str) else parameter
|
||||
for path in run_path.iterdir():
|
||||
if path.is_dir() and str(run_identifier) in path.name:
|
||||
for run, monitor_file in enumerate(path.rglob('monitor*.pick')):
|
||||
with monitor_file.open('rb') as f:
|
||||
monitor_df = pickle.load(f)
|
||||
|
||||
monitor_df['run'] = run
|
||||
monitor_df['model'] = next((x for x in path.name.split('_') if x in MODEL_MAP.keys()))
|
||||
monitor_df = monitor_df.fillna(0)
|
||||
df_list.append(monitor_df)
|
||||
|
||||
df = pd.concat(df_list, ignore_index=True)
|
||||
df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run', 'model': 'Model'})
|
||||
columns = [col for col in df.columns if col in parameter]
|
||||
|
||||
last_episode_to_report = min(df.groupby(['Model'])['Episode'].max())
|
||||
df = df[df['Episode'] < last_episode_to_report]
|
||||
|
||||
roll_n = 40
|
||||
non_overlapp_window = df.groupby(['Model', 'Run', 'Episode']).rolling(roll_n, min_periods=1).mean()
|
||||
|
||||
df_melted = non_overlapp_window[columns].reset_index().melt(id_vars=['Episode', 'Run', 'Model'],
|
||||
value_vars=columns, var_name="Measurement",
|
||||
value_name="Score")
|
||||
|
||||
if df_melted['Episode'].max() > 80:
|
||||
skip_n = round(df_melted['Episode'].max() * 0.02)
|
||||
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||
|
||||
style = 'Measurement' if len(columns) > 1 else None
|
||||
prepare_plot(run_path / f'{run_identifier}_compare_{parameter}.png', df_melted, hue='Model', style=style)
|
||||
print('Plotting done.')
|
||||
|
||||
|
||||
def compare_all_parameter_runs(run_root_path: Path, parameter: Union[str, List[str]],
|
||||
param_names: Union[List[str], None] = None, str_to_ignore=''):
|
||||
run_root_path = Path(run_root_path)
|
||||
df_list = list()
|
||||
parameter = [parameter] if isinstance(parameter, str) else parameter
|
||||
for monitor_idx, monitor_file in enumerate(run_root_path.rglob('monitor*.pick')):
|
||||
with monitor_file.open('rb') as f:
|
||||
monitor_df = pickle.load(f)
|
||||
|
||||
parameters = [x.name for x in monitor_file.parents if x.parent not in run_root_path.parents]
|
||||
if str_to_ignore:
|
||||
parameters = [re.sub(f'_*({str_to_ignore})', '', param) for param in parameters]
|
||||
|
||||
if monitor_idx == 0:
|
||||
if param_names is not None:
|
||||
if len(param_names) < len(parameters):
|
||||
# FIXME: Missing Seed Detection, see below @111
|
||||
param_names = [next(param_names) if param not in MODEL_MAP.keys() else 'Model' for param in parameters]
|
||||
elif len(param_names) == len(parameters):
|
||||
pass
|
||||
else:
|
||||
raise ValueError
|
||||
else:
|
||||
param_names = []
|
||||
for param_idx, param in enumerate(parameters):
|
||||
dtype = None
|
||||
if param in MODEL_MAP.keys():
|
||||
param_name = 'Model'
|
||||
elif '_' in param:
|
||||
param_split = param.split('_')
|
||||
if len(param_split) == 2 and any(split in MODEL_MAP.keys() for split in param_split):
|
||||
# Extract the seed
|
||||
param = int(next(x for x in param_split if x not in MODEL_MAP))
|
||||
param_name = 'Seed'
|
||||
dtype = int
|
||||
else:
|
||||
param_name = f'param_{param_idx}'
|
||||
else:
|
||||
param_name = f'param_{param_idx}'
|
||||
dtype = dtype if dtype is not None else str
|
||||
monitor_df[param_name] = str(param)
|
||||
monitor_df[param_name] = monitor_df[param_name].astype(dtype)
|
||||
if monitor_idx == 0:
|
||||
param_names.append(param_name)
|
||||
|
||||
monitor_df = monitor_df.fillna(0)
|
||||
df_list.append(monitor_df)
|
||||
|
||||
df = pd.concat(df_list, ignore_index=True)
|
||||
df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
|
||||
|
||||
for param_name in param_names:
|
||||
df[param_name] = df[param_name].astype(str)
|
||||
columns = [col for col in df.columns if col in parameter]
|
||||
|
||||
last_episode_to_report = min(df.groupby(['Model'])['Episode'].max())
|
||||
df = df[df['Episode'] < last_episode_to_report]
|
||||
|
||||
if df['Episode'].max() > 80:
|
||||
skip_n = round(df['Episode'].max() * 0.02)
|
||||
df = df[df['Episode'] % skip_n == 0]
|
||||
combinations = [x for x in param_names if x not in ['Model', 'Seed']]
|
||||
df['Parameter Combination'] = df[combinations].apply(lambda row: '_'.join(row.values.astype(str)), axis=1)
|
||||
df.drop(columns=combinations, inplace=True)
|
||||
|
||||
# non_overlapp_window = df.groupby(param_names).sum()
|
||||
|
||||
df_melted = df.reset_index().melt(id_vars=['Parameter Combination', 'Episode'],
|
||||
value_vars=columns, var_name="Measurement",
|
||||
value_name="Score")
|
||||
|
||||
style = 'Measurement' if len(columns) > 1 else None
|
||||
prepare_plot(run_root_path / f'compare_{parameter}.png', df_melted, hue='Parameter Combination', style=style)
|
||||
print('Plotting done.')
|
46
plotting/plotting.py
Normal file
46
plotting/plotting.py
Normal file
@ -0,0 +1,46 @@
|
||||
import seaborn as sns
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
PALETTE = 10 * (
|
||||
"#377eb8",
|
||||
"#4daf4a",
|
||||
"#984ea3",
|
||||
"#e41a1c",
|
||||
"#ff7f00",
|
||||
"#a65628",
|
||||
"#f781bf",
|
||||
"#888888",
|
||||
"#a6cee3",
|
||||
"#b2df8a",
|
||||
"#cab2d6",
|
||||
"#fb9a99",
|
||||
"#fdbf6f",
|
||||
)
|
||||
|
||||
|
||||
def plot(filepath, ext='png'):
|
||||
plt.tight_layout()
|
||||
figure = plt.gcf()
|
||||
figure.savefig(str(filepath), format=ext)
|
||||
plt.show()
|
||||
plt.clf()
|
||||
|
||||
|
||||
def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None):
|
||||
df = results_df.copy()
|
||||
df[hue] = df[hue].str.replace('_', '-')
|
||||
hue_order = sorted(list(df[hue].unique()))
|
||||
try:
|
||||
sns.set(rc={'text.usetex': True}, style='whitegrid')
|
||||
lineplot = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE,
|
||||
hue_order=hue_order, hue=hue, style=style)
|
||||
# lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
|
||||
plot(filepath, ext=ext) # plot raises errors not lineplot!
|
||||
except (FileNotFoundError, RuntimeError):
|
||||
print('Struggling to plot Figure using LaTeX - going back to normal.')
|
||||
plt.close('all')
|
||||
sns.set(rc={'text.usetex': False}, style='whitegrid')
|
||||
lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
|
||||
ci=95, palette=PALETTE, hue_order=hue_order)
|
||||
# lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
|
||||
plot(filepath, ext=ext)
|
Reference in New Issue
Block a user