36 lines
1.0 KiB
Python
36 lines
1.0 KiB
Python
from pathlib import Path
|
|
|
|
import pandas as pd
|
|
from stable_baselines3.common.callbacks import BaseCallback
|
|
|
|
|
|
class TraningMonitor(BaseCallback):
|
|
|
|
def __init__(self, filepath, flush_interval=None):
|
|
super(TraningMonitor, self).__init__()
|
|
self.values = dict()
|
|
self.filepath = Path(filepath)
|
|
self.flush_interval = flush_interval
|
|
pass
|
|
|
|
def _on_training_start(self) -> None:
|
|
self.flush_interval = self.flush_interval or (self.locals['total_timesteps'] * 0.1)
|
|
|
|
def _flush(self):
|
|
df = pd.DataFrame.from_dict(self.values)
|
|
if not self.filepath.exists():
|
|
df.to_csv(self.filepath, mode='wb', header=True)
|
|
else:
|
|
df.to_csv(self.filepath, mode='a', header=False)
|
|
self.values = dict()
|
|
|
|
def _on_step(self) -> bool:
|
|
self.values[self.num_timesteps] = dict(reward=self.locals['rewards'].item())
|
|
if self.num_timesteps % self.flush_interval == 0:
|
|
self._flush()
|
|
return True
|
|
|
|
def on_training_end(self) -> None:
|
|
self._flush()
|
|
|