mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-18 18:52:52 +02:00
plotting
This commit is contained in:
@ -1,35 +1,54 @@
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
|
||||
from environments.logging.plotting import prepare_plot
|
||||
|
||||
|
||||
class TraningMonitor(BaseCallback):
|
||||
|
||||
def __init__(self, filepath, flush_interval=None):
|
||||
super(TraningMonitor, self).__init__()
|
||||
self.values = dict()
|
||||
self.values = defaultdict(dict)
|
||||
self.rewards = defaultdict(lambda: 0)
|
||||
|
||||
self.filepath = Path(filepath)
|
||||
self.flush_interval = flush_interval
|
||||
self.next_flush: int
|
||||
pass
|
||||
|
||||
def _on_training_start(self) -> None:
|
||||
self.flush_interval = self.flush_interval or (self.locals['total_timesteps'] * 0.1)
|
||||
self.next_flush = self.flush_interval
|
||||
|
||||
def _flush(self):
|
||||
df = pd.DataFrame.from_dict(self.values)
|
||||
df = pd.DataFrame.from_dict(self.values, orient='index')
|
||||
if not self.filepath.exists():
|
||||
df.to_csv(self.filepath, mode='wb', header=True)
|
||||
else:
|
||||
df.to_csv(self.filepath, mode='a', header=False)
|
||||
self.values = dict()
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
self.values[self.num_timesteps] = dict(reward=self.locals['rewards'].item())
|
||||
if self.num_timesteps % self.flush_interval == 0:
|
||||
for idx, done in np.ndenumerate(self.locals['dones']):
|
||||
idx = idx[0]
|
||||
# self.values[self.num_timesteps].update(**{f'reward_env_{idx}': self.locals['rewards'][idx]})
|
||||
self.rewards[idx] += self.locals['rewards'][idx]
|
||||
if done:
|
||||
self.values[self.num_timesteps].update(**{f'acc_epispde_r_env_{idx}': self.rewards[idx]})
|
||||
self.rewards[idx] = 0
|
||||
|
||||
if self.num_timesteps >= self.next_flush and self.values:
|
||||
self._flush()
|
||||
self.values = defaultdict(dict)
|
||||
|
||||
self.next_flush += self.flush_interval
|
||||
return True
|
||||
|
||||
def on_training_end(self) -> None:
|
||||
self._flush()
|
||||
self.values = defaultdict(dict)
|
||||
# prepare_plot()
|
||||
|
||||
|
Reference in New Issue
Block a user