Merge remote-tracking branch 'origin/main'

This commit is contained in:
Robert Müller 2022-12-15 17:17:43 +01:00
commit bcbd4a8078

View File

@ -7,6 +7,7 @@ from typing import Union
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import simplejson import simplejson
from deepdiff.operator import BaseOperator
from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.callbacks import BaseCallback
from environments.factory.base.base_factory import REC_TAC from environments.factory.base.base_factory import REC_TAC
@ -71,16 +72,29 @@ class EnvRecorder(BaseCallback):
self._on_training_end() self._on_training_end()
return True return True
def save_records(self, filepath: Union[Path, str, None] = None, save_occupation_map=False, save_trajectory_map=False): def save_records(self, filepath: Union[Path, str, None] = None,
only_deltas=True,
save_occupation_map=False,
save_trajectory_map=False,
):
filepath = Path(filepath or self.filepath) filepath = Path(filepath or self.filepath)
filepath.parent.mkdir(exist_ok=True, parents=True) filepath.parent.mkdir(exist_ok=True, parents=True)
# cls.out_file.unlink(missing_ok=True) # cls.out_file.unlink(missing_ok=True)
with filepath.open('w') as f: with filepath.open('w') as f:
out_dict = {'n_episodes': self._episode_counter, if only_deltas:
'env_params': self.unwrapped.params, from deepdiff import DeepDiff, Delta
'header': self.unwrapped.summarize_header, diff_dict = [DeepDiff(t1,t2, ignore_order=True)
'episodes': self._recorder_out_list for t1, t2 in zip(self._recorder_out_list, self._recorder_out_list[1:])
} ]
out_dict = {'episodes': diff_dict}
else:
out_dict = {'episodes': self._recorder_out_list}
out_dict.update(
{'n_episodes': self._episode_counter,
'env_params': self.unwrapped.params,
'header': self.unwrapped.summarize_header
})
try: try:
simplejson.dump(out_dict, f, indent=4) simplejson.dump(out_dict, f, indent=4)
except TypeError: except TypeError: