mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-12-12 10:30:37 +01:00
study e_1 corpus
This commit is contained in:
@@ -3,7 +3,6 @@ from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from natsort import natsorted
|
||||
from stable_baselines3 import PPO, DQN, A2C
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
|
||||
from environments.factory.factory_dirt import DirtProperties, DirtFactory
|
||||
@@ -12,13 +11,12 @@ from environments.factory.factory_item import ItemProperties, ItemFactory
|
||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
model_map = dict(PPO=PPO, DQN=DQN, A2C=A2C)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
model_name = 'PPO_1631029150'
|
||||
model_name = 'DQN_1631092016'
|
||||
run_id = 0
|
||||
seed=69
|
||||
seed = 69
|
||||
out_path = Path(__file__).parent / 'debug_out'
|
||||
model_path = out_path / model_name
|
||||
|
||||
@@ -38,5 +36,5 @@ if __name__ == '__main__':
|
||||
this_model = model_files[0]
|
||||
model_cls = next(val for key, val in model_map.items() if key in model_name)
|
||||
model = model_cls.load(this_model)
|
||||
evaluation_result = evaluate_policy(model, env, n_eval_episodes=100, deterministic=True, render=True)
|
||||
evaluation_result = evaluate_policy(model, env, n_eval_episodes=100, deterministic=False, render=True)
|
||||
print(evaluation_result)
|
||||
|
||||
Reference in New Issue
Block a user