mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-12-14 03:00:37 +01:00
Updated RL_runner
This commit is contained in:
@@ -11,7 +11,7 @@ def rerun_dirt_quadrant_agent1_training():
|
|||||||
eval_cfg = load_yaml_file(eval_cfg_path)
|
eval_cfg = load_yaml_file(eval_cfg_path)
|
||||||
|
|
||||||
print("Training phase")
|
print("Training phase")
|
||||||
agent = A2C(train_cfg, eval_cfg)
|
agent = A2C(train_cfg=train_cfg, eval_cfg=eval_cfg, mode="train")
|
||||||
agent.train_loop()
|
agent.train_loop()
|
||||||
print("Evaluation phase")
|
print("Evaluation phase")
|
||||||
agent.eval_loop(n_episodes=1)
|
agent.eval_loop(n_episodes=1)
|
||||||
@@ -23,11 +23,11 @@ def two_rooms_training(max_steps, agent_name):
|
|||||||
train_cfg = load_yaml_file(train_cfg_path)
|
train_cfg = load_yaml_file(train_cfg_path)
|
||||||
eval_cfg = load_yaml_file(eval_cfg_path)
|
eval_cfg = load_yaml_file(eval_cfg_path)
|
||||||
|
|
||||||
train_cfg["algorithm"]["max_steps"] = max_steps
|
# train_cfg["algorithm"]["max_steps"] = max_steps
|
||||||
train_cfg["env"]["env_name"] = f"rl/two_rooms_{agent_name}_train_config"
|
train_cfg["env"]["env_name"] = f"rl/two_rooms_{agent_name}_train_config"
|
||||||
eval_cfg["env"]["env_name"] = f"rl/two_rooms_{agent_name}_eval_config"
|
eval_cfg["env"]["env_name"] = f"rl/two_rooms_{agent_name}_eval_config"
|
||||||
print("Training phase")
|
print("Training phase")
|
||||||
agent = A2C(train_cfg, eval_cfg)
|
agent = A2C(train_cfg=train_cfg, eval_cfg=eval_cfg, mode="train")
|
||||||
agent.train_loop()
|
agent.train_loop()
|
||||||
print("Evaluation phase")
|
print("Evaluation phase")
|
||||||
agent.eval_loop(n_episodes=1)
|
agent.eval_loop(n_episodes=1)
|
||||||
@@ -43,11 +43,11 @@ def rerun_two_rooms_agent2_training():
|
|||||||
|
|
||||||
####### Eval routines ########
|
####### Eval routines ########
|
||||||
def single_agent_eval(config_name, run_folder_name):
|
def single_agent_eval(config_name, run_folder_name):
|
||||||
eval_cfg_path = Path(f'../marl_factory_grid/algorithms/rl/single_agent_configs/{config_name}_eval_config.yaml')
|
eval_cfg_path = Path(f'./marl_factory_grid/algorithms/rl/single_agent_configs/{config_name}_eval_config.yaml')
|
||||||
train_cfg = eval_cfg = load_yaml_file(eval_cfg_path)
|
eval_cfg = load_yaml_file(eval_cfg_path)
|
||||||
|
|
||||||
# A value for train_cfg is required, but the train environment won't be used
|
# A value for train_cfg is required, but the train environment won't be used
|
||||||
agent = A2C(train_cfg=train_cfg, eval_cfg=eval_cfg)
|
agent = A2C(eval_cfg=eval_cfg, mode="eval")
|
||||||
print("Evaluation phase")
|
print("Evaluation phase")
|
||||||
agent.load_agents([run_folder_name])
|
agent.load_agents([run_folder_name])
|
||||||
agent.eval_loop(1)
|
agent.eval_loop(1)
|
||||||
@@ -59,7 +59,7 @@ def multi_agent_eval(config_name, runs, emergent_phenomenon=False):
|
|||||||
eval_cfg = load_yaml_file(eval_cfg_path)
|
eval_cfg = load_yaml_file(eval_cfg_path)
|
||||||
|
|
||||||
# A value for train_cfg is required, but the train environment won't be used
|
# A value for train_cfg is required, but the train environment won't be used
|
||||||
agent = A2C(train_cfg=eval_cfg, eval_cfg=eval_cfg)
|
agent = A2C(eval_cfg=eval_cfg, mode="eval")
|
||||||
print("Evaluation phase")
|
print("Evaluation phase")
|
||||||
agent.load_agents(runs)
|
agent.load_agents(runs)
|
||||||
agent.eval_loop(1)
|
agent.eval_loop(1)
|
||||||
|
|||||||
Reference in New Issue
Block a user