diff --git a/environments/logging/monitor.py b/environments/logging/monitor.py index 311ca0c..55058f5 100644 --- a/environments/logging/monitor.py +++ b/environments/logging/monitor.py @@ -72,7 +72,8 @@ class MonitorCallback(BaseCallback): self._monitor_dict[self.num_timesteps] = {key: val for key, val in info.items() if key not in ['terminal_observation', 'episode']} - for env_idx, done in enumerate(self.locals.get('dones', [])): + for env_idx, done in list(enumerate(self.locals.get('dones', []))) + \ + list(enumerate(self.locals.get('done', []))): if done: env_monitor_df = pd.DataFrame.from_dict(self._monitor_dict, orient='index') self._monitor_dict = dict() diff --git a/main.py b/main.py index 00fcb29..da6b757 100644 --- a/main.py +++ b/main.py @@ -41,8 +41,6 @@ def combine_runs(run_path: Union[str, PathLike]): value_vars=columns, var_name="Measurement", value_name="Score") df_melted = df_melted[df_melted['Episode'] % skip_n == 0] - #df_melted['Episode'] = df_melted['Episode'] * skip_n # only needed for old version - prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted) print('Plotting done.') @@ -51,12 +49,14 @@ def combine_runs(run_path: Union[str, PathLike]): if __name__ == '__main__': from stable_baselines3 import PPO, DQN, A2C + from algorithms.dqn_reg import RegDQN + dirt_props = DirtProperties() time_stamp = int(time.time()) out_path = None - for modeL_type in [A2C, PPO, DQN]: + for modeL_type in [RegDQN, DQN]: for seed in range(5): env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400,