mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-12 07:42:41 +02:00
Reworked differentiation between train and eval execution + Renamed cfgs + Added algorithm seeding + Included early stopping functionality + Added weights&biases logging
This commit is contained in:
@ -2,7 +2,9 @@ import copy
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import wandb
|
||||
|
||||
from marl_factory_grid.algorithms.rl.base_a2c import cumulate_discount
|
||||
from marl_factory_grid.algorithms.rl.constants import Names
|
||||
@ -331,3 +333,38 @@ def save_agent_models(results_path, agents):
|
||||
for idx, agent in enumerate(agents):
|
||||
agent.pi.save_model_parameters(results_path)
|
||||
agent.vf.save_model_parameters(results_path)
|
||||
|
||||
|
||||
def has_low_change_phase_started(return_change_development, last_n_episodes, mean_target_change):
|
||||
""" Checks if training has reached a phase with only marginal average change """
|
||||
if np.mean(np.abs(return_change_development[-last_n_episodes:])) < mean_target_change:
|
||||
print("Low change phase started.")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def significant_deviation(return_change_development, low_change_phase_start_episode):
|
||||
""" Determines if a significant return deviation has occurred in the last episode """
|
||||
return_change_development = return_change_development[low_change_phase_start_episode:]
|
||||
|
||||
df = pd.DataFrame({'Episode': range(len(return_change_development)), 'DeltaReturn': return_change_development})
|
||||
df['Difference'] = df['DeltaReturn'].diff().abs()
|
||||
|
||||
# Only the most extreme changes (those that are greater than 99.99% of all changes) will be considered significant
|
||||
threshold = df['Difference'].quantile(0.9999)
|
||||
|
||||
# Identify significant changes
|
||||
significant_changes = df[df['Difference'] > threshold]
|
||||
print("Threshold: ", threshold, "Significant changes: ", significant_changes)
|
||||
|
||||
if len(significant_changes["Episode"]) > 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def log_wandb_training(ep_count, ep_return, ep_return_discounted, ep_return_return_change):
|
||||
""" Log training step metrics with weights&biases """
|
||||
wandb.log({f"ep/step": ep_count,
|
||||
f"ep/return": ep_return,
|
||||
f"ep/discounted_return": ep_return_discounted,
|
||||
f"ep/return_change": ep_return_return_change})
|
||||
|
Reference in New Issue
Block a user