99 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			99 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import pickle
 | |
| from collections import defaultdict
 | |
| from pathlib import Path
 | |
| from typing import List, Dict
 | |
| 
 | |
| from stable_baselines3.common.callbacks import BaseCallback
 | |
| 
 | |
| from environments.helpers import IGNORED_DF_COLUMNS
 | |
| from environments.logging.plotting import prepare_plot
 | |
| import pandas as pd
 | |
| 
 | |
| 
 | |
| class MonitorCallback(BaseCallback):
 | |
| 
 | |
|     ext = 'png'
 | |
| 
 | |
|     def __init__(self, filepath=Path('debug_out/monitor.pick'), plotting=True):
 | |
|         super(MonitorCallback, self).__init__()
 | |
|         self.filepath = Path(filepath)
 | |
|         self._monitor_df = pd.DataFrame()
 | |
|         self._monitor_dicts = defaultdict(dict)
 | |
|         self.plotting = plotting
 | |
|         self.started = False
 | |
|         self.closed = False
 | |
| 
 | |
|     def __enter__(self):
 | |
|         self._on_training_start()
 | |
| 
 | |
|     def __exit__(self, exc_type, exc_val, exc_tb):
 | |
|         self._on_training_end()
 | |
| 
 | |
|     def _on_training_start(self) -> None:
 | |
|         if self.started:
 | |
|             pass
 | |
|         else:
 | |
|             self.filepath.parent.mkdir(exist_ok=True, parents=True)
 | |
|             self.started = True
 | |
|         pass
 | |
| 
 | |
|     def _on_training_end(self) -> None:
 | |
|         if self.closed:
 | |
|             pass
 | |
|         else:
 | |
|             # self.out_file.unlink(missing_ok=True)
 | |
|             with self.filepath.open('wb') as f:
 | |
|                 pickle.dump(self._monitor_df.reset_index(), f, protocol=pickle.HIGHEST_PROTOCOL)
 | |
|             if self.plotting:
 | |
|                 print('Monitor files were dumped to disk, now plotting....')
 | |
| 
 | |
|                 # %% Load MonitorList from Disk
 | |
|                 with self.filepath.open('rb') as f:
 | |
|                     monitor_list = pickle.load(f)
 | |
|                 df = None
 | |
|                 for m_idx, monitor in enumerate(monitor_list):
 | |
|                     monitor['episode'] = m_idx
 | |
|                     if df is None:
 | |
|                         df = pd.DataFrame(columns=monitor.columns)
 | |
|                     for _, row in monitor.iterrows():
 | |
|                         df.loc[df.shape[0]] = row
 | |
|                 if df is None:  # The env exited premature, we catch it.
 | |
|                     self.closed = True
 | |
|                     return
 | |
|                 for column in list(df.columns):
 | |
|                     if column != 'episode':
 | |
|                         df[f'{column}_roll'] = df[column].rolling(window=50).mean()
 | |
|                 # result.tail()
 | |
|                 prepare_plot(filepath=self.filepath, results_df=df.filter(regex=(".+_roll")))
 | |
|                 print('Plotting done.')
 | |
|             self.closed = True
 | |
| 
 | |
|     def _on_step(self, alt_infos: List[Dict] = None, alt_dones: List[bool] = None) -> bool:
 | |
|         infos = alt_infos or self.locals.get('infos', [])
 | |
|         if alt_dones is not None:
 | |
|             dones = alt_dones
 | |
|         elif self.locals.get('dones', None) is not None:
 | |
|             dones =self.locals.get('dones', None)
 | |
|         elif self.locals.get('done', None) is not None:
 | |
|             dones = self.locals.get('done', [None])
 | |
|         else:
 | |
|             dones = []
 | |
| 
 | |
|         for env_idx, (info, done) in enumerate(zip(infos, dones)):
 | |
|             self._monitor_dicts[env_idx][len(self._monitor_dicts[env_idx])] = {key: val for key, val in info.items()
 | |
|                                                                 if key not in ['terminal_observation', 'episode']
 | |
|                                                                 and not key.startswith('rec_')}
 | |
|             if done:
 | |
|                 env_monitor_df = pd.DataFrame.from_dict(self._monitor_dicts[env_idx], orient='index')
 | |
|                 self._monitor_dicts[env_idx] = dict()
 | |
|                 columns = [col for col in env_monitor_df.columns if col not in IGNORED_DF_COLUMNS]
 | |
|                 env_monitor_df = env_monitor_df.aggregate(
 | |
|                     {col: 'mean' if col.endswith('ount') else 'sum' for col in columns}
 | |
|                 )
 | |
|                 env_monitor_df['episode'] = len(self._monitor_df)
 | |
|                 self._monitor_df = self._monitor_df.append([env_monitor_df])
 | |
|             else:
 | |
|                 pass
 | |
|         return True
 | |
| 
 | 
