Destinations implemented and debugged

This commit is contained in:
Steffen Illium
2021-12-06 15:46:26 +01:00
parent 3d81b7577d
commit 7f7a3d9a3b
13 changed files with 426 additions and 76 deletions

View File

@ -75,7 +75,7 @@ baseline_monitor_file = 'e_1_baseline'
from stable_baselines3 import A2C
def policy_model_kwargs():
return dict(gae_lambda=0.25, n_steps=16, max_grad_norm=0, use_rms_prop=True)
return dict() # gae_lambda=0.25, n_steps=16, max_grad_norm=0.25, use_rms_prop=True)
def dqn_model_kwargs():
@ -198,12 +198,12 @@ if __name__ == '__main__':
ood_run = True
plotting = True
train_steps = 5e6
train_steps = 1e7
n_seeds = 3
frames_to_stack = 3
# Define a global studi save path
start_time = 'rms_weight_decay_0' # int(time.time())
start_time = 'new_reward' # int(time.time())
study_root_path = Path(__file__).parent.parent / 'study_out' / f'{Path(__file__).stem}_{start_time}'
# Define Global Env Parameters
@ -516,7 +516,7 @@ if __name__ == '__main__':
# df_melted["Measurements"] = df_melted["Measurement"] + " " + df_melted["monitor"]
# Plotting
# fig, ax = plt.subplots(figsize=(11.7, 8.27))
fig, ax = plt.subplots(figsize=(11.7, 8.27))
c = sns.catplot(data=df_melted[df_melted['env'] == env_name],
x='Measurement', hue='monitor', row='model', col='obs_mode', y='Score',
@ -525,7 +525,7 @@ if __name__ == '__main__':
c.set_xticklabels(rotation=65, horizontalalignment='right')
# c.fig.subplots_adjust(top=0.9) # adjust the Figure in rp
c.fig.suptitle(f"Cat plot for {env_name}")
# plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.tight_layout()
plt.savefig(study_root_path / f'results_{n_agents}_agents_{env_name}.png')
pass