Adjustments and Documentation, recording and new environments, refactoring

This commit is contained in:
Steffen Illium
2022-08-04 14:57:48 +02:00
parent e7461d7dcf
commit 6a24e7b518
41 changed files with 1660 additions and 760 deletions

View File

@ -1,11 +1,12 @@
import sys
import time
from pathlib import Path
from matplotlib import pyplot as plt
import itertools as it
import simplejson
import stable_baselines3 as sb3
# This is needed, when you put this file in a subfolder.
try:
# noinspection PyUnboundLocalVariable
if __package__ is None:
@ -18,19 +19,14 @@ except NameError:
DIR = None
pass
import simplejson
from stable_baselines3.common.vec_env import SubprocVecEnv
from environments import helpers as h
from environments.factory.factory_dirt import DirtProperties, DirtFactory
from environments.logging.envmonitor import EnvMonitor
from environments.logging.recorder import EnvRecorder
from environments.factory.additional.dirt.dirt_util import DirtProperties
from environments.factory.additional.dirt.factory_dirt import DirtFactory
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
import pickle
from plotting.compare_runs import compare_seed_runs, compare_model_runs
import pandas as pd
import seaborn as sns
import multiprocessing as mp
from plotting.compare_runs import compare_seed_runs
"""
Welcome to this quick start file. Here we will see how to:
@ -53,6 +49,8 @@ if __name__ == '__main__':
model_class = sb3.PPO
env_class = DirtFactory
env_params_json = 'env_params.json'
# Define a global studi save path
start_time = int(time.time())
study_root_path = Path(__file__).parent.parent / 'study_out' / f'{Path(__file__).stem}_{start_time}'
@ -100,7 +98,7 @@ if __name__ == '__main__':
mv_prop=move_props, # See Above
obs_prop=obs_props, # See Above
done_at_collision=True,
dirt_props=dirt_props
dirt_prop=dirt_props
)
#########################################################
@ -120,30 +118,37 @@ if __name__ == '__main__':
seed_path.mkdir(parents=True, exist_ok=True)
# Parameter Storage
param_path = seed_path / f'env_params.json'
param_path = seed_path / env_params_json
# Observation (measures) Storage
monitor_path = seed_path / 'monitor.pick'
recorder_path = seed_path / 'recorder.json'
# Model save Path for the trained model
model_save_path = seed_path / f'model.zip'
# Env Init & Model kwargs definition
with DirtFactory(env_kwargs) as env_factory:
with env_class(**env_kwargs) as env_factory:
# EnvMonitor Init
env_monitor_callback = EnvMonitor(env_factory)
# EnvRecorder Init
env_recorder_callback = EnvRecorder(env_factory, freq=int(train_steps / 400 / 10))
# Model Init
model = model_class("MlpPolicy", env_factory,verbose=1, seed=seed, device='cpu')
# Model train
model.learn(total_timesteps=int(train_steps), callback=[env_monitor_callback])
model.learn(total_timesteps=int(train_steps), callback=[env_monitor_callback, env_recorder_callback])
#########################################################
# 3. Save env and agent for later analysis.
# Save the trained Model, the monitor (env measures) and the env parameters
model.named_observation_space = env_factory.named_observation_space
model.named_action_space = env_factory.named_action_space
model.save(model_save_path)
env_factory.save_params(param_path)
env_monitor_callback.save_run(monitor_path)
env_recorder_callback.save_records(recorder_path, save_occupation_map=False)
# Compare performance runs, for each seed within a model
try:
@ -164,18 +169,19 @@ if __name__ == '__main__':
# Load the agent agent
model = model_cls.load(policy_path / 'model.zip', device='cpu')
# Load old env kwargs
with next(policy_path.glob('*.json')).open('r') as f:
with next(policy_path.glob(env_params_json)).open('r') as f:
env_kwargs = simplejson.load(f)
# Make the env stop ar collisions
# (you only want to have a single collision per episode hence the statistics)
env_kwargs.update(done_at_collision=True)
# Init Env
with env_to_run(**env_kwargs) as env_factory:
with env_class(**env_kwargs) as env_factory:
monitored_env_factory = EnvMonitor(env_factory)
# Evaluation Loop for i in range(n Episodes)
for episode in range(100):
# noinspection PyRedeclaration
env_state = monitored_env_factory.reset()
rew, done_bool = 0, False
while not done_bool:
@ -185,8 +191,5 @@ if __name__ == '__main__':
if done_bool:
break
print(f'Factory run {episode} done, reward is:\n {rew}')
monitored_env_factory.save_run(filepath=policy_path / f'{baseline_monitor_file}.pick')
# for policy_path in (y for y in policy_path.iterdir() if y.is_dir()):
# load_model_run_baseline(policy_path)
monitored_env_factory.save_run(filepath=policy_path / 'eval_run_monitor.pick')
print('Measurements Done')