mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-08 02:21:36 +02:00
Updated RL_runner
This commit is contained in:
@ -11,7 +11,7 @@ def rerun_dirt_quadrant_agent1_training():
|
||||
eval_cfg = load_yaml_file(eval_cfg_path)
|
||||
|
||||
print("Training phase")
|
||||
agent = A2C(train_cfg, eval_cfg)
|
||||
agent = A2C(train_cfg=train_cfg, eval_cfg=eval_cfg, mode="train")
|
||||
agent.train_loop()
|
||||
print("Evaluation phase")
|
||||
agent.eval_loop(n_episodes=1)
|
||||
@ -23,11 +23,11 @@ def two_rooms_training(max_steps, agent_name):
|
||||
train_cfg = load_yaml_file(train_cfg_path)
|
||||
eval_cfg = load_yaml_file(eval_cfg_path)
|
||||
|
||||
train_cfg["algorithm"]["max_steps"] = max_steps
|
||||
# train_cfg["algorithm"]["max_steps"] = max_steps
|
||||
train_cfg["env"]["env_name"] = f"rl/two_rooms_{agent_name}_train_config"
|
||||
eval_cfg["env"]["env_name"] = f"rl/two_rooms_{agent_name}_eval_config"
|
||||
print("Training phase")
|
||||
agent = A2C(train_cfg, eval_cfg)
|
||||
agent = A2C(train_cfg=train_cfg, eval_cfg=eval_cfg, mode="train")
|
||||
agent.train_loop()
|
||||
print("Evaluation phase")
|
||||
agent.eval_loop(n_episodes=1)
|
||||
@ -43,11 +43,11 @@ def rerun_two_rooms_agent2_training():
|
||||
|
||||
####### Eval routines ########
|
||||
def single_agent_eval(config_name, run_folder_name):
|
||||
eval_cfg_path = Path(f'../marl_factory_grid/algorithms/rl/single_agent_configs/{config_name}_eval_config.yaml')
|
||||
train_cfg = eval_cfg = load_yaml_file(eval_cfg_path)
|
||||
eval_cfg_path = Path(f'./marl_factory_grid/algorithms/rl/single_agent_configs/{config_name}_eval_config.yaml')
|
||||
eval_cfg = load_yaml_file(eval_cfg_path)
|
||||
|
||||
# A value for train_cfg is required, but the train environment won't be used
|
||||
agent = A2C(train_cfg=train_cfg, eval_cfg=eval_cfg)
|
||||
agent = A2C(eval_cfg=eval_cfg, mode="eval")
|
||||
print("Evaluation phase")
|
||||
agent.load_agents([run_folder_name])
|
||||
agent.eval_loop(1)
|
||||
@ -59,7 +59,7 @@ def multi_agent_eval(config_name, runs, emergent_phenomenon=False):
|
||||
eval_cfg = load_yaml_file(eval_cfg_path)
|
||||
|
||||
# A value for train_cfg is required, but the train environment won't be used
|
||||
agent = A2C(train_cfg=eval_cfg, eval_cfg=eval_cfg)
|
||||
agent = A2C(eval_cfg=eval_cfg, mode="eval")
|
||||
print("Evaluation phase")
|
||||
agent.load_agents(runs)
|
||||
agent.eval_loop(1)
|
||||
|
Reference in New Issue
Block a user