Merge remote-tracking branch 'origin/main'
This commit is contained in:
commit
bcbd4a8078
@ -7,6 +7,7 @@ from typing import Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import simplejson
|
import simplejson
|
||||||
|
from deepdiff.operator import BaseOperator
|
||||||
from stable_baselines3.common.callbacks import BaseCallback
|
from stable_baselines3.common.callbacks import BaseCallback
|
||||||
|
|
||||||
from environments.factory.base.base_factory import REC_TAC
|
from environments.factory.base.base_factory import REC_TAC
|
||||||
@ -71,16 +72,29 @@ class EnvRecorder(BaseCallback):
|
|||||||
self._on_training_end()
|
self._on_training_end()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def save_records(self, filepath: Union[Path, str, None] = None, save_occupation_map=False, save_trajectory_map=False):
|
def save_records(self, filepath: Union[Path, str, None] = None,
|
||||||
|
only_deltas=True,
|
||||||
|
save_occupation_map=False,
|
||||||
|
save_trajectory_map=False,
|
||||||
|
):
|
||||||
filepath = Path(filepath or self.filepath)
|
filepath = Path(filepath or self.filepath)
|
||||||
filepath.parent.mkdir(exist_ok=True, parents=True)
|
filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||||
# cls.out_file.unlink(missing_ok=True)
|
# cls.out_file.unlink(missing_ok=True)
|
||||||
with filepath.open('w') as f:
|
with filepath.open('w') as f:
|
||||||
out_dict = {'n_episodes': self._episode_counter,
|
if only_deltas:
|
||||||
'env_params': self.unwrapped.params,
|
from deepdiff import DeepDiff, Delta
|
||||||
'header': self.unwrapped.summarize_header,
|
diff_dict = [DeepDiff(t1,t2, ignore_order=True)
|
||||||
'episodes': self._recorder_out_list
|
for t1, t2 in zip(self._recorder_out_list, self._recorder_out_list[1:])
|
||||||
}
|
]
|
||||||
|
out_dict = {'episodes': diff_dict}
|
||||||
|
|
||||||
|
else:
|
||||||
|
out_dict = {'episodes': self._recorder_out_list}
|
||||||
|
out_dict.update(
|
||||||
|
{'n_episodes': self._episode_counter,
|
||||||
|
'env_params': self.unwrapped.params,
|
||||||
|
'header': self.unwrapped.summarize_header
|
||||||
|
})
|
||||||
try:
|
try:
|
||||||
simplejson.dump(out_dict, f, indent=4)
|
simplejson.dump(out_dict, f, indent=4)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user