Reworked differentiation between train and eval execution + Renamed cfgs + Added algorithm seeding + Included early stopping functionality + Added weights&biases logging

This commit is contained in:
Julian Schönberger
2024-08-09 16:30:04 +02:00
parent 81b12612ed
commit 8e8e925278
3 changed files with 190 additions and 59 deletions

View File

@@ -1,7 +1,9 @@
import os import os
import pickle
import torch import torch
from typing import Union, List from typing import Union, List
import numpy as np import numpy as np
import wandb
from tqdm import tqdm from tqdm import tqdm
from marl_factory_grid.algorithms.rl.base_a2c import PolicyGradient, cumulate_discount from marl_factory_grid.algorithms.rl.base_a2c import PolicyGradient, cumulate_discount
@@ -9,34 +11,52 @@ from marl_factory_grid.algorithms.rl.constants import Names
from marl_factory_grid.algorithms.rl.utils import transform_observations, _as_torch, is_door_close, \ from marl_factory_grid.algorithms.rl.utils import transform_observations, _as_torch, is_door_close, \
get_dirt_piles_positions, update_target_pile, update_ordered_dirt_piles, get_all_cleaned_dirt_piles, \ get_dirt_piles_positions, update_target_pile, update_ordered_dirt_piles, get_all_cleaned_dirt_piles, \
distribute_indices, set_agents_spawnpoints, get_ordered_dirt_piles, handle_finished_episode, save_configs, \ distribute_indices, set_agents_spawnpoints, get_ordered_dirt_piles, handle_finished_episode, save_configs, \
save_agent_models, get_all_observations, get_agents_positions save_agent_models, get_all_observations, get_agents_positions, has_low_change_phase_started, significant_deviation, \
log_wandb_training
from marl_factory_grid.algorithms.utils import add_env_props from marl_factory_grid.algorithms.utils import add_env_props
from marl_factory_grid.utils.plotting.plot_single_runs import plot_action_maps, plot_reward_development, \ from marl_factory_grid.utils.plotting.plot_single_runs import plot_action_maps, plot_return_development, \
create_info_maps create_info_maps, plot_return_development_change
nms = Names nms = Names
ListOrTensor = Union[List, torch.Tensor] ListOrTensor = Union[List, torch.Tensor]
class A2C: class A2C:
def __init__(self, train_cfg, eval_cfg): def __init__(self, train_cfg=None, eval_cfg=None, mode="train"):
self.factory = add_env_props(train_cfg) self.mode = mode
if mode == nms.TRAIN:
self.train_factory = add_env_props(train_cfg)
self.train_cfg = train_cfg
self.n_agents = train_cfg[nms.ENV][nms.N_AGENTS]
else:
self.n_agents = eval_cfg[nms.ENV][nms.N_AGENTS]
self.eval_factory = add_env_props(eval_cfg) self.eval_factory = add_env_props(eval_cfg)
self.__training = True
self.train_cfg = train_cfg
self.eval_cfg = eval_cfg self.eval_cfg = eval_cfg
self.cfg = train_cfg
self.n_agents = train_cfg[nms.ENV][nms.N_AGENTS]
self.setup() self.setup()
self.reward_development = []
self.action_probabilities = {agent_idx: [] for agent_idx in range(self.n_agents)} self.action_probabilities = {agent_idx: [] for agent_idx in range(self.n_agents)}
def setup(self): def setup(self):
""" Initialize agents and create entry for run results according to configuration """ """ Initialize agents and create entry for run results according to configuration """
if self.mode == "train":
self.cfg = self.train_cfg
self.factory = self.train_factory
self.gamma = self.cfg[nms.ALGORITHM][nms.GAMMA]
else:
self.cfg = self.eval_cfg
self.factory = self.eval_factory
self.gamma = 0.99
seed = self.cfg[nms.ALGORITHM][nms.SEED]
print("Algorithm Seed: ", seed)
if seed == -1:
seed = np.random.choice(range(1000))
print("Algorithm seed is -1. Pick random seed: ", seed)
self.obs_dim = 2 + 2 * len(get_dirt_piles_positions(self.factory)) if self.cfg[nms.ALGORITHM][ self.obs_dim = 2 + 2 * len(get_dirt_piles_positions(self.factory)) if self.cfg[nms.ALGORITHM][
nms.PILE_OBSERVABILITY] == nms.ALL else 4 nms.PILE_OBSERVABILITY] == nms.ALL else 4
self.act_dim = 4 # The 4 movement directions self.act_dim = 4 # The 4 movement directions
self.agents = [PolicyGradient(self.factory, agent_id=i, obs_dim=self.obs_dim, act_dim=self.act_dim) for i in self.agents = [PolicyGradient(self.factory, seed=seed, gamma=self.gamma, agent_id=i, obs_dim=self.obs_dim, act_dim=self.act_dim) for i in
range(self.n_agents)] range(self.n_agents)]
if self.cfg[nms.ENV][nms.SAVE_AND_LOG]: if self.cfg[nms.ENV][nms.SAVE_AND_LOG]:
@@ -48,13 +68,9 @@ class A2C:
os.mkdir(self.results_path) os.mkdir(self.results_path)
# Save settings in results folder # Save settings in results folder
save_configs(self.results_path, self.cfg, self.factory.conf, self.eval_factory.conf) save_configs(self.results_path, self.cfg, self.factory.conf, self.eval_factory.conf)
if self.cfg[nms.ENV][nms.WANDB_LOG]:
def set_cfg(self, eval=False): wandb.login()
""" Set the mode of the current configuration """ wandb.init(project="EDYS", name=str(next_run_number))
if eval:
self.cfg = self.eval_cfg
else:
self.cfg = self.train_cfg
def load_agents(self, runs_list): def load_agents(self, runs_list):
""" Initialize networks with parameters of already trained agents """ """ Initialize networks with parameters of already trained agents """
@@ -67,39 +83,65 @@ class A2C:
def train_loop(self): def train_loop(self):
""" Function for training agents """ """ Function for training agents """
env = self.factory env = self.factory
n_steps, max_steps = [self.cfg[nms.ALGORITHM][k] for k in [nms.N_STEPS, nms.MAX_STEPS]] n_steps, max_steps = [self.train_cfg[nms.ALGORITHM][k] for k in [nms.N_STEPS, nms.MAX_STEPS]]
global_steps, episode = 0, 0 global_steps, episode = 0, 0
indices = distribute_indices(env, self.cfg, self.n_agents) indices = distribute_indices(env, self.train_cfg, self.n_agents)
dirt_piles_positions = get_dirt_piles_positions(env) dirt_piles_positions = get_dirt_piles_positions(env)
target_pile = [partition[0] for partition in target_pile = [partition[0] for partition in
indices] # list of pointers that point to the current target pile for each agent indices] # list of pointers that point to the current target pile for each agent
cleaned_dirt_piles = [{pos: False for pos in dirt_piles_positions} for _ in range(self.n_agents)] cleaned_dirt_piles = [{pos: False for pos in dirt_piles_positions} for _ in range(self.n_agents)]
low_change_phase_start_episode = -1
episode_rewards_development = []
return_change_development = []
pbar = tqdm(total=max_steps) pbar = tqdm(total=max_steps)
while global_steps < max_steps: loop_condition = True if self.train_cfg[nms.ALGORITHM][nms.EARLY_STOPPING] else global_steps < max_steps
while loop_condition:
_ = env.reset() _ = env.reset()
if self.cfg[nms.ENV][nms.TRAIN_RENDER]: if self.train_cfg[nms.ENV][nms.TRAIN_RENDER]:
env.render() env.render()
set_agents_spawnpoints(env, self.n_agents) set_agents_spawnpoints(env, self.n_agents)
ordered_dirt_piles = get_ordered_dirt_piles(env, cleaned_dirt_piles, self.cfg, self.n_agents) ordered_dirt_piles = get_ordered_dirt_piles(env, cleaned_dirt_piles, self.train_cfg, self.n_agents)
# Reset current target pile at episode begin if all piles have to be cleaned in one episode # Reset current target pile at episode begin if all piles have to be cleaned in one episode
if self.cfg[nms.ALGORITHM][nms.PILE_ALL_DONE] == nms.ALL: if self.train_cfg[nms.ALGORITHM][nms.PILE_ALL_DONE] == nms.ALL:
target_pile = [partition[0] for partition in indices] target_pile = [partition[0] for partition in indices]
cleaned_dirt_piles = [{pos: False for pos in dirt_piles_positions} for _ in range(self.n_agents)] cleaned_dirt_piles = [{pos: False for pos in dirt_piles_positions} for _ in range(self.n_agents)]
episode_rewards_development.append([])
# Supply each agent with its local observation # Supply each agent with its local observation
obs = transform_observations(env, ordered_dirt_piles, target_pile, self.cfg, self.n_agents) obs = transform_observations(env, ordered_dirt_piles, target_pile, self.train_cfg, self.n_agents)
done, rew_log = [False] * self.n_agents, 0 done, ep_return = [False] * self.n_agents, 0
if self.train_cfg[nms.ALGORITHM][nms.EARLY_STOPPING]:
if len(return_change_development) > self.train_cfg[nms.ALGORITHM][
nms.LAST_N_EPISODES] and low_change_phase_start_episode == -1 and has_low_change_phase_started(
return_change_development, self.train_cfg[nms.ALGORITHM][nms.LAST_N_EPISODES],
self.train_cfg[nms.ALGORITHM][nms.MEAN_TARGET_CHANGE]):
low_change_phase_start_episode = len(return_change_development)
print(low_change_phase_start_episode)
# Check if requirements for early stopping are met
if len(return_change_development) % 50 == 0:
print(len(return_change_development))
if low_change_phase_start_episode != -1 and significant_deviation(return_change_development, low_change_phase_start_episode):
print(f"Early Stopping in Episode: {global_steps} because of significant deviation.")
break
if low_change_phase_start_episode != -1 and (len(return_change_development) - low_change_phase_start_episode) >= 1000:
print(f"Early Stopping in Episode: {global_steps} because of episode time limit")
break
if low_change_phase_start_episode != -1 and global_steps >= max_steps:
print(f"Early Stopping in Episode: {global_steps} because of global steps time limit")
break
while not all(done): while not all(done):
action = self.use_door_or_move(env, obs, cleaned_dirt_piles) \ action = self.use_door_or_move(env, obs, cleaned_dirt_piles) \
if nms.DOORS in env.state.entities.keys() else self.get_actions(obs) if nms.DOORS in env.state.entities.keys() else self.get_actions(obs)
_, next_obs, reward, done, info = env.step(action) _, next_obs, reward, done, info = env.step(action)
next_obs = transform_observations(env, ordered_dirt_piles, target_pile, self.cfg, self.n_agents) next_obs = transform_observations(env, ordered_dirt_piles, target_pile, self.train_cfg, self.n_agents)
# Handle case where agent is on field with dirt # Handle case where agent is on field with dirt
reward, done = self.handle_dirt(env, cleaned_dirt_piles, ordered_dirt_piles, target_pile, indices, reward, done = self.handle_dirt(env, cleaned_dirt_piles, ordered_dirt_piles, target_pile, indices,
reward, done) reward, done, self.train_cfg)
if n_steps != 0 and (global_steps + 1) % n_steps == 0: done = True if n_steps != 0 and (global_steps + 1) % n_steps == 0: done = True
@@ -110,26 +152,42 @@ class A2C:
agent._episode[-1] = (next_obs[ag_i], action[ag_i], reward[ag_i], agent._episode[-1][-1]) agent._episode[-1] = (next_obs[ag_i], action[ag_i], reward[ag_i], agent._episode[-1][-1])
# Visualize state update # Visualize state update
if self.cfg[nms.ENV][nms.TRAIN_RENDER]: env.render() if self.train_cfg[nms.ENV][nms.TRAIN_RENDER]: env.render()
obs = next_obs obs = next_obs
if all(done): handle_finished_episode(obs, self.agents, self.cfg) if all(done): handle_finished_episode(obs, self.agents, self.train_cfg)
global_steps += 1 global_steps += 1
rew_log += sum(reward) episode_rewards_development[-1].extend(reward)
if global_steps >= max_steps: break if global_steps >= max_steps: break
self.reward_development.append(rew_log) return_change_development.append(
sum(episode_rewards_development[-1]) - sum(episode_rewards_development[-2])
if len(episode_rewards_development) > 1 else 0.0)
episode += 1 episode += 1
if self.cfg[nms.ENV][nms.SAVE_AND_LOG] and self.train_cfg[nms.ENV][nms.WANDB_LOG]:
log_wandb_training(episode, sum(episode_rewards_development[-1]),
np.sum([reward * pow(self.gamma, i) for i, reward in enumerate(episode_rewards_development[-1])]),
return_change_development[-1])
pbar.update(global_steps - pbar.n) pbar.update(global_steps - pbar.n)
pbar.close() pbar.close()
if self.cfg[nms.ENV][nms.SAVE_AND_LOG]: if self.train_cfg[nms.ENV][nms.SAVE_AND_LOG]:
plot_reward_development(self.reward_development, self.results_path) return_development = [np.sum(rewards) for rewards in episode_rewards_development]
create_info_maps(env, get_all_observations(env, self.cfg, self.n_agents), discounted_return_development = [np.sum([reward * pow(self.gamma, i) for i, reward in enumerate(ep_rewards)]) for ep_rewards in episode_rewards_development]
plot_return_development(return_development, self.results_path)
plot_return_development(discounted_return_development, self.results_path, discounted=True)
plot_return_development_change(return_change_development, self.results_path)
create_info_maps(env, get_all_observations(env, self.train_cfg, self.n_agents),
get_dirt_piles_positions(env), self.results_path, self.agents, self.act_dim, self) get_dirt_piles_positions(env), self.results_path, self.agents, self.act_dim, self)
metrics_data = {"episode_rewards_development": episode_rewards_development,
"return_development": return_development,
"discounted_return_development": discounted_return_development,
"return_change_development": return_change_development}
with open(f"{self.results_path}/metrics", "wb") as pickle_file:
pickle.dump(metrics_data, pickle_file)
save_agent_models(self.results_path, self.agents) save_agent_models(self.results_path, self.agents)
plot_action_maps(env, [self], self.results_path) plot_action_maps(env, [self], self.results_path)
@@ -137,23 +195,26 @@ class A2C:
def eval_loop(self, n_episodes): def eval_loop(self, n_episodes):
""" Function for performing inference """ """ Function for performing inference """
env = self.eval_factory env = self.eval_factory
self.set_cfg(eval=True)
episode, results = 0, [] episode, results = 0, []
dirt_piles_positions = get_dirt_piles_positions(env) dirt_piles_positions = get_dirt_piles_positions(env)
indices = distribute_indices(env, self.cfg, self.n_agents) print("Dirt Piles positions", dirt_piles_positions)
indices = distribute_indices(env, self.eval_cfg, self.n_agents)
target_pile = [partition[0] for partition in target_pile = [partition[0] for partition in
indices] # list of pointers that point to the current target pile for each agent indices] # list of pointers that point to the current target pile for each agent
if self.cfg[nms.ALGORITHM][nms.PILE_ALL_DONE] == nms.DISTRIBUTED: if self.eval_cfg[nms.ALGORITHM][nms.PILE_ALL_DONE] == nms.DISTRIBUTED:
cleaned_dirt_piles = [{dirt_piles_positions[idx]: False for idx in indices[i]} for i in cleaned_dirt_piles = [{dirt_piles_positions[idx]: False for idx in indices[i]} for i in
range(self.n_agents)] range(self.n_agents)]
else: cleaned_dirt_piles = [{pos: False for pos in dirt_piles_positions} for _ in range(self.n_agents)] else:
cleaned_dirt_piles = [{pos: False for pos in dirt_piles_positions} for _ in range(self.n_agents)]
cleaned_dirt_piles_per_step = []
while episode < n_episodes: while episode < n_episodes:
_ = env.reset() _ = env.reset()
set_agents_spawnpoints(env, self.n_agents) set_agents_spawnpoints(env, self.n_agents)
if self.cfg[nms.ENV][nms.EVAL_RENDER]: if self.eval_cfg[nms.ENV][nms.EVAL_RENDER]:
# Don't render auxiliary piles # Don't render auxiliary piles
if self.cfg[nms.ALGORITHM][nms.AUXILIARY_PILES]: if self.eval_cfg[nms.ALGORITHM][nms.AUXILIARY_PILES]:
auxiliary_piles = [pile for idx, pile in enumerate(env.state.entities[nms.DIRT_PILES]) if auxiliary_piles = [pile for idx, pile in enumerate(env.state.entities[nms.DIRT_PILES]) if
idx % 2 == 0] idx % 2 == 0]
for pile in auxiliary_piles: for pile in auxiliary_piles:
@@ -162,19 +223,23 @@ class A2C:
env._renderer.fps = 5 # Slow down agent movement env._renderer.fps = 5 # Slow down agent movement
# Reset current target pile at episode begin if all piles have to be cleaned in one episode # Reset current target pile at episode begin if all piles have to be cleaned in one episode
if self.cfg[nms.ALGORITHM][nms.PILE_ALL_DONE] in [nms.ALL, nms.DISTRIBUTED, nms.SHARED]: if self.eval_cfg[nms.ALGORITHM][nms.PILE_ALL_DONE] in [nms.ALL, nms.DISTRIBUTED, nms.SHARED]:
target_pile = [partition[0] for partition in indices] target_pile = [partition[0] for partition in indices]
if self.cfg[nms.ALGORITHM][nms.PILE_ALL_DONE] == nms.DISTRIBUTED: if self.eval_cfg[nms.ALGORITHM][nms.PILE_ALL_DONE] == nms.DISTRIBUTED:
cleaned_dirt_piles = [{dirt_piles_positions[idx]: False for idx in indices[i]} for i in cleaned_dirt_piles = [{dirt_piles_positions[idx]: False for idx in indices[i]} for i in
range(self.n_agents)] range(self.n_agents)]
else: cleaned_dirt_piles = [{pos: False for pos in dirt_piles_positions} for _ in range(self.n_agents)] else:
cleaned_dirt_piles = [{pos: False for pos in dirt_piles_positions} for _ in range(self.n_agents)]
ordered_dirt_piles = get_ordered_dirt_piles(env, cleaned_dirt_piles, self.cfg, self.n_agents) ordered_dirt_piles = get_ordered_dirt_piles(env, cleaned_dirt_piles, self.eval_cfg, self.n_agents)
# Supply each agent with its local observation # Supply each agent with its local observation
obs = transform_observations(env, ordered_dirt_piles, target_pile, self.cfg, self.n_agents) obs = transform_observations(env, ordered_dirt_piles, target_pile, self.eval_cfg, self.n_agents)
done, rew_log, eps_rew = [False] * self.n_agents, 0, torch.zeros(self.n_agents) done, rew_log, eps_rew = [False] * self.n_agents, 0, torch.zeros(self.n_agents)
cleaned_dirt_piles_per_step.append([])
ep_steps = 0
while not all(done): while not all(done):
action = self.use_door_or_move(env, obs, cleaned_dirt_piles, det=True) \ action = self.use_door_or_move(env, obs, cleaned_dirt_piles, det=True) \
if nms.DOORS in env.state.entities.keys() else self.execute_policy(obs, env, if nms.DOORS in env.state.entities.keys() else self.execute_policy(obs, env,
@@ -183,18 +248,37 @@ class A2C:
# Handle case where agent is on field with dirt # Handle case where agent is on field with dirt
reward, done = self.handle_dirt(env, cleaned_dirt_piles, ordered_dirt_piles, target_pile, indices, reward, done = self.handle_dirt(env, cleaned_dirt_piles, ordered_dirt_piles, target_pile, indices,
reward, done) reward, done, self.eval_cfg)
ordered_dirt_piles = get_ordered_dirt_piles(env, cleaned_dirt_piles, self.eval_cfg, self.n_agents)
# Get transformed next_obs that might have been updated because of handle_dirt # Get transformed next_obs that might have been updated because of handle_dirt
next_obs = transform_observations(env, ordered_dirt_piles, target_pile, self.cfg, self.n_agents) next_obs = transform_observations(env, ordered_dirt_piles, target_pile, self.eval_cfg, self.n_agents)
done = [done] * self.n_agents if isinstance(done, bool) else done done = [done] * self.n_agents if isinstance(done, bool) else done
if self.cfg[nms.ENV][nms.EVAL_RENDER]: env.render() if self.eval_cfg[nms.ENV][nms.EVAL_RENDER]: env.render()
obs = next_obs obs = next_obs
# Count the overall number of cleaned dirt piles in each step
cleaned_piles = 0
for dict in cleaned_dirt_piles:
for value in dict.values():
if value:
cleaned_piles += 1
cleaned_dirt_piles_per_step[-1].append(cleaned_piles)
ep_steps += 1
episode += 1 episode += 1
print(ep_steps)
print(cleaned_dirt_piles_per_step)
if self.eval_cfg[nms.ENV][nms.SAVE_AND_LOG]:
metrics_data = {"cleaned_dirt_piles_per_step": cleaned_dirt_piles_per_step}
with open(f"{self.results_path}/metrics", "wb") as pickle_file:
pickle.dump(metrics_data, pickle_file)
########## Helper functions ######## ########## Helper functions ########
@@ -235,14 +319,18 @@ class A2C:
a.name == nms.USE_DOOR)) a.name == nms.USE_DOOR))
# Don't include action in agent experience # Don't include action in agent experience
else: else:
if det: action.append(int(agent.pi(agent_obs, det=True)[0])) if det:
else: action.append(int(agent.step(agent_obs))) action.append(int(agent.pi(agent_obs, det=True)[0]))
else:
action.append(int(agent.step(agent_obs)))
else: else:
if det: action.append(int(agent.pi(agent_obs, det=True)[0])) if det:
else: action.append(int(agent.step(agent_obs))) action.append(int(agent.pi(agent_obs, det=True)[0]))
else:
action.append(int(agent.step(agent_obs)))
return action return action
def handle_dirt(self, env, cleaned_dirt_piles, ordered_dirt_piles, target_pile, indices, reward, done): def handle_dirt(self, env, cleaned_dirt_piles, ordered_dirt_piles, target_pile, indices, reward, done, cfg):
""" Check if agent moved on field with dirt. If that is the case collect dirt automatically """ """ Check if agent moved on field with dirt. If that is the case collect dirt automatically """
agents_positions = get_agents_positions(env, self.n_agents) agents_positions = get_agents_positions(env, self.n_agents)
dirt_piles_positions = get_dirt_piles_positions(env) dirt_piles_positions = get_dirt_piles_positions(env)
@@ -257,10 +345,10 @@ class A2C:
reward[idx] += 50 reward[idx] += 50
cleaned_dirt_piles[idx][pos] = True cleaned_dirt_piles[idx][pos] = True
# Set pointer to next dirt pile # Set pointer to next dirt pile
update_target_pile(env, idx, target_pile, indices, self.cfg) update_target_pile(env, idx, target_pile, indices, cfg)
update_ordered_dirt_piles(idx, cleaned_dirt_piles, ordered_dirt_piles, env, update_ordered_dirt_piles(idx, cleaned_dirt_piles, ordered_dirt_piles, env,
self.cfg, self.n_agents) cfg, self.n_agents)
if self.cfg[nms.ALGORITHM][nms.PILE_ALL_DONE] == nms.SINGLE: if cfg[nms.ALGORITHM][nms.PILE_ALL_DONE] == nms.SINGLE:
done = True done = True
if all(cleaned_dirt_piles[idx].values()): if all(cleaned_dirt_piles[idx].values()):
# Reset cleaned_dirt_piles indicator # Reset cleaned_dirt_piles indicator
@@ -274,10 +362,10 @@ class A2C:
dirt_at_position = env.state[nms.DIRT_PILES].by_pos(pos) dirt_at_position = env.state[nms.DIRT_PILES].by_pos(pos)
dirt_at_position[0].set_new_amount(0) dirt_at_position[0].set_new_amount(0)
if self.cfg[nms.ALGORITHM][nms.PILE_ALL_DONE] in [nms.ALL, nms.DISTRIBUTED]: if cfg[nms.ALGORITHM][nms.PILE_ALL_DONE] in [nms.ALL, nms.DISTRIBUTED]:
if all([all(cleaned_dirt_piles[i].values()) for i in range(self.n_agents)]): if all([all(cleaned_dirt_piles[i].values()) for i in range(self.n_agents)]):
done = True done = True
elif self.cfg[nms.ALGORITHM][nms.PILE_ALL_DONE] == nms.SHARED: elif cfg[nms.ALGORITHM][nms.PILE_ALL_DONE] == nms.SHARED:
# End episode if both agents together have cleaned all dirt piles # End episode if both agents together have cleaned all dirt piles
if all(get_all_cleaned_dirt_piles(dirt_piles_positions, cleaned_dirt_piles, self.n_agents).values()): if all(get_all_cleaned_dirt_piles(dirt_piles_positions, cleaned_dirt_piles, self.n_agents).values()):
done = True done = True

View File

@@ -35,3 +35,9 @@ class Names:
SINGLE = 'single' SINGLE = 'single'
DISTRIBUTED = 'distributed' DISTRIBUTED = 'distributed'
SHARED = 'shared' SHARED = 'shared'
EARLY_STOPPING = 'early_stopping'
TRAIN = 'train'
SEED = 'seed'
LAST_N_EPISODES = 'last_n_episodes'
MEAN_TARGET_CHANGE = 'mean_target_change'
WANDB_LOG = 'wandb_log'

View File

@@ -2,7 +2,9 @@ import copy
from typing import List from typing import List
import numpy as np import numpy as np
import pandas as pd
import torch import torch
import wandb
from marl_factory_grid.algorithms.rl.base_a2c import cumulate_discount from marl_factory_grid.algorithms.rl.base_a2c import cumulate_discount
from marl_factory_grid.algorithms.rl.constants import Names from marl_factory_grid.algorithms.rl.constants import Names
@@ -331,3 +333,38 @@ def save_agent_models(results_path, agents):
for idx, agent in enumerate(agents): for idx, agent in enumerate(agents):
agent.pi.save_model_parameters(results_path) agent.pi.save_model_parameters(results_path)
agent.vf.save_model_parameters(results_path) agent.vf.save_model_parameters(results_path)
def has_low_change_phase_started(return_change_development, last_n_episodes, mean_target_change):
""" Checks if training has reached a phase with only marginal average change """
if np.mean(np.abs(return_change_development[-last_n_episodes:])) < mean_target_change:
print("Low change phase started.")
return True
return False
def significant_deviation(return_change_development, low_change_phase_start_episode):
""" Determines if a significant return deviation has occurred in the last episode """
return_change_development = return_change_development[low_change_phase_start_episode:]
df = pd.DataFrame({'Episode': range(len(return_change_development)), 'DeltaReturn': return_change_development})
df['Difference'] = df['DeltaReturn'].diff().abs()
# Only the most extreme changes (those that are greater than 99.99% of all changes) will be considered significant
threshold = df['Difference'].quantile(0.9999)
# Identify significant changes
significant_changes = df[df['Difference'] > threshold]
print("Threshold: ", threshold, "Significant changes: ", significant_changes)
if len(significant_changes["Episode"]) > 0:
return True
return False
def log_wandb_training(ep_count, ep_return, ep_return_discounted, ep_return_return_change):
""" Log training step metrics with weights&biases """
wandb.log({f"ep/step": ep_count,
f"ep/return": ep_return,
f"ep/discounted_return": ep_return_discounted,
f"ep/return_change": ep_return_return_change})