major redesign ob observations and entittes

This commit is contained in:
Steffen Illium
2023-06-09 14:04:17 +02:00
parent 901fbcbc32
commit c552c35f66
161 changed files with 4458 additions and 4163 deletions

View File

@ -47,7 +47,7 @@ def encapsule_env_factory(env_fctry, env_kwrgs):
def load_model_run_baseline(policy_path, env_to_run):
# retrieve model class
model_cls = h.MODEL_MAP['A2C']
# Load both agents
# Load both agent
model = model_cls.load(policy_path / 'model.zip', device='cpu')
# Load old env kwargs
with next(policy_path.glob('*params.json')).open('r') as f:
@ -76,7 +76,7 @@ def load_model_run_baseline(policy_path, env_to_run):
def load_model_run_combined(root_path, env_to_run, env_kwargs):
# retrieve model class
model_cls = h.MODEL_MAP['A2C']
# Load both agents
# Load both agent
models = [model_cls.load(model_zip, device='cpu') for model_zip in root_path.rglob('model.zip')]
# Load old env kwargs
env_kwargs = env_kwargs.copy()
@ -252,7 +252,7 @@ if __name__ == '__main__':
if individual_run:
print('Start Individual Recording')
for env_key in (env_key for env_key in env_map if 'combined' != env_key):
# For trained policy in study_root_path / identifier
# For trained policy in study_root_path / _identifier
policy_path = study_root_path / env_key
load_model_run_baseline(policy_path, env_map[policy_path.name][0])
@ -264,7 +264,7 @@ if __name__ == '__main__':
if combined_run:
print('Start combined run')
for env_key in (env_key for env_key in env_map if 'combined' == env_key):
# For trained policy in study_root_path / identifier
# For trained policy in study_root_path / _identifier
factory, kwargs = env_map[env_key]
load_model_run_combined(study_root_path, factory, kwargs)
print('OOD Tracking Done')