mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-21 19:31:34 +02:00
Logging Monitor Callback
This commit is contained in:
@ -56,12 +56,10 @@ class FactoryMonitor:
|
||||
|
||||
class MonitorCallback(BaseCallback):
|
||||
|
||||
def __init__(self, env, outpath='debug_out', filename='monitor'):
|
||||
def __init__(self, env, filepath=Path('debug_out/monitor.pick')):
|
||||
super(MonitorCallback, self).__init__()
|
||||
self._outpath = Path(outpath)
|
||||
self._filename = filename
|
||||
self.filepath = Path(filepath)
|
||||
self._monitor_list = list()
|
||||
self.out_file = self._outpath / f'{self._filename.split(".")[0]}.pick'
|
||||
self.env = env
|
||||
self.started = False
|
||||
self.closed = False
|
||||
@ -84,7 +82,7 @@ class MonitorCallback(BaseCallback):
|
||||
if self.started:
|
||||
pass
|
||||
else:
|
||||
self.out_file.parent.mkdir(exist_ok=True, parents=True)
|
||||
self.filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||
self.started = True
|
||||
pass
|
||||
|
||||
@ -93,7 +91,7 @@ class MonitorCallback(BaseCallback):
|
||||
pass
|
||||
else:
|
||||
# self.out_file.unlink(missing_ok=True)
|
||||
with self.out_file.open('wb') as f:
|
||||
with self.filepath.open('wb') as f:
|
||||
pickle.dump(self.monitor_as_df_list, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
self.closed = True
|
||||
|
||||
|
35
environments/logging/training.py
Normal file
35
environments/logging/training.py
Normal file
@ -0,0 +1,35 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
|
||||
|
||||
class TraningMonitor(BaseCallback):
|
||||
|
||||
def __init__(self, filepath, flush_interval=None):
|
||||
super(TraningMonitor, self).__init__()
|
||||
self.values = dict()
|
||||
self.filepath = Path(filepath)
|
||||
self.flush_interval = flush_interval
|
||||
pass
|
||||
|
||||
def _on_training_start(self) -> None:
|
||||
self.flush_interval = self.flush_interval or (self.locals['total_timesteps'] * 0.1)
|
||||
|
||||
def _flush(self):
|
||||
df = pd.DataFrame.from_dict(self.values)
|
||||
if not self.filepath.exists():
|
||||
df.to_csv(self.filepath, mode='wb', header=True)
|
||||
else:
|
||||
df.to_csv(self.filepath, mode='a', header=False)
|
||||
self.values = dict()
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
self.values[self.num_timesteps] = dict(reward=self.locals['rewards'].item())
|
||||
if self.num_timesteps % self.flush_interval == 0:
|
||||
self._flush()
|
||||
return True
|
||||
|
||||
def on_training_end(self) -> None:
|
||||
self._flush()
|
||||
|
Reference in New Issue
Block a user