mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-18 18:52:52 +02:00
major redesign ob observations and entittes
This commit is contained in:
@ -47,7 +47,7 @@ def encapsule_env_factory(env_fctry, env_kwrgs):
|
||||
def load_model_run_baseline(policy_path, env_to_run):
|
||||
# retrieve model class
|
||||
model_cls = h.MODEL_MAP['A2C']
|
||||
# Load both agents
|
||||
# Load both agent
|
||||
model = model_cls.load(policy_path / 'model.zip', device='cpu')
|
||||
# Load old env kwargs
|
||||
with next(policy_path.glob('*params.json')).open('r') as f:
|
||||
@ -76,7 +76,7 @@ def load_model_run_baseline(policy_path, env_to_run):
|
||||
def load_model_run_combined(root_path, env_to_run, env_kwargs):
|
||||
# retrieve model class
|
||||
model_cls = h.MODEL_MAP['A2C']
|
||||
# Load both agents
|
||||
# Load both agent
|
||||
models = [model_cls.load(model_zip, device='cpu') for model_zip in root_path.rglob('model.zip')]
|
||||
# Load old env kwargs
|
||||
env_kwargs = env_kwargs.copy()
|
||||
@ -252,7 +252,7 @@ if __name__ == '__main__':
|
||||
if individual_run:
|
||||
print('Start Individual Recording')
|
||||
for env_key in (env_key for env_key in env_map if 'combined' != env_key):
|
||||
# For trained policy in study_root_path / identifier
|
||||
# For trained policy in study_root_path / _identifier
|
||||
policy_path = study_root_path / env_key
|
||||
load_model_run_baseline(policy_path, env_map[policy_path.name][0])
|
||||
|
||||
@ -264,7 +264,7 @@ if __name__ == '__main__':
|
||||
if combined_run:
|
||||
print('Start combined run')
|
||||
for env_key in (env_key for env_key in env_map if 'combined' == env_key):
|
||||
# For trained policy in study_root_path / identifier
|
||||
# For trained policy in study_root_path / _identifier
|
||||
factory, kwargs = env_map[env_key]
|
||||
load_model_run_combined(study_root_path, factory, kwargs)
|
||||
print('OOD Tracking Done')
|
||||
|
Reference in New Issue
Block a user