mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-22 14:56:43 +02:00
87 lines
2.8 KiB
Python
87 lines
2.8 KiB
Python
# foreign imports
|
|
import warnings
|
|
|
|
from pathlib import Path
|
|
import yaml
|
|
from gym.wrappers import FrameStack
|
|
from natsort import natsorted
|
|
|
|
from stable_baselines3.common.callbacks import CallbackList
|
|
from stable_baselines3 import PPO, DQN, A2C
|
|
|
|
# our imports
|
|
from environments.factory.factory_dirt import DirtFactory, DirtProperties
|
|
from environments.logging.monitor import MonitorCallback
|
|
from algorithms.reg_dqn import RegDQN
|
|
from main import compare_model_runs, compare_seed_runs
|
|
|
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
|
warnings.filterwarnings('ignore', category=UserWarning)
|
|
model_mapping = dict(A2C=A2C, PPO=PPO, DQN=DQN, RegDQN=RegDQN)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
# get n policies pi_1, ..., pi_n trained in single agent setting
|
|
# rewards = []
|
|
# repeat for x eval runs
|
|
# total reward = rollout game for y steps with n policies in multi-agent setting
|
|
# rewards += [total reward]
|
|
# boxplot total rewards
|
|
|
|
run_id = '1623923982'
|
|
model_name = 'A2C'
|
|
|
|
# -----------------------
|
|
out_path = Path(__file__).parent / 'debug_out'
|
|
|
|
# from sb3_contrib import QRDQN
|
|
model_path = out_path / f'{model_name}_{run_id}'
|
|
model_files = list(natsorted(model_path.rglob('model_*.zip')))
|
|
this_model = model_files[0]
|
|
render = True
|
|
|
|
model = model_mapping[model_name].load(this_model)
|
|
|
|
for seed in range(3):
|
|
with (model_path / f'env_{model_path.name}.yaml').open('r') as f:
|
|
env_kwargs = yaml.load(f, Loader=yaml.FullLoader)
|
|
dirt_props = DirtProperties(clean_amount=3, gain_amount=0.2, max_global_amount=30,
|
|
max_local_amount=3, spawn_frequency=1, max_spawn_ratio=0.05)
|
|
# env_kwargs.update(n_agents=1, dirt_properties=dirt_props)
|
|
env = DirtFactory(**env_kwargs)
|
|
|
|
env = FrameStack(env, 4)
|
|
|
|
exp_out_path = model_path / 'exp'
|
|
callbacks = CallbackList(
|
|
[MonitorCallback(filepath=exp_out_path / f'future_exp_name')]
|
|
)
|
|
|
|
n_actions = env.action_space.n
|
|
|
|
for epoch in range(100):
|
|
observations = env.reset()
|
|
if render:
|
|
if isinstance(env, FrameStack):
|
|
env.env.render()
|
|
else:
|
|
env.render()
|
|
done_bool = False
|
|
r = 0
|
|
while not done_bool:
|
|
if env.n_agents > 1:
|
|
actions = [model.predict(obs, deterministic=False)[0] for obs in observations]
|
|
else:
|
|
actions = model.predict(observations, deterministic=False)[0]
|
|
|
|
observations, r, done_bool, info_obj = env.step(actions)
|
|
if render:
|
|
env.render()
|
|
if done_bool:
|
|
break
|
|
print(f'Factory run {epoch} done, reward is:\n {r}')
|
|
|
|
if out_path:
|
|
compare_seed_runs(out_path.parent)
|