mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-12-12 10:30:37 +01:00
Debugging
This commit is contained in:
@@ -1,14 +1,10 @@
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
from stable_baselines3 import A2C
|
||||
from stable_baselines3 import A2C, PPO, DQN
|
||||
|
||||
from environments import helpers as h
|
||||
from environments.helpers import Constants as c
|
||||
from environments.factory.factory_dirt import DirtFactory
|
||||
from environments.factory.combined_factories import DirtItemFactory
|
||||
from environments.logging.recorder import EnvRecorder
|
||||
|
||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
@@ -17,7 +13,7 @@ warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
determin = True
|
||||
determin = False
|
||||
render = True
|
||||
record = False
|
||||
seed = 67
|
||||
@@ -37,7 +33,7 @@ if __name__ == '__main__':
|
||||
|
||||
this_model = out_path / 'model.zip'
|
||||
|
||||
model_cls = A2C # next(val for key, val in h.MODEL_MAP.items() if key in out_path.parent.name)
|
||||
model_cls = PPO # next(val for key, val in h.MODEL_MAP.items() if key in out_path.parent.name)
|
||||
models = [model_cls.load(this_model)]
|
||||
|
||||
# Init Env
|
||||
|
||||
Reference in New Issue
Block a user