mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-21 11:21:35 +02:00
Individual Rewards
This commit is contained in:
201
studies/e_1.py
201
studies/e_1.py
@ -2,6 +2,7 @@ import sys
|
||||
from pathlib import Path
|
||||
from matplotlib import pyplot as plt
|
||||
import numpy as np
|
||||
import itertools as it
|
||||
|
||||
try:
|
||||
# noinspection PyUnboundLocalVariable
|
||||
@ -70,7 +71,7 @@ baseline_monitor_file = 'e_1_baseline_monitor.pick'
|
||||
|
||||
|
||||
def policy_model_kwargs():
|
||||
return dict(ent_coef=0.05)
|
||||
return dict()
|
||||
|
||||
|
||||
def dqn_model_kwargs():
|
||||
@ -100,6 +101,7 @@ def load_model_run_baseline(seed_path, env_to_run):
|
||||
# Load old env kwargs
|
||||
with next(seed_path.glob('*.json')).open('r') as f:
|
||||
env_kwargs = simplejson.load(f)
|
||||
env_kwargs.update(done_at_collision=True)
|
||||
# Monitor Init
|
||||
with MonitorCallback(filepath=seed_path / baseline_monitor_file) as monitor:
|
||||
# Init Env
|
||||
@ -134,6 +136,7 @@ def load_model_run_study(seed_path, env_to_run, additional_kwargs_dict):
|
||||
env_kwargs = simplejson.load(f)
|
||||
env_kwargs.update(
|
||||
n_agents=n_agents,
|
||||
done_at_collision=True,
|
||||
**additional_kwargs_dict.get('post_training_kwargs', {}))
|
||||
# Monitor Init
|
||||
with MonitorCallback(filepath=seed_path / ood_monitor_file) as monitor:
|
||||
@ -168,6 +171,31 @@ def load_model_run_study(seed_path, env_to_run, additional_kwargs_dict):
|
||||
gc.collect()
|
||||
|
||||
|
||||
def start_mp_study_run(envs_map, policies_path):
|
||||
paths = list(y for y in policies_path.iterdir() if y.is_dir() and not (y / ood_monitor_file).exists())
|
||||
if paths:
|
||||
import multiprocessing as mp
|
||||
pool = mp.Pool(mp.cpu_count())
|
||||
print("Starting MP with: ", pool._processes, " Processes")
|
||||
_ = pool.starmap(load_model_run_study,
|
||||
it.product(paths,
|
||||
(envs_map[policies_path.parent.name][0],),
|
||||
(observation_modes[policies_path.parent.parent.name],))
|
||||
)
|
||||
|
||||
|
||||
def start_mp_baseline_run(envs_map, policies_path):
|
||||
paths = list(y for y in policies_path.iterdir() if y.is_dir() and not (y / baseline_monitor_file).exists())
|
||||
if paths:
|
||||
import multiprocessing as mp
|
||||
pool = mp.Pool(mp.cpu_count())
|
||||
print("Starting MP with: ", pool._processes, " Processes")
|
||||
_ = pool.starmap(load_model_run_baseline,
|
||||
it.product(paths,
|
||||
(envs_map[policies_path.parent.name][0],))
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_steps = 5e6
|
||||
n_seeds = 3
|
||||
@ -215,75 +243,74 @@ if __name__ == '__main__':
|
||||
|
||||
# Define parameter versions according with #1,2[1,0,N],3
|
||||
observation_modes = {}
|
||||
if False:
|
||||
observation_modes.update({
|
||||
'seperate_1': dict(
|
||||
post_training_kwargs=
|
||||
dict(obs_prop=ObservationProperties(
|
||||
render_agents=AgentRenderOptions.COMBINED,
|
||||
additional_agent_placeholder=None,
|
||||
omit_agent_self=True,
|
||||
frames_to_stack=3,
|
||||
pomdp_r=2)
|
||||
),
|
||||
additional_env_kwargs=
|
||||
dict(obs_prop=ObservationProperties(
|
||||
render_agents=AgentRenderOptions.NOT,
|
||||
additional_agent_placeholder=1,
|
||||
omit_agent_self=True,
|
||||
frames_to_stack=3,
|
||||
pomdp_r=2)
|
||||
)
|
||||
)})
|
||||
observation_modes.update({
|
||||
'seperate_0': dict(
|
||||
post_training_kwargs=
|
||||
dict(obs_prop=ObservationProperties(
|
||||
render_agents=AgentRenderOptions.COMBINED,
|
||||
additional_agent_placeholder=None,
|
||||
omit_agent_self=True,
|
||||
frames_to_stack=3,
|
||||
pomdp_r=2)
|
||||
),
|
||||
additional_env_kwargs=
|
||||
dict(obs_prop=ObservationProperties(
|
||||
render_agents=AgentRenderOptions.NOT,
|
||||
additional_agent_placeholder=0,
|
||||
omit_agent_self=True,
|
||||
frames_to_stack=3,
|
||||
pomdp_r=2)
|
||||
)
|
||||
)})
|
||||
observation_modes.update({
|
||||
'seperate_N': dict(
|
||||
post_training_kwargs=
|
||||
dict(obs_prop=ObservationProperties(
|
||||
render_agents=AgentRenderOptions.COMBINED,
|
||||
additional_agent_placeholder=None,
|
||||
omit_agent_self=True,
|
||||
frames_to_stack=3,
|
||||
pomdp_r=2)
|
||||
),
|
||||
additional_env_kwargs=
|
||||
dict(obs_prop=ObservationProperties(
|
||||
render_agents=AgentRenderOptions.NOT,
|
||||
additional_agent_placeholder='N',
|
||||
omit_agent_self=True,
|
||||
frames_to_stack=3,
|
||||
pomdp_r=2)
|
||||
)
|
||||
)})
|
||||
observation_modes.update({
|
||||
'in_lvl_obs': dict(
|
||||
post_training_kwargs=
|
||||
dict(obs_prop=ObservationProperties(
|
||||
render_agents=AgentRenderOptions.LEVEL,
|
||||
omit_agent_self=True,
|
||||
additional_agent_placeholder=None,
|
||||
frames_to_stack=3,
|
||||
pomdp_r=2)
|
||||
)
|
||||
)})
|
||||
observation_modes.update({
|
||||
'seperate_1': dict(
|
||||
post_training_kwargs=
|
||||
dict(obs_prop=ObservationProperties(
|
||||
render_agents=AgentRenderOptions.COMBINED,
|
||||
additional_agent_placeholder=None,
|
||||
omit_agent_self=True,
|
||||
frames_to_stack=3,
|
||||
pomdp_r=2)
|
||||
),
|
||||
additional_env_kwargs=
|
||||
dict(obs_prop=ObservationProperties(
|
||||
render_agents=AgentRenderOptions.NOT,
|
||||
additional_agent_placeholder=1,
|
||||
omit_agent_self=True,
|
||||
frames_to_stack=3,
|
||||
pomdp_r=2)
|
||||
)
|
||||
)})
|
||||
observation_modes.update({
|
||||
'seperate_0': dict(
|
||||
post_training_kwargs=
|
||||
dict(obs_prop=ObservationProperties(
|
||||
render_agents=AgentRenderOptions.COMBINED,
|
||||
additional_agent_placeholder=None,
|
||||
omit_agent_self=True,
|
||||
frames_to_stack=3,
|
||||
pomdp_r=2)
|
||||
),
|
||||
additional_env_kwargs=
|
||||
dict(obs_prop=ObservationProperties(
|
||||
render_agents=AgentRenderOptions.NOT,
|
||||
additional_agent_placeholder=0,
|
||||
omit_agent_self=True,
|
||||
frames_to_stack=3,
|
||||
pomdp_r=2)
|
||||
)
|
||||
)})
|
||||
observation_modes.update({
|
||||
'seperate_N': dict(
|
||||
post_training_kwargs=
|
||||
dict(obs_prop=ObservationProperties(
|
||||
render_agents=AgentRenderOptions.COMBINED,
|
||||
additional_agent_placeholder=None,
|
||||
omit_agent_self=True,
|
||||
frames_to_stack=3,
|
||||
pomdp_r=2)
|
||||
),
|
||||
additional_env_kwargs=
|
||||
dict(obs_prop=ObservationProperties(
|
||||
render_agents=AgentRenderOptions.NOT,
|
||||
additional_agent_placeholder='N',
|
||||
omit_agent_self=True,
|
||||
frames_to_stack=3,
|
||||
pomdp_r=2)
|
||||
)
|
||||
)})
|
||||
observation_modes.update({
|
||||
'in_lvl_obs': dict(
|
||||
post_training_kwargs=
|
||||
dict(obs_prop=ObservationProperties(
|
||||
render_agents=AgentRenderOptions.LEVEL,
|
||||
omit_agent_self=True,
|
||||
additional_agent_placeholder=None,
|
||||
frames_to_stack=3,
|
||||
pomdp_r=2)
|
||||
)
|
||||
)})
|
||||
observation_modes.update({
|
||||
# No further adjustment needed
|
||||
'no_obs': dict(
|
||||
@ -398,15 +425,7 @@ if __name__ == '__main__':
|
||||
for env_path in [x for x in obs_mode_path.iterdir() if x.is_dir()]:
|
||||
for policy_path in [x for x in env_path.iterdir() if x. is_dir()]:
|
||||
# Iteration
|
||||
paths = list(y for y in policy_path.iterdir() if y.is_dir() \
|
||||
and not (y / baseline_monitor_file).exists())
|
||||
import multiprocessing as mp
|
||||
import itertools as it
|
||||
pool = mp.Pool(mp.cpu_count())
|
||||
result = pool.starmap(load_model_run_baseline,
|
||||
it.product(paths,
|
||||
(env_map[env_path.name][0],))
|
||||
)
|
||||
start_mp_baseline_run(env_map, policy_path)
|
||||
|
||||
# for seed_path in (y for y in policy_path.iterdir() if y.is_dir()):
|
||||
# load_model_run_baseline(seed_path)
|
||||
@ -424,18 +443,9 @@ if __name__ == '__main__':
|
||||
# First seed path version
|
||||
# seed_path = next((y for y in policy_path.iterdir() if y.is_dir()))
|
||||
# Iteration
|
||||
import multiprocessing as mp
|
||||
import itertools as it
|
||||
pool = mp.Pool(mp.cpu_count())
|
||||
paths = list(y for y in policy_path.iterdir() if y.is_dir() \
|
||||
and not (y / ood_monitor_file).exists())
|
||||
# result = pool.starmap(load_model_run_study,
|
||||
# it.product(paths,
|
||||
# (env_map[env_path.name][0],),
|
||||
# (observation_modes[obs_mode],))
|
||||
# )
|
||||
for seed_path in (y for y in policy_path.iterdir() if y.is_dir()):
|
||||
load_model_run_study(seed_path, env_map[env_path.name][0], observation_modes[obs_mode])
|
||||
start_mp_study_run(env_map, policy_path)
|
||||
#for seed_path in (y for y in policy_path.iterdir() if y.is_dir()):
|
||||
# load_model_run_study(seed_path, env_map[env_path.name][0], observation_modes[obs_mode])
|
||||
print('OOD Tracking Done')
|
||||
|
||||
# Plotting
|
||||
@ -497,15 +507,16 @@ if __name__ == '__main__':
|
||||
# df_melted["Measurements"] = df_melted["Measurement"] + " " + df_melted["monitor"]
|
||||
|
||||
# Plotting
|
||||
fig, ax = plt.subplots(figsize=(11.7, 8.27))
|
||||
# fig, ax = plt.subplots(figsize=(11.7, 8.27))
|
||||
|
||||
c = sns.catplot(data=df_melted[df_melted['obs_mode'] == observation_folder.name],
|
||||
x='Measurement', hue='monitor', row='model', col='env', y='Score',
|
||||
sharey=False, kind="box", height=4, aspect=.7, legend_out=True,
|
||||
sharey=False, kind="box", height=4, aspect=.7, legend_out=False, legend=False,
|
||||
showfliers=False)
|
||||
c.set_xticklabels(rotation=65, horizontalalignment='right')
|
||||
c.fig.subplots_adjust(top=0.9) # adjust the Figure in rp
|
||||
# c.fig.subplots_adjust(top=0.9) # adjust the Figure in rp
|
||||
c.fig.suptitle(f"Cat plot for {observation_folder.name}")
|
||||
plt.tight_layout(pad=2)
|
||||
# plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
|
||||
plt.tight_layout()
|
||||
plt.savefig(study_root_path / f'results_{n_agents}_agents_{observation_folder.name}.png')
|
||||
pass
|
||||
|
Reference in New Issue
Block a user