mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-21 11:21:35 +02:00
Debugging and collision rendering
This commit is contained in:
155
studies/e_1.py
155
studies/e_1.py
@ -147,7 +147,7 @@ def load_model_run_study(seed_path, env_to_run, additional_kwargs_dict):
|
||||
try:
|
||||
actions = [model.predict(
|
||||
np.stack([env_state[i][j] for i in range(env_state.shape[0])]),
|
||||
deterministic=False)[0] for j, model in enumerate(models)]
|
||||
deterministic=True)[0] for j, model in enumerate(models)]
|
||||
except ValueError as e:
|
||||
print(e)
|
||||
print('Env_Kwargs are:\n')
|
||||
@ -169,10 +169,11 @@ def load_model_run_study(seed_path, env_to_run, additional_kwargs_dict):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_steps = 8e5
|
||||
train_steps = 5e6
|
||||
n_seeds = 3
|
||||
|
||||
# Define a global studi save path
|
||||
start_time = '900000' # int(time.time())
|
||||
start_time = 'Now_with_doors' # int(time.time())
|
||||
study_root_path = Path(__file__).parent.parent / 'study_out' / f'{Path(__file__).stem}_{start_time}'
|
||||
|
||||
# Define Global Env Parameters
|
||||
@ -195,57 +196,95 @@ if __name__ == '__main__':
|
||||
spawn_frequency=30, n_drop_off_locations=2,
|
||||
max_agent_inventory_capacity=15)
|
||||
factory_kwargs = dict(n_agents=1, max_steps=400, parse_doors=True,
|
||||
level_name='rooms', record_episodes=False, doors_have_area=False,
|
||||
level_name='rooms', record_episodes=False, doors_have_area=True,
|
||||
verbose=False,
|
||||
mv_prop=move_props,
|
||||
obs_prop=obs_props
|
||||
)
|
||||
|
||||
# Bundle both environments with global kwargs and parameters
|
||||
env_map = {'dirt': (DirtFactory, dict(dirt_prop=dirt_props,
|
||||
**factory_kwargs.copy())),
|
||||
'item': (ItemFactory, dict(item_prop=item_props,
|
||||
**factory_kwargs.copy())),
|
||||
'itemdirt': (DirtItemFactory, dict(dirt_prop=dirt_props,
|
||||
item_prop=item_props,
|
||||
**factory_kwargs.copy()))}
|
||||
env_map = {}
|
||||
env_map.update({'dirt': (DirtFactory, dict(dirt_prop=dirt_props,
|
||||
**factory_kwargs.copy()))})
|
||||
if False:
|
||||
env_map.update({'item': (ItemFactory, dict(item_prop=item_props,
|
||||
**factory_kwargs.copy()))})
|
||||
env_map.update({'itemdirt': (DirtItemFactory, dict(dirt_prop=dirt_props, item_prop=item_props,
|
||||
**factory_kwargs.copy()))})
|
||||
env_names = list(env_map.keys())
|
||||
|
||||
# Define parameter versions according with #1,2[1,0,N],3
|
||||
observation_modes = {
|
||||
# Fill-value = 0
|
||||
# DEACTIVATED 'seperate_0': dict(additional_env_kwargs=dict(additional_agent_placeholder=0)),
|
||||
# Fill-value = 1
|
||||
# DEACTIVATED 'seperate_1': dict(additional_env_kwargs=dict(additional_agent_placeholder=1)),
|
||||
# Fill-value = N(0, 1)
|
||||
'seperate_N': dict(
|
||||
post_training_kwargs=
|
||||
dict(obs_prop=ObservationProperties(
|
||||
render_agents=AgentRenderOptions.COMBINED,
|
||||
additional_agent_placeholder=None,
|
||||
omit_agent_self=True,
|
||||
frames_to_stack=3,
|
||||
pomdp_r=2)
|
||||
),
|
||||
additional_env_kwargs=
|
||||
dict(obs_prop=ObservationProperties(
|
||||
render_agents=AgentRenderOptions.NOT,
|
||||
additional_agent_placeholder='N',
|
||||
omit_agent_self=True,
|
||||
frames_to_stack=3,
|
||||
pomdp_r=2)
|
||||
)
|
||||
),
|
||||
'in_lvl_obs': dict(
|
||||
post_training_kwargs=
|
||||
dict(obs_prop=ObservationProperties(
|
||||
render_agents=AgentRenderOptions.LEVEL,
|
||||
omit_agent_self=True,
|
||||
additional_agent_placeholder=None,
|
||||
frames_to_stack=3,
|
||||
pomdp_r=2)
|
||||
)
|
||||
),
|
||||
observation_modes = {}
|
||||
if False:
|
||||
observation_modes.update({
|
||||
'seperate_1': dict(
|
||||
post_training_kwargs=
|
||||
dict(obs_prop=ObservationProperties(
|
||||
render_agents=AgentRenderOptions.COMBINED,
|
||||
additional_agent_placeholder=None,
|
||||
omit_agent_self=True,
|
||||
frames_to_stack=3,
|
||||
pomdp_r=2)
|
||||
),
|
||||
additional_env_kwargs=
|
||||
dict(obs_prop=ObservationProperties(
|
||||
render_agents=AgentRenderOptions.NOT,
|
||||
additional_agent_placeholder=1,
|
||||
omit_agent_self=True,
|
||||
frames_to_stack=3,
|
||||
pomdp_r=2)
|
||||
)
|
||||
)})
|
||||
observation_modes.update({
|
||||
'seperate_0': dict(
|
||||
post_training_kwargs=
|
||||
dict(obs_prop=ObservationProperties(
|
||||
render_agents=AgentRenderOptions.COMBINED,
|
||||
additional_agent_placeholder=None,
|
||||
omit_agent_self=True,
|
||||
frames_to_stack=3,
|
||||
pomdp_r=2)
|
||||
),
|
||||
additional_env_kwargs=
|
||||
dict(obs_prop=ObservationProperties(
|
||||
render_agents=AgentRenderOptions.NOT,
|
||||
additional_agent_placeholder=0,
|
||||
omit_agent_self=True,
|
||||
frames_to_stack=3,
|
||||
pomdp_r=2)
|
||||
)
|
||||
)})
|
||||
observation_modes.update({
|
||||
'seperate_N': dict(
|
||||
post_training_kwargs=
|
||||
dict(obs_prop=ObservationProperties(
|
||||
render_agents=AgentRenderOptions.COMBINED,
|
||||
additional_agent_placeholder=None,
|
||||
omit_agent_self=True,
|
||||
frames_to_stack=3,
|
||||
pomdp_r=2)
|
||||
),
|
||||
additional_env_kwargs=
|
||||
dict(obs_prop=ObservationProperties(
|
||||
render_agents=AgentRenderOptions.NOT,
|
||||
additional_agent_placeholder='N',
|
||||
omit_agent_self=True,
|
||||
frames_to_stack=3,
|
||||
pomdp_r=2)
|
||||
)
|
||||
)})
|
||||
observation_modes.update({
|
||||
'in_lvl_obs': dict(
|
||||
post_training_kwargs=
|
||||
dict(obs_prop=ObservationProperties(
|
||||
render_agents=AgentRenderOptions.LEVEL,
|
||||
omit_agent_self=True,
|
||||
additional_agent_placeholder=None,
|
||||
frames_to_stack=3,
|
||||
pomdp_r=2)
|
||||
)
|
||||
)})
|
||||
observation_modes.update({
|
||||
# No further adjustment needed
|
||||
'no_obs': dict(
|
||||
post_training_kwargs=
|
||||
@ -257,14 +296,14 @@ if __name__ == '__main__':
|
||||
pomdp_r=2)
|
||||
)
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
# Train starts here ############################################################
|
||||
# Build Major Loop parameters, parameter versions, Env Classes and models
|
||||
if True:
|
||||
for obs_mode in observation_modes.keys():
|
||||
for env_name in env_names:
|
||||
for model_cls in [h.MODEL_MAP['A2C'], h.MODEL_MAP['DQN']]:
|
||||
for model_cls in [h.MODEL_MAP['A2C']]:
|
||||
# Create an identifier, which is unique for every combination and easy to read in filesystem
|
||||
identifier = f'{model_cls.__name__}_{start_time}'
|
||||
# Train each combination per seed
|
||||
@ -274,7 +313,7 @@ if __name__ == '__main__':
|
||||
# Retrieve and set the observation mode specific env parameters
|
||||
additional_kwargs = observation_modes.get(obs_mode, {}).get("additional_env_kwargs", {})
|
||||
env_kwargs.update(additional_kwargs)
|
||||
for seed in range(5):
|
||||
for seed in range(n_seeds):
|
||||
env_kwargs.update(env_seed=seed)
|
||||
# Output folder
|
||||
seed_path = combination_path / f'{str(seed)}_{identifier}'
|
||||
@ -352,6 +391,7 @@ if __name__ == '__main__':
|
||||
# Evaluation starts here #####################################################
|
||||
# First Iterate over every model and monitor "as trained"
|
||||
if True:
|
||||
print('Start Baseline Tracking')
|
||||
for obs_mode in observation_modes:
|
||||
obs_mode_path = next(x for x in study_root_path.iterdir() if x.is_dir() and x.name == obs_mode)
|
||||
# For trained policy in study_root_path / identifier
|
||||
@ -370,9 +410,11 @@ if __name__ == '__main__':
|
||||
|
||||
# for seed_path in (y for y in policy_path.iterdir() if y.is_dir()):
|
||||
# load_model_run_baseline(seed_path)
|
||||
print('Baseline Tracking done')
|
||||
|
||||
# Then iterate over every model and monitor "ood behavior" - "is it ood?"
|
||||
if True:
|
||||
print('Start OOD Tracking')
|
||||
for obs_mode in observation_modes:
|
||||
obs_mode_path = next(x for x in study_root_path.iterdir() if x.is_dir() and x.name == obs_mode)
|
||||
# For trained policy in study_root_path / identifier
|
||||
@ -387,18 +429,19 @@ if __name__ == '__main__':
|
||||
pool = mp.Pool(mp.cpu_count())
|
||||
paths = list(y for y in policy_path.iterdir() if y.is_dir() \
|
||||
and not (y / ood_monitor_file).exists())
|
||||
result = pool.starmap(load_model_run_study,
|
||||
it.product(paths,
|
||||
(env_map[env_path.name][0],),
|
||||
(observation_modes[obs_mode],))
|
||||
)
|
||||
# for seed_path in (y for y in policy_path.iterdir() if y.is_dir()):
|
||||
# load_model_run_study(seed_path)
|
||||
# result = pool.starmap(load_model_run_study,
|
||||
# it.product(paths,
|
||||
# (env_map[env_path.name][0],),
|
||||
# (observation_modes[obs_mode],))
|
||||
# )
|
||||
for seed_path in (y for y in policy_path.iterdir() if y.is_dir()):
|
||||
load_model_run_study(seed_path, env_map[env_path.name][0], observation_modes[obs_mode])
|
||||
print('OOD Tracking Done')
|
||||
|
||||
# Plotting
|
||||
if True:
|
||||
# TODO: Plotting
|
||||
|
||||
print('Start Plotting')
|
||||
for observation_folder in (x for x in study_root_path.iterdir() if x.is_dir()):
|
||||
df_list = list()
|
||||
for env_folder in (x for x in observation_folder.iterdir() if x.is_dir()):
|
||||
|
Reference in New Issue
Block a user