recorder fixed
This commit is contained in:
parent
4f3924d3ab
commit
6c2df735d4
@ -7,6 +7,7 @@ from typing import Union
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import simplejson
|
||||
from deepdiff.operator import BaseOperator
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
|
||||
from environments.factory.base.base_factory import REC_TAC
|
||||
@ -71,16 +72,29 @@ class EnvRecorder(BaseCallback):
|
||||
self._on_training_end()
|
||||
return True
|
||||
|
||||
def save_records(self, filepath: Union[Path, str, None] = None, save_occupation_map=False, save_trajectory_map=False):
|
||||
def save_records(self, filepath: Union[Path, str, None] = None,
|
||||
only_deltas=True,
|
||||
save_occupation_map=False,
|
||||
save_trajectory_map=False,
|
||||
):
|
||||
filepath = Path(filepath or self.filepath)
|
||||
filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||
# cls.out_file.unlink(missing_ok=True)
|
||||
with filepath.open('w') as f:
|
||||
out_dict = {'n_episodes': self._episode_counter,
|
||||
if only_deltas:
|
||||
from deepdiff import DeepDiff, Delta
|
||||
diff_dict = [DeepDiff(t1,t2, ignore_order=True)
|
||||
for t1, t2 in zip(self._recorder_out_list, self._recorder_out_list[1:])
|
||||
]
|
||||
out_dict = {'episodes': diff_dict}
|
||||
|
||||
else:
|
||||
out_dict = {'episodes': self._recorder_out_list}
|
||||
out_dict.update(
|
||||
{'n_episodes': self._episode_counter,
|
||||
'env_params': self.unwrapped.params,
|
||||
'header': self.unwrapped.summarize_header,
|
||||
'episodes': self._recorder_out_list
|
||||
}
|
||||
'header': self.unwrapped.summarize_header
|
||||
})
|
||||
try:
|
||||
simplejson.dump(out_dict, f, indent=4)
|
||||
except TypeError:
|
||||
|
Loading…
x
Reference in New Issue
Block a user