recorder fixed

This commit is contained in:
Steffen Illium 2022-12-15 13:28:22 +01:00
parent 4f3924d3ab
commit 6c2df735d4

View File

@ -7,6 +7,7 @@ from typing import Union
import numpy as np
import pandas as pd
import simplejson
from deepdiff.operator import BaseOperator
from stable_baselines3.common.callbacks import BaseCallback
from environments.factory.base.base_factory import REC_TAC
@ -71,16 +72,29 @@ class EnvRecorder(BaseCallback):
self._on_training_end()
return True
def save_records(self, filepath: Union[Path, str, None] = None, save_occupation_map=False, save_trajectory_map=False):
def save_records(self, filepath: Union[Path, str, None] = None,
only_deltas=True,
save_occupation_map=False,
save_trajectory_map=False,
):
filepath = Path(filepath or self.filepath)
filepath.parent.mkdir(exist_ok=True, parents=True)
# cls.out_file.unlink(missing_ok=True)
with filepath.open('w') as f:
out_dict = {'n_episodes': self._episode_counter,
if only_deltas:
from deepdiff import DeepDiff, Delta
diff_dict = [DeepDiff(t1,t2, ignore_order=True)
for t1, t2 in zip(self._recorder_out_list, self._recorder_out_list[1:])
]
out_dict = {'episodes': diff_dict}
else:
out_dict = {'episodes': self._recorder_out_list}
out_dict.update(
{'n_episodes': self._episode_counter,
'env_params': self.unwrapped.params,
'header': self.unwrapped.summarize_header,
'episodes': self._recorder_out_list
}
'header': self.unwrapped.summarize_header
})
try:
simplejson.dump(out_dict, f, indent=4)
except TypeError: