study e_1 corpus

This commit is contained in:
Steffen Illium
2021-09-08 16:24:14 +02:00
parent b09055d95d
commit 4c21a0af7c
8 changed files with 246 additions and 87 deletions

View File

@@ -3,7 +3,6 @@ from pathlib import Path
import yaml
from natsort import natsorted
from stable_baselines3 import PPO, DQN, A2C
from stable_baselines3.common.evaluation import evaluate_policy
from environments.factory.factory_dirt import DirtProperties, DirtFactory
@@ -12,13 +11,12 @@ from environments.factory.factory_item import ItemProperties, ItemFactory
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
model_map = dict(PPO=PPO, DQN=DQN, A2C=A2C)
if __name__ == '__main__':
model_name = 'PPO_1631029150'
model_name = 'DQN_1631092016'
run_id = 0
seed=69
seed = 69
out_path = Path(__file__).parent / 'debug_out'
model_path = out_path / model_name
@@ -38,5 +36,5 @@ if __name__ == '__main__':
this_model = model_files[0]
model_cls = next(val for key, val in model_map.items() if key in model_name)
model = model_cls.load(this_model)
evaluation_result = evaluate_policy(model, env, n_eval_episodes=100, deterministic=True, render=True)
evaluation_result = evaluate_policy(model, env, n_eval_episodes=100, deterministic=False, render=True)
print(evaluation_result)