Updated RL_runner

This commit is contained in:
Julian Schönberger
2024-08-09 16:33:06 +02:00
parent 4c81e4b865
commit 50ef0c94e9

View File

@ -11,7 +11,7 @@ def rerun_dirt_quadrant_agent1_training():
eval_cfg = load_yaml_file(eval_cfg_path)
print("Training phase")
agent = A2C(train_cfg, eval_cfg)
agent = A2C(train_cfg=train_cfg, eval_cfg=eval_cfg, mode="train")
agent.train_loop()
print("Evaluation phase")
agent.eval_loop(n_episodes=1)
@ -23,11 +23,11 @@ def two_rooms_training(max_steps, agent_name):
train_cfg = load_yaml_file(train_cfg_path)
eval_cfg = load_yaml_file(eval_cfg_path)
train_cfg["algorithm"]["max_steps"] = max_steps
# train_cfg["algorithm"]["max_steps"] = max_steps
train_cfg["env"]["env_name"] = f"rl/two_rooms_{agent_name}_train_config"
eval_cfg["env"]["env_name"] = f"rl/two_rooms_{agent_name}_eval_config"
print("Training phase")
agent = A2C(train_cfg, eval_cfg)
agent = A2C(train_cfg=train_cfg, eval_cfg=eval_cfg, mode="train")
agent.train_loop()
print("Evaluation phase")
agent.eval_loop(n_episodes=1)
@ -43,11 +43,11 @@ def rerun_two_rooms_agent2_training():
####### Eval routines ########
def single_agent_eval(config_name, run_folder_name):
eval_cfg_path = Path(f'../marl_factory_grid/algorithms/rl/single_agent_configs/{config_name}_eval_config.yaml')
train_cfg = eval_cfg = load_yaml_file(eval_cfg_path)
eval_cfg_path = Path(f'./marl_factory_grid/algorithms/rl/single_agent_configs/{config_name}_eval_config.yaml')
eval_cfg = load_yaml_file(eval_cfg_path)
# A value for train_cfg is required, but the train environment won't be used
agent = A2C(train_cfg=train_cfg, eval_cfg=eval_cfg)
agent = A2C(eval_cfg=eval_cfg, mode="eval")
print("Evaluation phase")
agent.load_agents([run_folder_name])
agent.eval_loop(1)
@ -59,7 +59,7 @@ def multi_agent_eval(config_name, runs, emergent_phenomenon=False):
eval_cfg = load_yaml_file(eval_cfg_path)
# A value for train_cfg is required, but the train environment won't be used
agent = A2C(train_cfg=eval_cfg, eval_cfg=eval_cfg)
agent = A2C(eval_cfg=eval_cfg, mode="eval")
print("Evaluation phase")
agent.load_agents(runs)
agent.eval_loop(1)