mirror of
				https://github.com/illiumst/marl-factory-grid.git
				synced 2025-10-31 04:37:25 +01:00 
			
		
		
		
	recorder fixed
This commit is contained in:
		| @@ -7,6 +7,7 @@ from typing import Union | ||||
| import numpy as np | ||||
| import pandas as pd | ||||
| import simplejson | ||||
| from deepdiff.operator import BaseOperator | ||||
| from stable_baselines3.common.callbacks import BaseCallback | ||||
|  | ||||
| from environments.factory.base.base_factory import REC_TAC | ||||
| @@ -71,16 +72,29 @@ class EnvRecorder(BaseCallback): | ||||
|         self._on_training_end() | ||||
|         return True | ||||
|  | ||||
|     def save_records(self, filepath: Union[Path, str, None] = None, save_occupation_map=False, save_trajectory_map=False): | ||||
|     def save_records(self, filepath: Union[Path, str, None] = None, | ||||
|                      only_deltas=True, | ||||
|                      save_occupation_map=False, | ||||
|                      save_trajectory_map=False, | ||||
|                      ): | ||||
|         filepath = Path(filepath or self.filepath) | ||||
|         filepath.parent.mkdir(exist_ok=True, parents=True) | ||||
|         # cls.out_file.unlink(missing_ok=True) | ||||
|         with filepath.open('w') as f: | ||||
|             out_dict = {'n_episodes': self._episode_counter, | ||||
|                         'env_params': self.unwrapped.params, | ||||
|                         'header': self.unwrapped.summarize_header, | ||||
|                         'episodes': self._recorder_out_list | ||||
|                         } | ||||
|             if only_deltas: | ||||
|                 from deepdiff import DeepDiff, Delta | ||||
|                 diff_dict = [DeepDiff(t1,t2, ignore_order=True) | ||||
|                              for t1, t2 in zip(self._recorder_out_list, self._recorder_out_list[1:]) | ||||
|                              ] | ||||
|                 out_dict = {'episodes': diff_dict} | ||||
|  | ||||
|             else: | ||||
|                 out_dict = {'episodes': self._recorder_out_list} | ||||
|             out_dict.update( | ||||
|                 {'n_episodes': self._episode_counter, | ||||
|                  'env_params': self.unwrapped.params, | ||||
|                  'header': self.unwrapped.summarize_header | ||||
|                  }) | ||||
|             try: | ||||
|                 simplejson.dump(out_dict, f, indent=4) | ||||
|             except TypeError: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Steffen Illium
					Steffen Illium