mirror of
				https://github.com/illiumst/marl-factory-grid.git
				synced 2025-10-31 12:37:27 +01:00 
			
		
		
		
	recorder fixed
This commit is contained in:
		| @@ -7,6 +7,7 @@ from typing import Union | |||||||
| import numpy as np | import numpy as np | ||||||
| import pandas as pd | import pandas as pd | ||||||
| import simplejson | import simplejson | ||||||
|  | from deepdiff.operator import BaseOperator | ||||||
| from stable_baselines3.common.callbacks import BaseCallback | from stable_baselines3.common.callbacks import BaseCallback | ||||||
|  |  | ||||||
| from environments.factory.base.base_factory import REC_TAC | from environments.factory.base.base_factory import REC_TAC | ||||||
| @@ -71,16 +72,29 @@ class EnvRecorder(BaseCallback): | |||||||
|         self._on_training_end() |         self._on_training_end() | ||||||
|         return True |         return True | ||||||
|  |  | ||||||
|     def save_records(self, filepath: Union[Path, str, None] = None, save_occupation_map=False, save_trajectory_map=False): |     def save_records(self, filepath: Union[Path, str, None] = None, | ||||||
|  |                      only_deltas=True, | ||||||
|  |                      save_occupation_map=False, | ||||||
|  |                      save_trajectory_map=False, | ||||||
|  |                      ): | ||||||
|         filepath = Path(filepath or self.filepath) |         filepath = Path(filepath or self.filepath) | ||||||
|         filepath.parent.mkdir(exist_ok=True, parents=True) |         filepath.parent.mkdir(exist_ok=True, parents=True) | ||||||
|         # cls.out_file.unlink(missing_ok=True) |         # cls.out_file.unlink(missing_ok=True) | ||||||
|         with filepath.open('w') as f: |         with filepath.open('w') as f: | ||||||
|             out_dict = {'n_episodes': self._episode_counter, |             if only_deltas: | ||||||
|                         'env_params': self.unwrapped.params, |                 from deepdiff import DeepDiff, Delta | ||||||
|                         'header': self.unwrapped.summarize_header, |                 diff_dict = [DeepDiff(t1,t2, ignore_order=True) | ||||||
|                         'episodes': self._recorder_out_list |                              for t1, t2 in zip(self._recorder_out_list, self._recorder_out_list[1:]) | ||||||
|                         } |                              ] | ||||||
|  |                 out_dict = {'episodes': diff_dict} | ||||||
|  |  | ||||||
|  |             else: | ||||||
|  |                 out_dict = {'episodes': self._recorder_out_list} | ||||||
|  |             out_dict.update( | ||||||
|  |                 {'n_episodes': self._episode_counter, | ||||||
|  |                  'env_params': self.unwrapped.params, | ||||||
|  |                  'header': self.unwrapped.summarize_header | ||||||
|  |                  }) | ||||||
|             try: |             try: | ||||||
|                 simplejson.dump(out_dict, f, indent=4) |                 simplejson.dump(out_dict, f, indent=4) | ||||||
|             except TypeError: |             except TypeError: | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Steffen Illium
					Steffen Illium