mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-18 18:52:52 +02:00
Debugging
This commit is contained in:
@ -114,6 +114,7 @@ if __name__ == '__main__':
|
||||
train = True
|
||||
individual_run = True
|
||||
combined_run = True
|
||||
multi_env = False
|
||||
|
||||
train_steps = 2e5
|
||||
frames_to_stack = 3
|
||||
@ -122,7 +123,7 @@ if __name__ == '__main__':
|
||||
study_root_path = Path(__file__).parent.parent / 'study_out' / f'{Path(__file__).stem}'
|
||||
|
||||
def policy_model_kwargs():
|
||||
return dict(learning_rate=0.0003, n_steps=10, gamma=0.95, gae_lambda=0.0, ent_coef=0.01, vf_coef=0.5)
|
||||
return dict()
|
||||
|
||||
# Define Global Env Parameters
|
||||
# Define properties object parameters
|
||||
@ -142,22 +143,22 @@ if __name__ == '__main__':
|
||||
item_props = ItemProperties(n_items=10, spawn_frequency=30, n_drop_off_locations=2,
|
||||
max_agent_inventory_capacity=15)
|
||||
dest_props = DestProperties(n_dests=4, spawn_mode=DestModeOptions.GROUPED, spawn_frequency=1)
|
||||
factory_kwargs = dict(n_agents=1, max_steps=400, parse_doors=True,
|
||||
level_name='rooms', doors_have_area=False,
|
||||
factory_kwargs = dict(n_agents=1, max_steps=500, parse_doors=True,
|
||||
level_name='rooms', doors_have_area=True,
|
||||
verbose=False,
|
||||
mv_prop=move_props,
|
||||
obs_prop=obs_props,
|
||||
done_at_collision=True
|
||||
done_at_collision=False
|
||||
)
|
||||
|
||||
# Bundle both environments with global kwargs and parameters
|
||||
env_map = {}
|
||||
env_map.update({'dirt': (DirtFactory, dict(dirt_prop=dirt_props,
|
||||
**factory_kwargs.copy()))})
|
||||
env_map.update({'item': (ItemFactory, dict(item_prop=item_props,
|
||||
**factory_kwargs.copy()))})
|
||||
env_map.update({'dest': (DestFactory, dict(dest_prop=dest_props,
|
||||
**factory_kwargs.copy()))})
|
||||
# env_map.update({'item': (ItemFactory, dict(item_prop=item_props,
|
||||
# **factory_kwargs.copy()))})
|
||||
# env_map.update({'dest': (DestFactory, dict(dest_prop=dest_props,
|
||||
# **factory_kwargs.copy()))})
|
||||
env_map.update({'combined': (DirtDestItemFactory, dict(dest_prop=dest_props,
|
||||
item_prop=item_props,
|
||||
dirt_prop=dirt_props,
|
||||
@ -168,7 +169,7 @@ if __name__ == '__main__':
|
||||
# Build Major Loop parameters, parameter versions, Env Classes and models
|
||||
if train:
|
||||
for env_key in (env_key for env_key in env_map if 'combined' != env_key):
|
||||
model_cls = h.MODEL_MAP['A2C']
|
||||
model_cls = h.MODEL_MAP['PPO']
|
||||
combination_path = study_root_path / env_key
|
||||
env_class, env_kwargs = env_map[env_key]
|
||||
|
||||
@ -177,8 +178,11 @@ if __name__ == '__main__':
|
||||
continue
|
||||
combination_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
env_factory = SubprocVecEnv([encapsule_env_factory(env_class, env_kwargs)
|
||||
for _ in range(6)], start_method="spawn")
|
||||
if not multi_env:
|
||||
env_factory = encapsule_env_factory(env_class, env_kwargs)()
|
||||
else:
|
||||
env_factory = SubprocVecEnv([encapsule_env_factory(env_class, env_kwargs)
|
||||
for _ in range(6)], start_method="spawn")
|
||||
|
||||
param_path = combination_path / f'env_params.json'
|
||||
try:
|
||||
|
Reference in New Issue
Block a user