intermediate backup

This commit is contained in:
2025-05-03 20:46:14 +02:00
parent 2b0a5728d4
commit 6542caf48f
38 changed files with 4513 additions and 1067 deletions

View File

@ -1,30 +1,34 @@
import argparse
import logging
import sys
import os
import random
from pathlib import Path
import time
import json
import numpy as np
import pandas as pd
import torch
import yaml
import pytorch_lightning as pl
from matplotlib import pyplot as plt
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import CSVLogger
from sklearn.preprocessing import StandardScaler, MinMaxScaler
# Import necessary components from your project structure
# Assuming forecasting_model is a package installable or in PYTHONPATH
from forecasting_model.utils.config_model import MainConfig
from forecasting_model.utils.forecast_config_model import MainConfig
from forecasting_model.data_processing import (
load_raw_data,
TimeSeriesCrossValidationSplitter,
prepare_fold_data_and_loaders
)
from forecasting_model.model import LSTMForecastLightningModule
from forecasting_model.train.model import LSTMForecastLightningModule
from forecasting_model.evaluation import evaluate_fold_predictions
from typing import Dict, List, Any, Optional
from forecasting_model.train.ensemble_evaluation import run_ensemble_evaluation
# Import the new classic training function
from forecasting_model.train.classic import run_classic_training
from typing import Dict, List, Optional, Tuple, Union
from forecasting_model.utils.helper import parse_arguments, load_config, set_seeds, aggregate_cv_metrics, save_results
from forecasting_model.io.plotting import plot_loss_curve_from_csv, create_multi_horizon_time_series_plot, save_plot
# Silence overly verbose libraries if needed
mpl_logger = logging.getLogger('matplotlib')
@ -33,396 +37,552 @@ pil_logger = logging.getLogger('PIL')
pil_logger.setLevel(logging.WARNING)
# --- Basic Logging Setup ---
# Configure logging early. Level might be adjusted by config.
# Configure logging early. Level might be adjusted by config later.
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)-7s - %(message)s',
datefmt='%H:%M:%S')
# Get the root logger
logger = logging.getLogger()
# --- Argument Parsing ---
def parse_arguments():
"""Parses command-line arguments."""
parser = argparse.ArgumentParser(
description="Run the Time Series Forecasting training pipeline.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
'-c', '--config',
type=str,
default='config.yaml',
help="Path to the YAML configuration file."
)
parser.add_argument(
'--seed',
type=int,
default=None, # Default to None, use config value if not provided
help="Override random seed defined in config."
)
parser.add_argument(
'--debug',
action='store_true',
help="Override log level to DEBUG."
)
parser.add_argument(
'--output-dir',
type=str,
default='output/cv_results', # Default output base directory
help="Base directory for saving cross-validation results (checkpoints, logs, plots)."
)
args = parser.parse_args()
return args
# --- Helper Functions ---
def load_config(config_path: Path) -> MainConfig:
# --- Single Fold Processing Function ---
def run_single_fold(
fold_num: int,
train_idx: np.ndarray,
val_idx: np.ndarray,
test_idx: np.ndarray,
config: MainConfig,
full_df: pd.DataFrame,
output_base_dir: Path # Receives Path object from run_training_pipeline
) -> Tuple[Dict[str, float], Optional[float], Optional[Path], Optional[Path], Optional[Path], Optional[Path]]:
"""
Load and validate configuration from YAML file using Pydantic.
Runs the pipeline for a single cross-validation fold.
Args:
config_path: Path to the YAML configuration file.
fold_num: The zero-based index of the current fold.
train_idx: Indices for the training set.
val_idx: Indices for the validation set.
test_idx: Indices for the test set.
config: The main configuration object.
full_df: The complete raw DataFrame.
output_base_dir: The base directory Path for saving results.
Returns:
Validated MainConfig object.
Raises:
FileNotFoundError: If the config file doesn't exist.
yaml.YAMLError: If the file is not valid YAML.
pydantic.ValidationError: If the config doesn't match the schema.
A tuple containing:
- fold_metrics: Dictionary of test metrics for the fold (e.g., {'MAE': ..., 'RMSE': ...}).
- best_val_score: The best validation score achieved during training (or None).
- saved_model_path: Path to the best saved model checkpoint (or None).
- saved_target_scaler_path: Path to the saved target scaler (or None).
- saved_input_size_path: Path to the saved input size file (or None).
- saved_config_path: Path to the saved config file for this fold (or None).
"""
if not config_path.is_file():
logger.error(f"Configuration file not found at: {config_path}")
raise FileNotFoundError(f"Config file not found: {config_path}")
fold_start_time = time.perf_counter()
fold_id = fold_num + 1 # User-facing fold number (1-based)
logger.info(f"--- Starting Fold {fold_id}/{config.cross_validation.n_splits} ---")
fold_output_dir = output_base_dir / f"fold_{fold_id:02d}"
fold_output_dir.mkdir(parents=True, exist_ok=True)
logger.debug(f"Fold output directory: {fold_output_dir}")
fold_metrics: Dict[str, float] = {'MAE': np.nan, 'RMSE': np.nan} # Default in case of failure
best_val_score: Optional[float] = None
best_model_path_str: Optional[str] = None # Use a different name for the string from callback
# Variables to hold prediction results for plotting later
all_preds_scaled: Optional[np.ndarray] = None
all_targets_scaled: Optional[np.ndarray] = None
target_scaler: Optional[Union[StandardScaler, MinMaxScaler]] = None # Need to keep scaler reference
prediction_target_time_index_h1: Optional[pd.DatetimeIndex] = None
# Variables to store paths of saved artifacts
saved_model_path: Optional[Path] = None
saved_target_scaler_path: Optional[Path] = None
saved_input_size_path: Optional[Path] = None
saved_config_path: Optional[Path] = None
logger.info(f"Loading configuration from: {config_path}")
try:
with open(config_path, 'r') as f:
config_dict = yaml.safe_load(f)
# --- Per-Fold Data Preparation ---
logger.info("Preparing data loaders for the fold...")
# Keep scaler and input_size references returned by prepare_fold_data_and_loaders
train_loader, val_loader, test_loader, target_scaler_fold, input_size = prepare_fold_data_and_loaders( # Renamed target_scaler
full_df=full_df,
train_idx=train_idx,
val_idx=val_idx,
test_idx=test_idx,
target_col=config.data.target_col, # Pass target col name explicitly
feature_config=config.features,
train_config=config.training,
eval_config=config.evaluation
)
target_scaler = target_scaler_fold # Store the scaler in the outer scope
logger.info(f"Data loaders prepared. Input size determined: {input_size}")
# Validate configuration using Pydantic model
config = MainConfig(**config_dict)
logger.info("Configuration loaded and validated successfully.")
return config
except yaml.YAMLError as e:
logger.error(f"Error parsing YAML file {config_path}: {e}", exc_info=True)
raise
except Exception as e: # Catches Pydantic validation errors too
logger.error(f"Error validating configuration {config_path}: {e}", exc_info=True)
raise
# Save necessary items for potential later use (e.g., ensemble)
# Capture the paths when saving
saved_target_scaler_path = fold_output_dir / "target_scaler.pt"
torch.save(target_scaler, saved_target_scaler_path)
torch.save(test_loader, fold_output_dir / "test_loader.pt") # Test loader might be large, consider if needed
def set_seeds(seed: Optional[int] = 42) -> None:
"""
Set random seeds for reproducibility across libraries.
# Save input size and capture path
saved_input_size_path = fold_output_dir / "input_size.pt"
torch.save(input_size, saved_input_size_path)
Args:
seed: The seed value to use. If None, uses default 42.
"""
if seed is None:
seed = 42
logger.warning(f"No seed provided, using default seed: {seed}")
else:
logger.info(f"Setting random seed: {seed}")
# Save config for this fold (needed for reloading model) and capture path
config_dump = config.model_dump()
saved_config_path = fold_output_dir / "config.yaml" # Capture the path before saving
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# Ensure reproducibility for CUDA operations where possible
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # For multi-GPU
# These settings can slow down training but improve reproducibility
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# PyTorch Lightning seeding (optional, as we seed torch directly)
# pl.seed_everything(seed, workers=True) # workers=True ensures dataloader reproducibility
with open(saved_config_path, 'w') as f:
yaml.dump(config_dump, f, default_flow_style=False)
def aggregate_cv_metrics(all_fold_metrics: List[Dict[str, float]]) -> Dict[str, Dict[str, float]]:
"""
Calculate mean and standard deviation of metrics across folds.
Handles potential NaN values by ignoring them.
Args:
all_fold_metrics: A list where each element is a dictionary of
metrics for one fold (e.g., {'MAE': v1, 'RMSE': v2}).
# --- Model Initialization ---
model = LSTMForecastLightningModule(
model_config=config.model,
train_config=config.training,
input_size=input_size,
target_scaler=target_scaler_fold # Pass scaler during init
)
logger.info("LSTMForecastLightningModule initialized.")
Returns:
A dictionary where keys are metric names and values are dicts
containing 'mean' and 'std' for that metric across folds.
Example: {'MAE': {'mean': m, 'std': s}, 'RMSE': {'mean': m2, 'std': s2}}
"""
if not all_fold_metrics:
logger.warning("Received empty list for metric aggregation.")
return {}
# --- PyTorch Lightning Callbacks ---
# Ensure monitor_metric matches the exact name logged in model.py
monitor_metric = "val_MeanAbsoluteError_Original_Scale" # Corrected metric name
monitor_mode = "min"
aggregated: Dict[str, Dict[str, float]] = {}
# Get metric names from the first valid fold's results
first_valid_metrics = next((m for m in all_fold_metrics if m), None)
if not first_valid_metrics:
logger.warning("No valid fold metrics found for aggregation.")
return {}
metric_names = list(first_valid_metrics.keys())
early_stop_callback = None
if config.training.early_stopping_patience is not None and config.training.early_stopping_patience > 0:
early_stop_callback = EarlyStopping(
monitor=monitor_metric,
min_delta=0.0001,
patience=config.training.early_stopping_patience,
verbose=True,
mode=monitor_mode
)
logger.info(f"Enabled EarlyStopping: monitor='{monitor_metric}', patience={config.training.early_stopping_patience}")
for metric in metric_names:
# Collect values for this metric across all folds, ignoring NaNs
values = [fold_metrics.get(metric) for fold_metrics in all_fold_metrics if fold_metrics and metric in fold_metrics]
valid_values = [v for v in values if v is not None and not np.isnan(v)]
checkpoint_callback = ModelCheckpoint(
dirpath=fold_output_dir / "checkpoints",
filename=f"best_model_fold_{fold_id}",
save_top_k=1,
monitor=monitor_metric,
mode=monitor_mode,
verbose=True
)
logger.info(f"Enabled ModelCheckpoint: monitor='{monitor_metric}', mode='{monitor_mode}'")
lr_monitor = LearningRateMonitor(logging_interval='epoch')
callbacks = [checkpoint_callback, lr_monitor]
if early_stop_callback:
callbacks.append(early_stop_callback)
# --- PyTorch Lightning Logger ---
# Log to a subdir specific to the fold, relative to output_base_dir
log_dir = output_base_dir / f"fold_{fold_id:02d}" / "training_logs"
pl_logger = CSVLogger(save_dir=str(log_dir.parent), name=log_dir.name, version='') # Use name for subdir, version='' to avoid 'version_0'
logger.info(f"Using CSVLogger, logs will be saved in: {pl_logger.log_dir}")
# --- PyTorch Lightning Trainer ---
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
devices = 1 if accelerator == 'gpu' else None
precision = getattr(config.training, 'precision', 32)
trainer = pl.Trainer(
accelerator=accelerator,
devices=devices,
max_epochs=config.training.epochs,
callbacks=callbacks,
logger=pl_logger,
log_every_n_steps=max(1, len(train_loader)//10),
enable_progress_bar=True,
gradient_clip_val=getattr(config.training, 'gradient_clip_val', None),
precision=precision,
)
logger.info(f"Initialized PyTorch Lightning Trainer: accelerator='{accelerator}', devices={devices}, precision={precision}")
# --- Training ---
logger.info(f"Starting training for Fold {fold_id}...")
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
logger.info(f"Training finished for Fold {fold_id}.")
# Store best validation score and path
best_val_score_tensor = trainer.checkpoint_callback.best_model_score
# Capture the best model path reported by the checkpoint callback
best_model_path_str = trainer.checkpoint_callback.best_model_path # Capture the string path
best_val_score = best_val_score_tensor.item() if best_val_score_tensor is not None else None
if best_val_score is not None:
logger.info(f"Best validation score ({monitor_metric}) for Fold {fold_id}: {best_val_score:.4f}")
# Check if best_model_path was actually set by the callback
if best_model_path_str:
saved_model_path = Path(best_model_path_str) # Convert string to Path object and store
logger.info(f"Best model checkpoint path: {best_model_path_str}")
else:
logger.warning(f"ModelCheckpoint callback did not report a best_model_path for Fold {fold_id}.")
if not valid_values:
logger.warning(f"No valid values found for metric '{metric}' across folds.")
mean_val = np.nan
std_val = np.nan
else:
mean_val = float(np.mean(valid_values))
std_val = float(np.std(valid_values))
logger.debug(f"Aggregated '{metric}': Mean={mean_val:.4f}, Std={std_val:.4f} from {len(valid_values)} folds.")
logger.warning(f"Could not retrieve best validation score/path for Fold {fold_id} (metric: {monitor_metric}). Evaluation might use last model.")
best_model_path_str = None # Ensure string path is None if no best score
aggregated[metric] = {'mean': mean_val, 'std': std_val}
return aggregated
# --- Prediction on Test Set ---
logger.info(f"Starting prediction for Fold {fold_id} using {'best checkpoint' if saved_model_path else 'last model'}...")
# Use the best checkpoint path if available, otherwise use the in-memory model instance
ckpt_path_for_predict = str(saved_model_path) if saved_model_path else None # Use the saved Path object, convert to string for ckpt_path
prediction_results_list = trainer.predict(
model=model, # Use the in-memory model instance
dataloaders=test_loader,
ckpt_path=ckpt_path_for_predict # Specify checkpoint path if needed, though using model=model is typical
)
# --- Process Prediction Results & Get Time Index ---
if not prediction_results_list:
logger.error(f"Predict phase did not return any results for Fold {fold_id}. Check predict_step and logs.")
all_preds_scaled = None # Ensure these are None on failure
all_targets_scaled = None
else:
try:
all_preds_scaled = torch.cat([batch_res['preds_scaled'] for batch_res in prediction_results_list], dim=0).numpy()
n_predictions = len(all_preds_scaled)
if 'targets_scaled' in prediction_results_list[0]:
all_targets_scaled = torch.cat([batch_res['targets_scaled'] for batch_res in prediction_results_list], dim=0).numpy()
if len(all_targets_scaled) != n_predictions:
logger.error(f"Fold {fold_id}: Mismatch between number of predictions ({n_predictions}) and targets ({len(all_targets_scaled)}).")
raise ValueError("Prediction and target count mismatch during evaluation.")
else:
logger.error(f"Targets not found in prediction results for Fold {fold_id}. Cannot evaluate or plot original scale targets.")
all_targets_scaled = None
logger.info(f"Processing {n_predictions} prediction results for Fold {fold_id}...")
# --- Calculate Correct Time Index for Plotting (First Horizon) ---
prediction_target_time_index_h1_path = fold_output_dir / "prediction_target_time_index_h1.pt"
prediction_target_time_index_h1 = None
if test_idx is not None and config.features.forecast_horizon and len(config.features.forecast_horizon) > 0:
try:
test_block_index = full_df.index[test_idx]
seq_len = config.features.sequence_length
first_horizon = config.features.forecast_horizon[0]
target_indices_h1 = test_idx + seq_len + first_horizon - 1
valid_target_indices_h1_mask = target_indices_h1 < len(full_df)
valid_target_indices_h1 = target_indices_h1[valid_target_indices_h1_mask]
if len(valid_target_indices_h1) >= n_predictions: # Should be exactly n_predictions if no indices were out of bounds
prediction_target_time_index_h1 = full_df.index[valid_target_indices_h1[:n_predictions]]
if len(prediction_target_time_index_h1) != n_predictions:
logger.warning(f"Fold {fold_id}: Calculated target time index length ({len(prediction_target_time_index_h1)}) "
f"does not match prediction count ({n_predictions}). Plotting x-axis might be misaligned.")
prediction_target_time_index_h1 = None
else:
logger.warning(f"Fold {fold_id}: Cannot calculate target time index for h1; insufficient valid indices ({len(valid_target_indices_h1)} < {n_predictions}).")
prediction_target_time_index_h1 = None
# Save the calculated index if it's valid and evaluation plots are enabled
if prediction_target_time_index_h1 is not None and not prediction_target_time_index_h1.empty and config.evaluation.save_plots:
try:
torch.save(prediction_target_time_index_h1, prediction_target_time_index_h1_path)
logger.debug(f"Saved prediction target time index for h1 to {prediction_target_time_index_h1_path}")
except Exception as save_e:
logger.warning(f"Failed to save prediction target time index file {prediction_target_time_index_h1_path}: {save_e}")
elif prediction_target_time_index_h1_path.exists():
try:
prediction_target_time_index_h1_path.unlink()
logger.debug("Removed outdated prediction target time index h1 file.")
except OSError as e:
logger.warning(f"Could not remove outdated prediction target index h1 file {prediction_target_time_index_h1_path}: {e}")
except Exception as e:
logger.error(f"Fold {fold_id}: Error calculating or saving target time index for plotting: {e}", exc_info=True)
prediction_target_time_index_h1 = None
else:
logger.warning(f"Fold {fold_id}: Skipping target time index calculation (missing test_idx, forecast_horizon, or empty list).")
if prediction_target_time_index_h1_path.exists():
try:
prediction_target_time_index_h1_path.unlink()
logger.debug("Removed outdated prediction target time index h1 file as calculation was skipped.")
except OSError as e:
logger.warning(f"Could not remove outdated prediction target index h1 file {prediction_target_time_index_h1_path}: {e}")
# --- End Index Calculation and Saving ---
# --- Evaluation ---
if all_targets_scaled is not None: # Only evaluate if targets are available
fold_metrics = evaluate_fold_predictions(
y_true_scaled=all_targets_scaled, # Pass the (N, H) array
y_pred_scaled=all_preds_scaled, # Pass the (N, H) array
target_scaler=target_scaler,
eval_config=config.evaluation,
fold_num=fold_num, # Pass zero-based index
output_dir=str(fold_output_dir),
plot_subdir="plots",
# Pass the calculated index for the targets being plotted (h1 reference)
prediction_time_index=prediction_target_time_index_h1, # Use the calculated index here (for h1)
forecast_horizons=config.features.forecast_horizon, # Pass the list of horizons
plot_title_prefix=f"CV Fold {fold_id}"
)
save_results(fold_metrics, fold_output_dir / "test_metrics.json")
else:
logger.error(f"Skipping evaluation for Fold {fold_id} due to missing targets.")
# --- Multi-Horizon Plotting ---
if config.evaluation.save_plots and all_preds_scaled is not None and all_targets_scaled is not None and prediction_target_time_index_h1 is not None and target_scaler is not None:
logger.info(f"Generating multi-horizon plot for Fold {fold_id}...")
try:
multi_horizon_plot_path = fold_output_dir / "plots" / "multi_horizon_forecast.png"
# Need to import save_plot function if it's not already imported
# from forecasting_model.io.plotting import save_plot # Ensure this import is present if needed
fig = create_multi_horizon_time_series_plot(
y_true_scaled_all_horizons=all_targets_scaled,
y_pred_scaled_all_horizons=all_preds_scaled,
target_scaler=target_scaler,
prediction_time_index_h1=prediction_target_time_index_h1,
forecast_horizons=config.features.forecast_horizon,
title=f"Fold {fold_id} Multi-Horizon Forecast",
max_points=1000 # Limit points for clarity
)
# Check if save_plot is available or use fig.savefig()
try:
save_plot(fig, multi_horizon_plot_path)
except NameError:
# Fallback if save_plot is not defined/imported
fig.savefig(multi_horizon_plot_path)
plt.close(fig) # Close the figure after saving
logger.warning("Using fig.savefig as save_plot function was not found.")
except Exception as plot_e:
logger.error(f"Fold {fold_id}: Failed to generate multi-horizon plot: {plot_e}", exc_info=True)
elif config.evaluation.save_plots:
logger.warning(f"Fold {fold_id}: Skipping multi-horizon plot due to missing data (preds, targets, time index, or scaler).")
except KeyError as e:
logger.error(f"KeyError processing prediction results for Fold {fold_id}: Missing key {e}. Check predict_step return format.", exc_info=True)
except ValueError as e: # Catch specific error from above
logger.error(f"ValueError processing prediction results for Fold {fold_id}: {e}", exc_info=True)
except Exception as e:
logger.error(f"Error processing prediction results for Fold {fold_id}: {e}", exc_info=True)
# --- Plot Loss Curve for Fold ---
try:
actual_log_dir = Path(pl_logger.log_dir) / pl_logger.name # Should be .../fold_XX/training_logs
metrics_file_path = actual_log_dir / "metrics.csv"
if metrics_file_path.is_file():
plot_loss_curve_from_csv(
metrics_csv_path=metrics_file_path,
output_path=fold_output_dir / "plots" / "loss_curve.png", # Save in plots subdir
title=f"Fold {fold_id} Training Progression",
train_loss_col='train_loss',
val_loss_col='val_loss' # This function handles fallback
)
logger.info(f"Loss curve plot saved for Fold {fold_id} to {fold_output_dir / 'plots' / 'loss_curve.png'}.")
else:
logger.warning(f"Fold {fold_id}: Could not find metrics.csv at {metrics_file_path} for loss curve plot.")
except Exception as e:
logger.error(f"Fold {fold_id}: Failed to generate loss curve plot: {e}", exc_info=True)
# --- End Loss Curve Plotting ---
def save_results(results: Dict, filename: Path):
"""Save dictionary results to a JSON file."""
try:
filename.parent.mkdir(parents=True, exist_ok=True)
with open(filename, 'w') as f:
json.dump(results, f, indent=4)
logger.info(f"Saved results to {filename}")
except Exception as e:
logger.error(f"Failed to save results to {filename}: {e}", exc_info=True)
logger.error(f"An error occurred during Fold {fold_id} pipeline: {e}", exc_info=True)
# Ensure paths are None if an error occurs before they are set
if saved_model_path is None: saved_model_path = None
if saved_target_scaler_path is None: saved_target_scaler_path = None
finally:
# Clean up GPU memory explicitly
del model, trainer # Ensure objects are deleted before clearing cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.debug("Cleared CUDA cache.")
# Delete loaders explicitly if they might hold references
del train_loader, val_loader, test_loader
fold_end_time = time.perf_counter()
logger.info(f"--- Finished Fold {fold_id} in {fold_end_time - fold_start_time:.2f} seconds ---")
# Return the calculated fold metrics, best validation score, and saved artifact paths
return fold_metrics, best_val_score, saved_model_path, saved_target_scaler_path, saved_input_size_path, saved_config_path
# --- Main Training & Evaluation Function ---
def run_training_pipeline(config: MainConfig, output_base_dir: Path):
"""Runs the full cross-validation training and evaluation pipeline."""
"""Runs the full training and evaluation pipeline based on config flags."""
start_time = time.perf_counter()
logger.info(f"Starting training pipeline. Results will be saved to: {output_base_dir}")
output_base_dir.mkdir(parents=True, exist_ok=True)
# --- Data Loading ---
try:
df = load_raw_data(config.data)
except Exception as e:
logger.critical(f"Failed to load raw data: {e}", exc_info=True)
sys.exit(1) # Cannot proceed without data
# --- Cross-Validation Setup ---
try:
cv_splitter = TimeSeriesCrossValidationSplitter(config.cross_validation, len(df))
except ValueError as e:
logger.critical(f"Failed to initialize CV splitter: {e}", exc_info=True)
sys.exit(1)
sys.exit(1)
# --- Initialize results ---
all_fold_test_metrics: List[Dict[str, float]] = []
all_fold_best_val_scores: Dict[int, Optional[float]] = {} # Store best val score per fold
all_fold_best_val_scores: Dict[int, Optional[float]] = {}
aggregated_metrics: Dict = {}
final_results: Dict = {} # Initialize empty results dict
# --- Cross-Validation Loop ---
logger.info(f"Starting {config.cross_validation.n_splits}-Fold Cross-Validation...")
for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split()):
fold_start_time = time.perf_counter()
fold_id = fold_num + 1
logger.info(f"--- Starting Fold {fold_id}/{config.cross_validation.n_splits} ---")
fold_output_dir = output_base_dir / f"fold_{fold_id:02d}"
fold_output_dir.mkdir(parents=True, exist_ok=True)
logger.debug(f"Fold output directory: {fold_output_dir}")
if config.run_cross_validation:
logger.info(f"Starting {config.cross_validation.n_splits}-Fold Cross-Validation...")
try:
# --- Per-Fold Data Preparation ---
logger.info("Preparing data loaders for the fold...")
train_loader, val_loader, test_loader, target_scaler, input_size = prepare_fold_data_and_loaders(
full_df=df,
cv_splitter = TimeSeriesCrossValidationSplitter(config.cross_validation, len(df))
except ValueError as e:
logger.critical(f"Failed to initialize CV splitter: {e}", exc_info=True)
sys.exit(1)
for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split()):
# Unpack the two new return values from run_single_fold
fold_metrics, best_val_score, saved_model_path, saved_target_scaler_path, _input_size_path, _config_path = run_single_fold(
fold_num=fold_num,
train_idx=train_idx,
val_idx=val_idx,
test_idx=test_idx,
target_col=config.data.target_col, # Pass target col name explicitly
feature_config=config.features,
train_config=config.training,
eval_config=config.evaluation
config=config,
full_df=df,
output_base_dir=output_base_dir
)
logger.info(f"Data loaders prepared. Input size determined: {input_size}")
# --- Model Initialization ---
# Pass input_size directly, ModelConfig no longer holds it.
# Ensure forecast horizon is consistent (checked in MainConfig validation)
current_model_config = config.model # Use the validated model config
model = LSTMForecastLightningModule(
model_config=current_model_config, # Does not contain input_size
train_config=config.training,
input_size=input_size, # Pass the dynamically determined input_size
target_scaler=target_scaler # Pass the fold-specific scaler
)
logger.info("LSTMForecastLightningModule initialized.")
# --- PyTorch Lightning Callbacks ---
# Monitor the validation MAE on the original scale (logged by LightningModule)
monitor_metric = "val_mae_orig_scale"
monitor_mode = "min"
early_stop_callback = None
if config.training.early_stopping_patience is not None and config.training.early_stopping_patience > 0:
early_stop_callback = EarlyStopping(
monitor=monitor_metric,
min_delta=0.0001, # Minimum change to qualify as improvement
patience=config.training.early_stopping_patience,
verbose=True,
mode=monitor_mode
)
logger.info(f"Enabled EarlyStopping: monitor='{monitor_metric}', patience={config.training.early_stopping_patience}")
# Checkpoint callback to save the best model based on validation metric
checkpoint_callback = ModelCheckpoint(
dirpath=fold_output_dir / "checkpoints",
filename=f"best_model_fold_{fold_id}", # {{epoch}}-{{val_loss:.2f}} etc. possible
save_top_k=1,
monitor=monitor_metric,
mode=monitor_mode,
verbose=True
)
logger.info(f"Enabled ModelCheckpoint: monitor='{monitor_metric}', mode='{monitor_mode}'")
# Learning rate monitor callback
lr_monitor = LearningRateMonitor(logging_interval='epoch')
callbacks = [checkpoint_callback, lr_monitor]
if early_stop_callback:
callbacks.append(early_stop_callback)
# --- PyTorch Lightning Logger ---
# Log metrics to a CSV file within the fold directory
pl_logger = CSVLogger(save_dir=str(output_base_dir), name=f"fold_{fold_id:02d}", version='logs')
logger.info(f"Using CSVLogger, logs will be saved in: {pl_logger.log_dir}")
# --- PyTorch Lightning Trainer ---
# Determine accelerator and devices based on PyTorch check
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
devices = 1 if accelerator == 'gpu' else None # Or specify specific GPU IDs [0], [1] etc.
precision = getattr(config.training, 'precision', 32) # Default to 32-bit
trainer = pl.Trainer(
accelerator=accelerator,
devices=devices,
max_epochs=config.training.epochs,
callbacks=callbacks,
logger=pl_logger,
log_every_n_steps=max(1, len(train_loader)//10), # Log ~10 times per epoch
enable_progress_bar=True, # Set to False for less verbose runs (e.g., HPO)
gradient_clip_val=getattr(config.training, 'gradient_clip_val', None),
precision=precision,
# deterministic=True, # For stricter reproducibility (can slow down)
)
logger.info(f"Initialized PyTorch Lightning Trainer: accelerator='{accelerator}', devices={devices}, precision={precision}")
# --- Training ---
logger.info(f"Starting training for Fold {fold_id}...")
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
logger.info(f"Training finished for Fold {fold_id}.")
# Store best validation score for this fold
best_val_score = trainer.checkpoint_callback.best_model_score
best_model_path = trainer.checkpoint_callback.best_model_path
all_fold_best_val_scores[fold_id] = best_val_score.item() if best_val_score else None
if best_val_score is not None:
logger.info(f"Best validation score ({monitor_metric}) for Fold {fold_id}: {all_fold_best_val_scores[fold_id]:.4f}")
logger.info(f"Best model checkpoint path: {best_model_path}")
else:
logger.warning(f"Could not retrieve best validation score/path for Fold {fold_id} (metric: {monitor_metric}). Evaluation might use last model.")
best_model_path = None # Ensure evaluation doesn't try to load 'best' if checkpointing failed
# --- Prediction on Test Set ---
# Use trainer.predict() to get model outputs
logger.info(f"Starting prediction for Fold {fold_id} using best checkpoint...")
# predict_step returns dict {'preds_scaled': ..., 'targets_scaled': ...}
# We pass the test_loader here, which yields (x, y) pairs, so predict_step will include targets
prediction_results_list = trainer.predict(
# model=model, # Not needed if using ckpt_path
ckpt_path=best_model_path if best_model_path else 'last', # Load best model or last if best failed
dataloaders=test_loader
# return_predictions=True # Default is True
)
# Check if prediction returned results
if not prediction_results_list:
logger.error(f"Predict phase did not return any results for Fold {fold_id}. Check predict_step and logs.")
fold_metrics = {'MAE': np.nan, 'RMSE': np.nan}
else:
try:
# Concatenate predictions and targets from predict_step results
all_preds_scaled = torch.cat([batch_res['preds_scaled'] for batch_res in prediction_results_list], dim=0).numpy()
# Check if targets were included (they should be if using test_loader)
if 'targets_scaled' in prediction_results_list[0]:
all_targets_scaled = torch.cat([batch_res['targets_scaled'] for batch_res in prediction_results_list], dim=0).numpy()
else:
# This case shouldn't happen if using test_loader, but good safeguard
logger.error(f"Targets not found in prediction results for Fold {fold_id}. Cannot evaluate.")
raise ValueError("Targets missing from prediction results.")
# --- Final Evaluation & Plotting ---
logger.info(f"Processing prediction results for Fold {fold_id}...")
fold_metrics = evaluate_fold_predictions(
y_true_scaled=all_targets_scaled,
y_pred_scaled=all_preds_scaled,
target_scaler=target_scaler, # Use the scaler from this fold
eval_config=config.evaluation,
fold_num=fold_num, # Pass zero-based index
output_dir=output_base_dir, # Base dir for saving plots etc.
# time_index=df.iloc[test_idx].index # Pass time index if needed
)
# Save fold metrics
save_results(fold_metrics, fold_output_dir / "test_metrics.json")
except KeyError as e:
logger.error(f"KeyError processing prediction results for Fold {fold_id}: Missing key {e}. Check predict_step return format.", exc_info=True)
fold_metrics = {'MAE': np.nan, 'RMSE': np.nan}
except Exception as e:
logger.error(f"Error processing prediction results for Fold {fold_id}: {e}", exc_info=True)
fold_metrics = {'MAE': np.nan, 'RMSE': np.nan}
all_fold_test_metrics.append(fold_metrics)
all_fold_best_val_scores[fold_num + 1] = best_val_score
# --- (Optional) Log final test metrics using trainer.test() ---
# If you want the metrics logged by test_step aggregated, call test now.
# logger.info(f"Logging final test metrics via trainer.test() for Fold {fold_id}...")
# try:
# trainer.test(ckpt_path=best_model_path if best_model_path else 'last', dataloaders=test_loader, verbose=False)
# except Exception as e:
# logger.warning(f"trainer.test() call failed for Fold {fold_id}: {e}")
# --- Aggregation and Reporting for CV ---
logger.info("Cross-validation finished. Aggregating results...")
aggregated_metrics = aggregate_cv_metrics(all_fold_test_metrics)
final_results['aggregated_test_metrics'] = aggregated_metrics
final_results['per_fold_test_metrics'] = all_fold_test_metrics
final_results['per_fold_best_val_scores'] = all_fold_best_val_scores
# Save intermediate results after CV
save_results(final_results, output_base_dir / "aggregated_cv_results.json")
except Exception as e:
# Catch errors during the fold processing (data prep, training, prediction, eval)
logger.error(f"An error occurred during Fold {fold_id} pipeline: {e}", exc_info=True)
all_fold_test_metrics.append({'MAE': np.nan, 'RMSE': np.nan})
# --- Cleanup per fold ---
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.debug("Cleared CUDA cache.")
fold_end_time = time.perf_counter()
logger.info(f"--- Finished Fold {fold_id} in {fold_end_time - fold_start_time:.2f} seconds ---")
# --- Aggregation and Final Reporting ---
logger.info("Cross-validation finished. Aggregating results...")
aggregated_metrics = aggregate_cv_metrics(all_fold_test_metrics)
# Save aggregated results
final_results = {
'aggregated_test_metrics': aggregated_metrics,
'per_fold_test_metrics': all_fold_test_metrics,
'per_fold_best_val_scores': all_fold_best_val_scores,
}
save_results(final_results, output_base_dir / "aggregated_cv_results.json")
# Log final results
logger.info("--- Aggregated Cross-Validation Test Results ---")
if aggregated_metrics:
for metric, stats in aggregated_metrics.items():
logger.info(f"{metric}: {stats['mean']:.4f} ± {stats['std']:.4f}")
else:
logger.warning("No metrics available for aggregation.")
logger.info("Skipping Cross-Validation loop as per config.")
# --- Ensemble Evaluation ---
if config.run_ensemble_evaluation:
# The validator in MainConfig already ensures run_cross_validation is also true here
logger.info("Starting ensemble evaluation...")
try:
ensemble_results = run_ensemble_evaluation(
config=config, # Pass config for context if needed by sub-functions
output_base_dir=output_base_dir
)
if ensemble_results:
logger.info("Ensemble evaluation completed successfully")
final_results['ensemble_results'] = ensemble_results
save_results(final_results, output_base_dir / "aggregated_cv_results.json")
else:
logger.warning("No ensemble results were generated (potentially < 2 folds available).")
except Exception as e:
logger.error(f"Error during ensemble evaluation: {e}", exc_info=True)
else:
logger.info("Skipping Ensemble evaluation as per config.")
# --- Classic Training Run ---
if config.run_classic_training:
logger.info("Starting classic training run...")
classic_output_dir = output_base_dir / "classic_run" # Define dir for logging path
try:
# Call the original classic training function directly
classic_metrics = run_classic_training(
config=config,
full_df=df,
output_base_dir=output_base_dir # It creates classic_run subdir internally
)
if classic_metrics:
logger.info(f"Classic training run completed. Test Metrics: {classic_metrics}")
final_results['classic_training_results'] = classic_metrics
save_results(final_results, output_base_dir / "aggregated_cv_results.json")
# --- Plot Loss Curve for Classic Run ---
try:
classic_log_dir = classic_output_dir / "training_logs"
metrics_file = classic_log_dir / "metrics.csv"
version_dirs = list(classic_log_dir.glob("version_*"))
if version_dirs:
metrics_file = version_dirs[0] / "metrics.csv"
if metrics_file.is_file():
plot_loss_curve_from_csv(
metrics_csv_path=metrics_file,
output_path=classic_output_dir / "loss_curve.png",
title="Classic Run Training Progression",
train_loss_col='train_loss', # Changed from 'train_loss_epoch'
val_loss_col='val_loss' # Check your logged metric names
)
else:
logger.warning(f"Classic Run: Could not find metrics.csv at {metrics_file} for loss curve plot.")
except Exception as plot_e:
logger.error(f"Classic Run: Failed to generate loss curve plot: {plot_e}", exc_info=True)
# --- End Classic Loss Plotting ---
else:
logger.warning("Classic training run did not produce metrics.")
except Exception as e:
logger.error(f"Error during classic training run: {e}", exc_info=True)
else:
logger.info("Skipping Classic training run as per config.")
# --- Final Logging Summary ---
logger.info("--- Final Summary ---")
# Log aggregated CV results if they exist
if 'aggregated_test_metrics' in final_results and final_results['aggregated_test_metrics']:
logger.info("--- Aggregated Cross-Validation Test Results ---")
for metric, stats in final_results['aggregated_test_metrics'].items():
logger.info(f"{metric}: {stats.get('mean', np.nan):.4f} ± {stats.get('std', np.nan):.4f}")
elif config.run_cross_validation:
logger.warning("Cross-validation was run, but no metrics were aggregated.")
# Log aggregated ensemble results if they exist
if 'ensemble_results' in final_results and final_results['ensemble_results']:
logger.info("--- Aggregated Ensemble Test Results (Mean over Test Folds) ---")
agg_ensemble = {}
for fold_res in final_results['ensemble_results'].values():
if isinstance(fold_res, dict):
for method, metrics in fold_res.items():
if method not in agg_ensemble: agg_ensemble[method] = {}
if isinstance(metrics, dict):
for m_name, m_val in metrics.items():
if m_name not in agg_ensemble[method]: agg_ensemble[method][m_name] = []
agg_ensemble[method][m_name].append(m_val)
else: logger.warning(f"Skipping non-dict metrics for ensemble method '{method}'.")
else: logger.warning("Skipping non-dict fold result in ensemble aggregation.")
for method, metrics_data in agg_ensemble.items():
logger.info(f" Ensemble Method: {method}")
for m_name, values in metrics_data.items():
valid_vals = [v for v in values if v is not None and not np.isnan(v)]
if valid_vals: logger.info(f" {m_name}: {np.mean(valid_vals):.4f} ± {np.std(valid_vals):.4f}")
else: logger.info(f" {m_name}: N/A")
# Log classic results if they exist
if 'classic_training_results' in final_results and final_results['classic_training_results']:
logger.info("--- Classic Training Test Results ---")
classic_res = final_results['classic_training_results']
for metric, value in classic_res.items():
logger.info(f"{metric}: {value:.4f}")
logger.info("-------------------------------------------------")
end_time = time.perf_counter()
@ -434,12 +594,6 @@ def run():
"""Main execution function."""
args = parse_arguments()
config_path = Path(args.config)
output_dir = Path(args.output_dir)
# Adjust log level if debug flag is set
if args.debug:
logger.setLevel(logging.DEBUG)
logger.debug("# --- Debug mode enabled. --- #")
# --- Configuration Loading ---
try:
@ -448,10 +602,20 @@ def run():
# Error already logged in load_config
sys.exit(1)
# --- Seed Setting ---
# Use command-line seed if provided, otherwise use config seed
seed = args.seed if args.seed is not None else getattr(config, 'random_seed', 42)
set_seeds(seed)
# --- Setup based on Config ---
# 1. Set Log Level
log_level_name = config.log_level.upper()
log_level = getattr(logging, log_level_name, logging.INFO)
logger.setLevel(log_level)
logger.info(f"Log level set to: {log_level_name}")
if log_level == logging.DEBUG:
logger.debug("# --- Debug mode enabled via config. --- #")
# 2. Set Seed
set_seeds(config.random_seed)
# 3. Determine Output Directory
output_dir = Path(config.output_dir)
# --- Pipeline Execution ---
try:
@ -459,7 +623,7 @@ def run():
except SystemExit as e:
logger.warning(f"Pipeline exited with code {e.code}.")
sys.exit(e.code) # Propagate exit code
sys.exit(e.code)
except Exception as e:
logger.critical(f"An critical error occurred during pipeline execution: {e}", exc_info=True)
sys.exit(1)