Smaller fixes, now running.
This commit is contained in:
		| @@ -74,13 +74,13 @@ class MonitorCallback(BaseCallback): | ||||
|             dones = alt_dones | ||||
|         elif self.locals.get('dones', None) is not None: | ||||
|             dones =self.locals.get('dones', None) | ||||
|         elif self.locals.get('dones', None) is not None: | ||||
|         elif self.locals.get('done', None) is not None: | ||||
|             dones = self.locals.get('done', [None]) | ||||
|         else: | ||||
|             dones = [] | ||||
|  | ||||
|         for env_idx, (info, done) in enumerate(zip(infos, dones)): | ||||
|             self._monitor_dicts[env_idx][self.num_timesteps - env_idx] = {key: val for key, val in info.items() | ||||
|             self._monitor_dicts[env_idx][len(self._monitor_dicts[env_idx])] = {key: val for key, val in info.items() | ||||
|                                                                 if key not in ['terminal_observation', 'episode'] | ||||
|                                                                 and not key.startswith('rec_')} | ||||
|             if done: | ||||
|   | ||||
| @@ -34,7 +34,7 @@ def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None) | ||||
|         sns.set(rc={'text.usetex': True}, style='whitegrid') | ||||
|         lineplot = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE, | ||||
|                                 hue_order=hue_order, hue=hue, style=style) | ||||
|         lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}') | ||||
|         # lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}') | ||||
|         plot(filepath, ext=ext)  # plot raises errors not lineplot! | ||||
|     except (FileNotFoundError, RuntimeError): | ||||
|         print('Struggling to plot Figure using LaTeX - going back to normal.') | ||||
| @@ -42,5 +42,5 @@ def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None) | ||||
|         sns.set(rc={'text.usetex': False}, style='whitegrid') | ||||
|         lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style, | ||||
|                                      ci=95, palette=PALETTE, hue_order=hue_order) | ||||
|         lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}') | ||||
|         # lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}') | ||||
|         plot(filepath, ext=ext) | ||||
|   | ||||
							
								
								
									
										27
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										27
									
								
								main.py
									
									
									
									
									
								
							| @@ -34,18 +34,20 @@ def combine_runs(run_path: Union[str, PathLike]): | ||||
|         df_list.append(monitor_df) | ||||
|  | ||||
|     df = pd.concat(df_list,  ignore_index=True) | ||||
|     df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'}) | ||||
|     df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'}).sort_values(['Run', 'Episode']) | ||||
|     columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS] | ||||
|  | ||||
|     roll_n = 50 | ||||
|     skip_n = 40 | ||||
|  | ||||
|     non_overlapp_window = df.groupby(['Run', 'Episode']).rolling(roll_n, min_periods=1).mean() | ||||
|  | ||||
|     df_melted = non_overlapp_window[columns].reset_index().melt(id_vars=['Episode', 'Run'], | ||||
|                                                                 value_vars=columns, var_name="Measurement", | ||||
|                                                                 value_name="Score") | ||||
|     df_melted = df_melted[df_melted['Episode'] % skip_n == 0] | ||||
|  | ||||
|     if df_melted['Episode'].max() > 100: | ||||
|         skip_n = round(df_melted['Episode'].max() * 0.01) | ||||
|         df_melted = df_melted[df_melted['Episode'] % skip_n == 0] | ||||
|  | ||||
|     prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted) | ||||
|     print('Plotting done.') | ||||
| @@ -71,14 +73,15 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List | ||||
|     columns = [col for col in df.columns if col in parameter] | ||||
|  | ||||
|     roll_n = 40 | ||||
|     skip_n = 20 | ||||
|  | ||||
|     non_overlapp_window = df.groupby(['Model', 'Run', 'Episode']).rolling(roll_n, min_periods=1).mean() | ||||
|  | ||||
|     df_melted = non_overlapp_window[columns].reset_index().melt(id_vars=['Episode', 'Run', 'Model'], | ||||
|                                                                 value_vars=columns, var_name="Measurement", | ||||
|                                                                 value_name="Score") | ||||
|     df_melted = df_melted[df_melted['Episode'] % skip_n == 0] | ||||
|     if df_melted['Episode'].max() > 100: | ||||
|         skip_n = round(df_melted['Episode'].max() * 0.01) | ||||
|         df_melted = df_melted[df_melted['Episode'] % skip_n == 0] | ||||
|  | ||||
|     style = 'Measurement' if len(columns) > 1 else None | ||||
|     prepare_plot(run_path / f'{run_identifier}_compare_{parameter}.png', df_melted, hue='Model', style=style) | ||||
| @@ -113,7 +116,7 @@ if __name__ == '__main__': | ||||
|     move_props = MovementProperties(allow_diagonal_movement=False, | ||||
|                                     allow_square_movement=True, | ||||
|                                     allow_no_op=False) | ||||
|     train_steps = 1e6 | ||||
|     train_steps = 1e5 | ||||
|     time_stamp = int(time.time()) | ||||
|  | ||||
|     out_path = None | ||||
| @@ -131,12 +134,11 @@ if __name__ == '__main__': | ||||
|                               cast_shadows=True, doors_have_area=False, env_seed=seed, verbose=False, | ||||
|                               ) | ||||
|  | ||||
|             # env = make_env(env_kwargs)() | ||||
|             env = SubprocVecEnv([make_env(env_kwargs) for _ in range(12)], start_method="spawn") | ||||
|  | ||||
|             if modeL_type.__name__ in ["PPO", "A2C"]: | ||||
|                 kwargs = dict(ent_coef=0.01) | ||||
|                 env = SubprocVecEnv([make_env(env_kwargs) for _ in range(6)], start_method="spawn") | ||||
|             elif modeL_type.__name__ in ["RegDQN", "DQN", "QRDQN"]: | ||||
|                 env = make_env(env_kwargs)() | ||||
|                 kwargs = dict(buffer_size=50000, | ||||
|                               learning_starts=64, | ||||
|                               batch_size=64, | ||||
| @@ -145,6 +147,7 @@ if __name__ == '__main__': | ||||
|                               exploration_final_eps=0.025) | ||||
|             else: | ||||
|                 raise NameError(f'The model "{modeL_type.__name__}" has the wrong name.') | ||||
|  | ||||
|             model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu', **kwargs) | ||||
|  | ||||
|             out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}' | ||||
| @@ -165,7 +168,11 @@ if __name__ == '__main__': | ||||
|             save_path = out_path / f'model_{identifier}.zip' | ||||
|             save_path.parent.mkdir(parents=True, exist_ok=True) | ||||
|             model.save(save_path) | ||||
|             env.env_method('save_params', out_path.parent / f'env_{model.__class__.__name__}_{time_stamp}.yaml') | ||||
|             param_path = out_path.parent / f'env_{model.__class__.__name__}_{time_stamp}.yaml' | ||||
|             try: | ||||
|                 env.env_method('save_params', param_path) | ||||
|             except AttributeError: | ||||
|                 env.save_params(param_path) | ||||
|             print("Model Trained and saved") | ||||
|         print("Model Group Done.. Plotting...") | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Steffen Illium
					Steffen Illium