Reworked differentiation between train and eval execution + Renamed cfgs + Added algorithm seeding + Included early stopping functionality + Added weights&biases logging

This commit is contained in:
Julian Schönberger
2024-08-09 16:30:04 +02:00
parent 81b12612ed
commit 8e8e925278
3 changed files with 190 additions and 59 deletions

View File

@ -2,7 +2,9 @@ import copy
from typing import List
import numpy as np
import pandas as pd
import torch
import wandb
from marl_factory_grid.algorithms.rl.base_a2c import cumulate_discount
from marl_factory_grid.algorithms.rl.constants import Names
@ -331,3 +333,38 @@ def save_agent_models(results_path, agents):
for idx, agent in enumerate(agents):
agent.pi.save_model_parameters(results_path)
agent.vf.save_model_parameters(results_path)
def has_low_change_phase_started(return_change_development, last_n_episodes, mean_target_change):
""" Checks if training has reached a phase with only marginal average change """
if np.mean(np.abs(return_change_development[-last_n_episodes:])) < mean_target_change:
print("Low change phase started.")
return True
return False
def significant_deviation(return_change_development, low_change_phase_start_episode):
""" Determines if a significant return deviation has occurred in the last episode """
return_change_development = return_change_development[low_change_phase_start_episode:]
df = pd.DataFrame({'Episode': range(len(return_change_development)), 'DeltaReturn': return_change_development})
df['Difference'] = df['DeltaReturn'].diff().abs()
# Only the most extreme changes (those that are greater than 99.99% of all changes) will be considered significant
threshold = df['Difference'].quantile(0.9999)
# Identify significant changes
significant_changes = df[df['Difference'] > threshold]
print("Threshold: ", threshold, "Significant changes: ", significant_changes)
if len(significant_changes["Episode"]) > 0:
return True
return False
def log_wandb_training(ep_count, ep_return, ep_return_discounted, ep_return_return_change):
""" Log training step metrics with weights&biases """
wandb.log({f"ep/step": ep_count,
f"ep/return": ep_return,
f"ep/discounted_return": ep_return_discounted,
f"ep/return_change": ep_return_return_change})