Rework of Observations and Entity Differentiation, lazy obs build by notification

This commit is contained in:
Steffen Illium
2021-12-22 10:48:36 +01:00
parent 7f7a3d9a3b
commit b43f595207
14 changed files with 961 additions and 487 deletions

View File

@@ -3,6 +3,7 @@ from pathlib import Path
import numpy as np
import yaml
from stable_baselines3 import A2C
from environments import helpers as h
from environments.helpers import Constants as c
@@ -16,13 +17,12 @@ warnings.filterwarnings('ignore', category=UserWarning)
if __name__ == '__main__':
determin = False
determin = True
render = True
record = True
record = False
seed = 67
n_agents = 1
out_path = Path('study_out/e_1_new_reward/no_obs/dirt/A2C_new_reward/0_A2C_new_reward')
out_path_2 = Path('study_out/e_1_obs_stack_3_gae_0.25_n_steps_16/seperate_N/dirt/A2C_obs_stack_3_gae_0.25_n_steps_16/1_A2C_obs_stack_3_gae_0.25_n_steps_16')
out_path = Path('study_out/single_run_with_export/dirt')
model_path = out_path
with (out_path / f'env_params.json').open('r') as f:
@@ -35,10 +35,9 @@ if __name__ == '__main__':
env_kwargs.update(record_episodes=record, done_at_collision=True)
this_model = out_path / 'model.zip'
other_model = out_path / 'model.zip'
model_cls = next(val for key, val in h.MODEL_MAP.items() if key in out_path.parent.name)
models = [model_cls.load(this_model)] # , model_cls.load(other_model)]
model_cls = A2C # next(val for key, val in h.MODEL_MAP.items() if key in out_path.parent.name)
models = [model_cls.load(this_model)]
# Init Env
with DirtFactory(**env_kwargs) as env:
@@ -59,6 +58,8 @@ if __name__ == '__main__':
rew += step_r
if render:
env.render()
if not env.unwrapped.unwrapped[c.AGENT][0].temp_valid:
print('Invalid ACtions')
if done_bool:
break
print(f'Factory run {episode} done, reward is:\n {rew}')