mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-21 19:31:34 +02:00
plotting for rolling mean over stats
This commit is contained in:
@ -101,10 +101,18 @@ class MonitorCallback(BaseCallback):
|
||||
# %% Load MonitorList from Disk
|
||||
with self.filepath.open('rb') as f:
|
||||
monitor_list = pickle.load(f)
|
||||
|
||||
result = pd.concat(monitor_list, sort=False)
|
||||
df = None
|
||||
for m_idx, monitor in enumerate(monitor_list):
|
||||
monitor['episode'] = m_idx
|
||||
if df is None:
|
||||
df = pd.DataFrame(columns=monitor.columns)
|
||||
for _, row in monitor.iterrows():
|
||||
df.loc[df.shape[0]] = row
|
||||
for column in list(df.columns):
|
||||
if column != 'episode':
|
||||
df[f'{column}_roll'] = df[column].rolling(window=50).mean()
|
||||
# result.tail()
|
||||
prepare_plot(filepath=self.filepath, results_df=result, tag='monitor')
|
||||
prepare_plot(filepath=self.filepath, results_df=df.filter(regex=(".+_roll")), tag='monitor')
|
||||
print('Plotting done.')
|
||||
self.closed = True
|
||||
|
||||
|
Reference in New Issue
Block a user