diff --git a/environments/logging/recorder.py b/environments/logging/recorder.py index 893a0a3..bf83959 100644 --- a/environments/logging/recorder.py +++ b/environments/logging/recorder.py @@ -7,6 +7,7 @@ from typing import Union import numpy as np import pandas as pd import simplejson +from deepdiff.operator import BaseOperator from stable_baselines3.common.callbacks import BaseCallback from environments.factory.base.base_factory import REC_TAC @@ -71,16 +72,29 @@ class EnvRecorder(BaseCallback): self._on_training_end() return True - def save_records(self, filepath: Union[Path, str, None] = None, save_occupation_map=False, save_trajectory_map=False): + def save_records(self, filepath: Union[Path, str, None] = None, + only_deltas=True, + save_occupation_map=False, + save_trajectory_map=False, + ): filepath = Path(filepath or self.filepath) filepath.parent.mkdir(exist_ok=True, parents=True) # cls.out_file.unlink(missing_ok=True) with filepath.open('w') as f: - out_dict = {'n_episodes': self._episode_counter, - 'env_params': self.unwrapped.params, - 'header': self.unwrapped.summarize_header, - 'episodes': self._recorder_out_list - } + if only_deltas: + from deepdiff import DeepDiff, Delta + diff_dict = [DeepDiff(t1,t2, ignore_order=True) + for t1, t2 in zip(self._recorder_out_list, self._recorder_out_list[1:]) + ] + out_dict = {'episodes': diff_dict} + + else: + out_dict = {'episodes': self._recorder_out_list} + out_dict.update( + {'n_episodes': self._episode_counter, + 'env_params': self.unwrapped.params, + 'header': self.unwrapped.summarize_header + }) try: simplejson.dump(out_dict, f, indent=4) except TypeError: