Smaller fixes, now running.

This commit is contained in:
Steffen Illium 2021-09-01 14:51:44 +02:00
parent 4fb32c98c6
commit 84eb381307
3 changed files with 21 additions and 14 deletions

View File

@ -74,13 +74,13 @@ class MonitorCallback(BaseCallback):
dones = alt_dones
elif self.locals.get('dones', None) is not None:
dones =self.locals.get('dones', None)
elif self.locals.get('dones', None) is not None:
elif self.locals.get('done', None) is not None:
dones = self.locals.get('done', [None])
else:
dones = []
for env_idx, (info, done) in enumerate(zip(infos, dones)):
self._monitor_dicts[env_idx][self.num_timesteps - env_idx] = {key: val for key, val in info.items()
self._monitor_dicts[env_idx][len(self._monitor_dicts[env_idx])] = {key: val for key, val in info.items()
if key not in ['terminal_observation', 'episode']
and not key.startswith('rec_')}
if done:

View File

@ -34,7 +34,7 @@ def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None)
sns.set(rc={'text.usetex': True}, style='whitegrid')
lineplot = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE,
hue_order=hue_order, hue=hue, style=style)
lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
# lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
plot(filepath, ext=ext) # plot raises errors not lineplot!
except (FileNotFoundError, RuntimeError):
print('Struggling to plot Figure using LaTeX - going back to normal.')
@ -42,5 +42,5 @@ def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None)
sns.set(rc={'text.usetex': False}, style='whitegrid')
lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
ci=95, palette=PALETTE, hue_order=hue_order)
lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
# lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
plot(filepath, ext=ext)

23
main.py
View File

@ -34,17 +34,19 @@ def combine_runs(run_path: Union[str, PathLike]):
df_list.append(monitor_df)
df = pd.concat(df_list, ignore_index=True)
df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'})
df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'}).sort_values(['Run', 'Episode'])
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
roll_n = 50
skip_n = 40
non_overlapp_window = df.groupby(['Run', 'Episode']).rolling(roll_n, min_periods=1).mean()
df_melted = non_overlapp_window[columns].reset_index().melt(id_vars=['Episode', 'Run'],
value_vars=columns, var_name="Measurement",
value_name="Score")
if df_melted['Episode'].max() > 100:
skip_n = round(df_melted['Episode'].max() * 0.01)
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted)
@ -71,13 +73,14 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List
columns = [col for col in df.columns if col in parameter]
roll_n = 40
skip_n = 20
non_overlapp_window = df.groupby(['Model', 'Run', 'Episode']).rolling(roll_n, min_periods=1).mean()
df_melted = non_overlapp_window[columns].reset_index().melt(id_vars=['Episode', 'Run', 'Model'],
value_vars=columns, var_name="Measurement",
value_name="Score")
if df_melted['Episode'].max() > 100:
skip_n = round(df_melted['Episode'].max() * 0.01)
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
style = 'Measurement' if len(columns) > 1 else None
@ -113,7 +116,7 @@ if __name__ == '__main__':
move_props = MovementProperties(allow_diagonal_movement=False,
allow_square_movement=True,
allow_no_op=False)
train_steps = 1e6
train_steps = 1e5
time_stamp = int(time.time())
out_path = None
@ -131,12 +134,11 @@ if __name__ == '__main__':
cast_shadows=True, doors_have_area=False, env_seed=seed, verbose=False,
)
# env = make_env(env_kwargs)()
env = SubprocVecEnv([make_env(env_kwargs) for _ in range(12)], start_method="spawn")
if modeL_type.__name__ in ["PPO", "A2C"]:
kwargs = dict(ent_coef=0.01)
env = SubprocVecEnv([make_env(env_kwargs) for _ in range(6)], start_method="spawn")
elif modeL_type.__name__ in ["RegDQN", "DQN", "QRDQN"]:
env = make_env(env_kwargs)()
kwargs = dict(buffer_size=50000,
learning_starts=64,
batch_size=64,
@ -145,6 +147,7 @@ if __name__ == '__main__':
exploration_final_eps=0.025)
else:
raise NameError(f'The model "{modeL_type.__name__}" has the wrong name.')
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu', **kwargs)
out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}'
@ -165,7 +168,11 @@ if __name__ == '__main__':
save_path = out_path / f'model_{identifier}.zip'
save_path.parent.mkdir(parents=True, exist_ok=True)
model.save(save_path)
env.env_method('save_params', out_path.parent / f'env_{model.__class__.__name__}_{time_stamp}.yaml')
param_path = out_path.parent / f'env_{model.__class__.__name__}_{time_stamp}.yaml'
try:
env.env_method('save_params', param_path)
except AttributeError:
env.save_params(param_path)
print("Model Trained and saved")
print("Model Group Done.. Plotting...")