diff --git a/environments/factory/env_default_param.yaml b/environments/factory/env_default_param.yaml new file mode 100644 index 0000000..9c0c139 --- /dev/null +++ b/environments/factory/env_default_param.yaml @@ -0,0 +1,23 @@ +combin_agent_slices_in_obs: true +dirt_properties: !!python/object/new:environments.factory.simple_factory.DirtProperties +- 1 +- 0.05 +- 0.1 +- 3 +- 1 +- 20 +- 0.0 +done_at_collision: false +frames_to_stack: 0 +level_name: rooms +max_steps: 400 +movement_properties: !!python/object/new:environments.utility_classes.MovementProperties +- true +- true +- false +n_agents: 1 +omit_agent_slice_in_obs: true +parse_doors: false +pomdp_radius: 3 +record_episodes: false +verbose: false diff --git a/environments/logging/recorder.py b/environments/logging/recorder.py index dec9ee1..d2cbb91 100644 --- a/environments/logging/recorder.py +++ b/environments/logging/recorder.py @@ -18,27 +18,31 @@ class RecorderCallback(BaseCallback): self.filepath = Path(filepath) self._recorder_dict = dict() self._recorder_df = pd.DataFrame() + self.do_record: bool self.started = False self.closed = False def _on_step(self) -> bool: - for _, info in enumerate(self.locals.get('infos', [])): - self._recorder_dict[self.num_timesteps] = {key: val for key, val in info.items() - if not key.startswith(f'{REC_TAC}_')} + if self.do_record and self.started: + for _, info in enumerate(self.locals.get('infos', [])): + self._recorder_dict[self.num_timesteps] = {key: val for key, val in info.items() + if not key.startswith(f'{REC_TAC}_')} - for env_idx, done in list(enumerate(self.locals.get('dones', []))) + \ - list(enumerate(self.locals.get('done', []))): - if done: - env_monitor_df = pd.DataFrame.from_dict(self._recorder_dict, orient='index') - self._recorder_dict = dict() - columns = [col for col in env_monitor_df.columns if col not in IGNORED_DF_COLUMNS] - env_monitor_df = env_monitor_df.aggregate( - {col: 'mean' if col.endswith('ount') else 'sum' for col in columns} - ) - env_monitor_df['episode'] = len(self._recorder_df) - self._recorder_df = self._recorder_df.append([env_monitor_df]) - else: - pass + for env_idx, done in list(enumerate(self.locals.get('dones', []))) + \ + list(enumerate(self.locals.get('done', []))): + if done: + env_monitor_df = pd.DataFrame.from_dict(self._recorder_dict, orient='index') + self._recorder_dict = dict() + columns = [col for col in env_monitor_df.columns if col not in IGNORED_DF_COLUMNS] + env_monitor_df = env_monitor_df.aggregate( + {col: 'mean' if col.endswith('ount') else 'sum' for col in columns} + ) + env_monitor_df['episode'] = len(self._recorder_df) + self._recorder_df = self._recorder_df.append([env_monitor_df]) + else: + pass + else: + pass return True def __enter__(self): @@ -51,24 +55,35 @@ class RecorderCallback(BaseCallback): if self.started: pass else: - self.filepath.parent.mkdir(exist_ok=True, parents=True) - self.started = True + if hasattr(self.training_env, 'record_episodes'): + if self.training_env.record_episodes: + self.do_record = True + self.filepath.parent.mkdir(exist_ok=True, parents=True) + self.started = True + else: + self.do_record = False + else: + self.do_record = False pass def _on_training_end(self) -> None: if self.closed: pass else: - # self.out_file.unlink(missing_ok=True) - with self.filepath.open('w') as f: - json_df = self._recorder_df.to_json(orient="table") - parsed = json.loads(json_df) - json.dump(parsed, f, indent=4) + if self.do_record and self.started: + # self.out_file.unlink(missing_ok=True) + with self.filepath.open('w') as f: + json_df = self._recorder_df.to_json(orient="table") + parsed = json.loads(json_df) + json.dump(parsed, f, indent=4) - if self.occupation_map: - print('Recorder files were dumped to disk, now plotting the occupation map...') + if self.occupation_map: + print('Recorder files were dumped to disk, now plotting the occupation map...') - if self.trajectory_map: - print('Recorder files were dumped to disk, now plotting the occupation map...') + if self.trajectory_map: + print('Recorder files were dumped to disk, now plotting the occupation map...') - self.closed = True \ No newline at end of file + self.closed = True + self.started = False + else: + pass diff --git a/main.py b/main.py index b1b5242..c424e01 100644 --- a/main.py +++ b/main.py @@ -93,7 +93,7 @@ if __name__ == '__main__': # from sb3_contrib import QRDQN dirt_props = DirtProperties(clean_amount=1, gain_amount=0.1, max_global_amount=20, - max_local_amount=1, spawn_frequency=10, max_spawn_ratio=0.05, + max_local_amount=1, spawn_frequency=5, max_spawn_ratio=0.05, dirt_smear_amount=0.0) move_props = MovementProperties(allow_diagonal_movement=True, allow_square_movement=True, @@ -109,7 +109,7 @@ if __name__ == '__main__': with SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400, parse_doors=True, movement_properties=move_props, level_name='rooms', frames_to_stack=3, omit_agent_slice_in_obs=True, combin_agent_slices_in_obs=True, record_episodes=False, - cast_shadows=True, + cast_shadows=True, doors_have_area=True ) as env: if modeL_type.__name__ in ["PPO", "A2C"]: