mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-21 11:21:35 +02:00
Destinations implemented and debugged
This commit is contained in:
@ -75,7 +75,7 @@ baseline_monitor_file = 'e_1_baseline'
|
||||
from stable_baselines3 import A2C
|
||||
|
||||
def policy_model_kwargs():
|
||||
return dict(gae_lambda=0.25, n_steps=16, max_grad_norm=0, use_rms_prop=True)
|
||||
return dict() # gae_lambda=0.25, n_steps=16, max_grad_norm=0.25, use_rms_prop=True)
|
||||
|
||||
|
||||
def dqn_model_kwargs():
|
||||
@ -198,12 +198,12 @@ if __name__ == '__main__':
|
||||
ood_run = True
|
||||
plotting = True
|
||||
|
||||
train_steps = 5e6
|
||||
train_steps = 1e7
|
||||
n_seeds = 3
|
||||
frames_to_stack = 3
|
||||
|
||||
# Define a global studi save path
|
||||
start_time = 'rms_weight_decay_0' # int(time.time())
|
||||
start_time = 'new_reward' # int(time.time())
|
||||
study_root_path = Path(__file__).parent.parent / 'study_out' / f'{Path(__file__).stem}_{start_time}'
|
||||
|
||||
# Define Global Env Parameters
|
||||
@ -516,7 +516,7 @@ if __name__ == '__main__':
|
||||
# df_melted["Measurements"] = df_melted["Measurement"] + " " + df_melted["monitor"]
|
||||
|
||||
# Plotting
|
||||
# fig, ax = plt.subplots(figsize=(11.7, 8.27))
|
||||
fig, ax = plt.subplots(figsize=(11.7, 8.27))
|
||||
|
||||
c = sns.catplot(data=df_melted[df_melted['env'] == env_name],
|
||||
x='Measurement', hue='monitor', row='model', col='obs_mode', y='Score',
|
||||
@ -525,7 +525,7 @@ if __name__ == '__main__':
|
||||
c.set_xticklabels(rotation=65, horizontalalignment='right')
|
||||
# c.fig.subplots_adjust(top=0.9) # adjust the Figure in rp
|
||||
c.fig.suptitle(f"Cat plot for {env_name}")
|
||||
# plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
|
||||
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
|
||||
plt.tight_layout()
|
||||
plt.savefig(study_root_path / f'results_{n_agents}_agents_{env_name}.png')
|
||||
pass
|
||||
|
Reference in New Issue
Block a user