recoder adaption

This commit is contained in:
Steffen Illium
2021-10-04 17:53:19 +02:00
parent 4c21a0af7c
commit 696e520862
21 changed files with 665 additions and 380 deletions

View File

@@ -6,7 +6,7 @@ from typing import List, Dict
from stable_baselines3.common.callbacks import BaseCallback
from environments.helpers import IGNORED_DF_COLUMNS
from environments.logging.plotting import prepare_plot
import pandas as pd
@@ -14,85 +14,76 @@ class MonitorCallback(BaseCallback):
ext = 'png'
def __init__(self, filepath=Path('debug_out/monitor.pick'), plotting=True):
def __init__(self, filepath=Path('debug_out/monitor.pick')):
super(MonitorCallback, self).__init__()
self.filepath = Path(filepath)
self._monitor_df = pd.DataFrame()
self._monitor_dicts = defaultdict(dict)
self.plotting = plotting
self.started = False
self.closed = False
def __enter__(self):
self._on_training_start()
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._on_training_end()
self.stop()
def _on_training_start(self) -> None:
if self.started:
pass
else:
self.filepath.parent.mkdir(exist_ok=True, parents=True)
self.started = True
self.start()
pass
def _on_training_end(self) -> None:
if self.closed:
pass
else:
# self.out_file.unlink(missing_ok=True)
with self.filepath.open('wb') as f:
pickle.dump(self._monitor_df.reset_index(), f, protocol=pickle.HIGHEST_PROTOCOL)
if self.plotting:
print('Monitor files were dumped to disk, now plotting....')
# %% Load MonitorList from Disk
with self.filepath.open('rb') as f:
monitor_list = pickle.load(f)
df = None
for m_idx, monitor in enumerate(monitor_list):
monitor['episode'] = m_idx
if df is None:
df = pd.DataFrame(columns=monitor.columns)
for _, row in monitor.iterrows():
df.loc[df.shape[0]] = row
if df is None: # The env exited premature, we catch it.
self.closed = True
return
for column in list(df.columns):
if column != 'episode':
df[f'{column}_roll'] = df[column].rolling(window=50).mean()
# result.tail()
prepare_plot(filepath=self.filepath, results_df=df.filter(regex=(".+_roll")))
print('Plotting done.')
self.closed = True
self.stop()
def _on_step(self, alt_infos: List[Dict] = None, alt_dones: List[bool] = None) -> bool:
infos = alt_infos or self.locals.get('infos', [])
if alt_dones is not None:
dones = alt_dones
elif self.locals.get('dones', None) is not None:
dones =self.locals.get('dones', None)
elif self.locals.get('done', None) is not None:
dones = self.locals.get('done', [None])
else:
dones = []
if self.started:
for env_idx, info in enumerate(self.locals.get('infos', [])):
self.read_info(env_idx, info)
for env_idx, (info, done) in enumerate(zip(infos, dones)):
self._monitor_dicts[env_idx][len(self._monitor_dicts[env_idx])] = {key: val for key, val in info.items()
if key not in ['terminal_observation', 'episode']
and not key.startswith('rec_')}
if done:
env_monitor_df = pd.DataFrame.from_dict(self._monitor_dicts[env_idx], orient='index')
self._monitor_dicts[env_idx] = dict()
columns = [col for col in env_monitor_df.columns if col not in IGNORED_DF_COLUMNS]
env_monitor_df = env_monitor_df.aggregate(
{col: 'mean' if col.endswith('ount') else 'sum' for col in columns}
)
env_monitor_df['episode'] = len(self._monitor_df)
self._monitor_df = self._monitor_df.append([env_monitor_df])
else:
pass
for env_idx, done in list(
enumerate(self.locals.get('dones', []))) + list(enumerate(self.locals.get('done', []))):
self.read_done(env_idx, done)
else:
pass
return True
def read_info(self, env_idx, info: dict):
self._monitor_dicts[env_idx][len(self._monitor_dicts[env_idx])] = {
key: val for key, val in info.items() if
key not in ['terminal_observation', 'episode'] and not key.startswith('rec_')}
return
def read_done(self, env_idx, done):
if done:
env_monitor_df = pd.DataFrame.from_dict(self._monitor_dicts[env_idx], orient='index')
self._monitor_dicts[env_idx] = dict()
columns = [col for col in env_monitor_df.columns if col not in IGNORED_DF_COLUMNS]
env_monitor_df = env_monitor_df.aggregate(
{col: 'mean' if col.endswith('ount') else 'sum' for col in columns}
)
env_monitor_df['episode'] = len(self._monitor_df)
self._monitor_df = self._monitor_df.append([env_monitor_df])
else:
pass
return
def stop(self):
# self.out_file.unlink(missing_ok=True)
with self.filepath.open('wb') as f:
pickle.dump(self._monitor_df.reset_index(), f, protocol=pickle.HIGHEST_PROTOCOL)
self.closed = True
def start(self):
if self.started:
pass
else:
self.filepath.parent.mkdir(exist_ok=True, parents=True)
self.started = True
pass

View File

@@ -1,46 +0,0 @@
import seaborn as sns
from matplotlib import pyplot as plt
PALETTE = 10 * (
"#377eb8",
"#4daf4a",
"#984ea3",
"#e41a1c",
"#ff7f00",
"#a65628",
"#f781bf",
"#888888",
"#a6cee3",
"#b2df8a",
"#cab2d6",
"#fb9a99",
"#fdbf6f",
)
def plot(filepath, ext='png'):
plt.tight_layout()
figure = plt.gcf()
figure.savefig(str(filepath), format=ext)
plt.show()
plt.clf()
def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None):
df = results_df.copy()
df[hue] = df[hue].str.replace('_', '-')
hue_order = sorted(list(df[hue].unique()))
try:
sns.set(rc={'text.usetex': True}, style='whitegrid')
lineplot = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE,
hue_order=hue_order, hue=hue, style=style)
# lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
plot(filepath, ext=ext) # plot raises errors not lineplot!
except (FileNotFoundError, RuntimeError):
print('Struggling to plot Figure using LaTeX - going back to normal.')
plt.close('all')
sns.set(rc={'text.usetex': False}, style='whitegrid')
lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
ci=95, palette=PALETTE, hue_order=hue_order)
# lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
plot(filepath, ext=ext)

View File

@@ -3,11 +3,10 @@ from collections import defaultdict
from pathlib import Path
from typing import Union
import pandas as pd
import simplejson
from stable_baselines3.common.callbacks import BaseCallback
from environments.factory.base.base_factory import REC_TAC
from environments.helpers import IGNORED_DF_COLUMNS
# noinspection PyAttributeOutsideInit
@@ -18,8 +17,8 @@ class RecorderCallback(BaseCallback):
self.trajectory_map = trajectory_map
self.occupation_map = occupation_map
self.filepath = Path(filepath)
self._recorder_dict = defaultdict(dict)
self._recorder_json_list = list()
self._recorder_dict = defaultdict(list)
self._recorder_out_list = list()
self.do_record: bool
self.started = False
self.closed = False
@@ -27,15 +26,15 @@ class RecorderCallback(BaseCallback):
def read_info(self, env_idx, info: dict):
if info_dict := {key.replace(REC_TAC, ''): val for key, val in info.items() if key.startswith(f'{REC_TAC}')}:
info_dict.update(episode=(self.num_timesteps + env_idx))
self._recorder_dict[env_idx][len(self._recorder_dict[env_idx])] = info_dict
self._recorder_dict[env_idx].append(info_dict)
else:
pass
return
def read_done(self, env_idx, done):
if done:
self._recorder_json_list.append(json.dumps(self._recorder_dict[env_idx]))
self._recorder_dict[env_idx] = dict()
self._recorder_out_list.append({'steps': self._recorder_dict[env_idx]})
self._recorder_dict[env_idx] = list()
else:
pass
@@ -51,8 +50,11 @@ class RecorderCallback(BaseCallback):
if self.do_record and self.started:
# self.out_file.unlink(missing_ok=True)
with self.filepath.open('w') as f:
json_list = self._recorder_json_list
json.dump(json_list, f, indent=4)
out_dict = {'episodes': self._recorder_out_list}
try:
simplejson.dump(out_dict, f, indent=4)
except TypeError:
print('Shit')
if self.occupation_map:
print('Recorder files were dumped to disk, now plotting the occupation map...')