intermediate backup

This commit is contained in:
2025-05-03 20:46:14 +02:00
parent 2b0a5728d4
commit 6542caf48f
38 changed files with 4513 additions and 1067 deletions

View File

@ -4,7 +4,7 @@ import pandas as pd
import json
from typing import Optional, Dict, List, Any
# Use utils for config if that's the structure
from data_analysis.utils.config_model import settings
from data_analysis.utils.data_config_model import settings
import datetime
logger = logging.getLogger(__name__)

View File

@ -3,7 +3,7 @@ from pathlib import Path
import pandas as pd
from typing import Tuple, Optional, Dict, Any
from data_analysis.utils.config_model import settings
from data_analysis.utils.data_config_model import settings
logger = logging.getLogger(__name__)

View File

@ -9,7 +9,7 @@ import shutil
import pandas as pd
from data_analysis.utils.config_model import settings # Assuming settings are configured
from data_analysis.utils.data_config_model import settings # Assuming settings are configured
from data_analysis.utils.report_model import ReportData
logger = logging.getLogger(__name__)

View File

@ -5,7 +5,7 @@ from pathlib import Path
import time
# Import necessary components from your project structure
from data_analysis.utils.config_model import load_settings, Settings # Import loading function and model
from data_analysis.utils.data_config_model import load_settings, Settings # Import loading function and model
from data_analysis.analysis.pipeline import run_eda_pipeline # Import the pipeline entry point
# Silence overly verbose libraries if needed (e.g., matplotlib)

View File

@ -2,6 +2,13 @@
project_name: "TimeSeriesForecasting" # Name for the project/run
random_seed: 42 # Optional: Global random seed for reproducibility
log_level: INFO # Or DEBUG
# --- Execution Control ---
run_cross_validation: true # Run the main cross-validation loop?
run_classic_training: true # Run a single classic train/val/test split?
run_ensemble_evaluation: true # Run ensemble evaluation (requires run_cross_validation=true)?
# --- End Execution Control ---
# --- Data Loading Configuration ---
data:
@ -20,7 +27,7 @@ data:
# --- Feature Engineering & Preprocessing Configuration ---
features:
sequence_length: 72 # REQUIRED: Lookback window size (e.g., 72 hours = 3 days)
forecast_horizon: 24 # REQUIRED: Number of steps ahead to predict (e.g., 24 hours)
forecast_horizon: [ 1, 6, 12, 24] # REQUIRED: List of steps ahead to predict (e.g., 1 hour, 6 hours, 12 hours, 24 hours, 48 hours, 72 hours, 168 hours)
lags: [24, 48, 72, 168] # List of lag features to create (e.g., 1 day, 2 days, 3 days, 1 week)
rolling_window_sizes: [24, 72, 168] # List of window sizes for rolling stats (mean, std)
use_time_features: true # Create calendar features (hour, dayofweek, month, etc.)?
@ -62,7 +69,7 @@ training:
scheduler_step_size: null # Optional: Step size for StepLR scheduler (epochs). Set null/None to disable. Must be > 0 if set.
scheduler_gamma: null # Optional: Gamma factor for StepLR scheduler. Set null/None to disable. Must be 0 < gamma < 1 if set.
gradient_clip_val: 1.0 # Optional: Value for gradient clipping. Set null/None to disable. Must be >= 0.0 if set.
num_workers: 0 # Number of workers for DataLoader (>= 0). 0 means data loading happens in the main process.
num_workers: 4 # Number of workers for DataLoader (>= 0). 0 means data loading happens in the main process.
precision: 32 # Training precision (16, 32, 64, 'bf16')
# --- Cross-Validation Configuration (Rolling Window) ---
@ -80,9 +87,11 @@ evaluation:
# --- Optuna Hyperparameter Optimization Configuration ---
optuna:
enabled: false # Enable Optuna HPO? If true, requires optuna.py script.
n_trials: 20 # Number of trials to run (must be > 0)
storage: null # Optional: Optuna storage URL (e.g., "sqlite:///output/hpo_results/study.db"). If null, uses in-memory.
direction: "minimize" # Optimization direction ('minimize' or 'maximize')
metric_to_optimize: "val_mae_orig_scale" # Metric logged by LightningModule to optimize
pruning: true # Enable Optuna trial pruning?
enabled: true # Set to true to actually run HPO via optuna_run.py
study_name: "lstm_price_forecast_hpo_v1" # Specific name for this study
n_trials: 200 # Number of trials to run
storage: "sqlite:///output/hpo_results/study_v1.db" # Path to database file
direction: "minimize" # 'minimize' or 'maximize'
metric_to_optimize: "val_MeanAbsoluteError" # Metric logged in validation_step
pruning: true # Enable pruning

View File

@ -15,7 +15,7 @@ from .data_processing import (
prepare_fold_data_and_loaders,
TimeSeriesDataset
)
from .model import LSTMForecastLightningModule
from forecasting_model.train.model import LSTMForecastLightningModule
from .evaluation import (
evaluate_fold_predictions,
# Optionally expose the standalone evaluation utility if needed externally

View File

@ -5,9 +5,10 @@ import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from typing import Tuple, Generator, List, Optional, Union, Dict, Literal, Type
import math # Add math import
# Use relative import for utils within the package
from .utils.config_model import DataConfig, FeatureConfig, TrainingConfig, EvaluationConfig, CrossValidationConfig
from .utils.forecast_config_model import DataConfig, FeatureConfig, TrainingConfig, EvaluationConfig, CrossValidationConfig
# Optional: Import wavelet library if needed later
# import pywt
@ -264,31 +265,39 @@ def engineer_features(df: pd.DataFrame, target_col: str, feature_config: Feature
if isinstance(nan_handler, str):
if nan_handler in ['ffill', 'bfill']:
fill_method = nan_handler
logger.debug(f"Filling NaNs in generated features using method: '{fill_method}'")
logger.debug(f"Selected NaN fill method for generated features: '{fill_method}'")
elif nan_handler == 'mean':
logger.warning("NaN filling with 'mean' in generated features is applied globally here;"
" consider per-fold mean filling if lookahead is a concern.")
# Calculate mean only on the slice provided, potentially leaking info if slice includes val/test
# Better to use ffill/bfill here or handle after split
fill_value = features_df[feature_cols_generated].mean() # Calculate mean per feature column
logger.debug("Filling NaNs in generated features using column means.")
fill_value = features_df[feature_cols_generated].mean()
logger.debug("Selected NaN fill method: column means.")
else:
logger.warning(f"Unsupported string fill_nan method '{nan_handler}' for generated features. Using 'ffill'.")
fill_method = 'ffill'
fill_method = 'ffill' # Default to ffill if unsupported string
elif isinstance(nan_handler, (int, float)):
fill_value = float(nan_handler)
logger.debug(f"Filling NaNs in generated features with value: {fill_value}")
logger.debug(f"Selected NaN fill value for generated features: {fill_value}")
else:
logger.warning(f"Invalid fill_nan type: {type(nan_handler)}. NaNs in features may remain.")
# Apply filling only to generated feature columns
if fill_method:
features_df[feature_cols_generated] = features_df[feature_cols_generated].fillna(method=fill_method)
# Apply filling only to generated feature columns using recommended methods
if fill_method == 'ffill':
features_df[feature_cols_generated] = features_df[feature_cols_generated].fillna(method='bfill')
logger.debug("Applying .ffill() to generated features...")
features_df[feature_cols_generated] = features_df[feature_cols_generated].ffill()
# Apply bfill afterwards to handle any NaNs remaining at the very beginning
logger.debug("Applying .bfill() to handle any remaining NaNs at the start...")
features_df[feature_cols_generated] = features_df[feature_cols_generated].bfill()
elif fill_method == 'bfill':
logger.debug("Applying .bfill() to generated features...")
features_df[feature_cols_generated] = features_df[feature_cols_generated].bfill()
# Optionally apply ffill after bfill if you need to fill trailing NaNs (less common)
# features_df[feature_cols_generated] = features_df[feature_cols_generated].ffill()
elif fill_value is not None:
# fillna with Series/dict for column-wise mean, or scalar for constant value
logger.debug(f"Applying .fillna(value={fill_value}) to generated features...")
features_df[feature_cols_generated] = features_df[feature_cols_generated].fillna(value=fill_value)
# No else needed, if fill_method and fill_value are None, no filling happens
else:
logger.warning("`fill_nan` is None. NaNs generated by feature engineering may remain.")
@ -366,36 +375,31 @@ class TimeSeriesCrossValidationSplitter:
# Estimate if None
elif self.initial_train_size is None:
min_samples_per_split_step = 2 # Heuristic minimum samples for val+test in one step
# Estimate val/test based on *potential* train size (crude)
# Assume train is roughly (1 - val - test) fraction for estimation
estimated_train_frac = max(0.1, 1.0 - self.val_frac - self.test_frac) # Ensure non-zero
estimated_train_n = int(self.n_samples * estimated_train_frac)
val_test_size_per_step = max(min_samples_per_split_step, int(estimated_train_n * (self.val_frac + self.test_frac)))
logger.info("Estimating fixed train size based on n_splits, val_frac, test_frac.")
# Estimate based on the total space needed for all splits:
# n_samples >= fixed_train_n + val_size + test_size + (n_splits - 1) * step_size
# n_samples >= fixed_train_n + int(fixed_train_n*val_frac) + n_splits * int(fixed_train_n*test_frac)
# n_samples >= fixed_train_n * (1 + val_frac + n_splits * test_frac)
# fixed_train_n <= n_samples / (1 + val_frac + n_splits * test_frac)
# Tentative initial train size is total minus one val/test block
fixed_train_n_est = self.n_samples - val_test_size_per_step
denominator = 1.0 + self.val_frac + self.n_splits * self.test_frac
if denominator <= 1.0: # Avoid division by zero or non-positive, and ensure train frac < 1
raise ValueError(f"Cannot estimate initial_train_size. Combination of val_frac ({self.val_frac}), "
f"test_frac ({self.test_frac}), and n_splits ({self.n_splits}) is invalid (denominator {denominator:.2f} <= 1.0).")
# Basic sanity checks
if fixed_train_n_est <= 0:
raise ValueError("Could not estimate a valid initial_train_size (<= 0). Please specify it or check CV fractions.")
# Need at least 1 sample for train, val, test each theoretically
est_val_size = max(1, int(fixed_train_n_est * self.val_frac))
est_test_size = max(1, int(fixed_train_n_est * self.test_frac))
if fixed_train_n_est + est_val_size + est_test_size > self.n_samples:
# If the simple estimate is too large, reduce it more drastically
# Try setting train size = 50% and see if val/test fit?
fixed_train_n_est = int(self.n_samples * 0.5)
est_val_size = max(1, int(fixed_train_n_est * self.val_frac))
est_test_size = max(1, int(fixed_train_n_est * self.test_frac))
if fixed_train_n_est <=0 or (fixed_train_n_est + est_val_size + est_test_size > self.n_samples):
raise ValueError("Could not estimate a valid initial_train_size. Data too small relative to val/test fractions? Please specify initial_train_size.")
estimated_size = int(self.n_samples / denominator)
logger.warning(f"initial_train_size not set, estimated fixed train size for rolling window: {fixed_train_n_est}. "
"This is a heuristic; viability depends on n_splits and step size. Validation happens in split().")
return fixed_train_n_est
# Add a sanity check: ensure estimated size is reasonably large
min_required_for_features = 1 # Placeholder - ideally get from FeatureConfig if possible, but complex here
if estimated_size < min_required_for_features:
raise ValueError(f"Estimated fixed train size ({estimated_size}) is too small. "
f"Check CV config (n_splits={self.n_splits}, val_frac={self.val_frac}, test_frac={self.test_frac}) "
f"relative to total samples ({self.n_samples}). Consider specifying initial_train_size manually.")
logger.info(f"Estimated fixed training window size: {estimated_size}")
return estimated_size
else:
raise ValueError(f"Invalid initial_train_size: {self.initial_train_size}")
raise ValueError(f"Invalid initial_train_size type or value: {self.initial_train_size}")
def split(self) -> Generator[Tuple[np.ndarray, np.ndarray, np.ndarray], None, None]:
@ -483,28 +487,31 @@ class TimeSeriesDataset(Dataset):
"""
PyTorch Dataset for time series forecasting.
Takes a NumPy array (features + target), sequence length, and forecast horizon,
and returns (input_sequence, target_sequence) tuples. Compatible with PyTorch
DataLoaders used by PyTorch Lightning.
Takes a NumPy array (features + target), sequence length, and a list of
specific forecast horizons. Returns (input_sequence, target_vector) tuples,
where target_vector contains the target values at the specified future steps.
"""
def __init__(self, data_array: np.ndarray, sequence_length: int, forecast_horizon: int, target_col_index: int = 0):
def __init__(self, data_array: np.ndarray, sequence_length: int, forecast_horizon: List[int], target_col_index: int = 0):
"""
Args:
data_array: Numpy array of shape (n_samples, n_features).
Assumes the target variable is one of the columns.
sequence_length: Length of the input sequence (lookback window).
forecast_horizon: Number of steps ahead to predict.
forecast_horizon: List of specific steps ahead to predict (e.g., [1, 6, 12]).
target_col_index: Index of the target column in data_array. Defaults to 0.
"""
if sequence_length <= 0:
raise ValueError("sequence_length must be positive.")
if forecast_horizon <= 0:
raise ValueError("forecast_horizon must be positive.")
if not forecast_horizon or not isinstance(forecast_horizon, list) or any(h <= 0 for h in forecast_horizon):
raise ValueError("forecast_horizon must be a non-empty list of positive integers.")
if data_array.ndim != 2:
raise ValueError(f"data_array must be 2D, but got shape {data_array.shape}")
min_len_required = sequence_length + forecast_horizon
self.max_horizon = max(forecast_horizon) # Find the furthest point needed
min_len_required = sequence_length + self.max_horizon
if min_len_required > data_array.shape[0]:
raise ValueError(f"sequence_length ({sequence_length}) + forecast_horizon ({forecast_horizon}) = {min_len_required} "
raise ValueError(f"sequence_length ({sequence_length}) + max_horizon ({self.max_horizon}) = {min_len_required} "
f"exceeds total samples provided ({data_array.shape[0]})")
if not (0 <= target_col_index < data_array.shape[1]):
raise ValueError(f"target_col_index ({target_col_index}) out of bounds for data with {data_array.shape[1]} columns.")
@ -512,32 +519,37 @@ class TimeSeriesDataset(Dataset):
self.data = torch.tensor(data_array, dtype=torch.float32)
self.sequence_length = sequence_length
self.forecast_horizon = forecast_horizon
self.forecast_horizon_list = sorted(forecast_horizon)
self.target_col_index = target_col_index
self.n_samples = data_array.shape[0]
self.n_features = data_array.shape[1]
logger.debug(f"TimeSeriesDataset created: data shape={self.data.shape}, "
f"seq_len={self.sequence_length}, forecast_horizon={self.forecast_horizon}, "
f"target_idx={self.target_col_index}")
f"seq_len={self.sequence_length}, forecast_horizons={self.forecast_horizon_list}, "
f"max_horizon={self.max_horizon}, target_idx={self.target_col_index}")
def __len__(self) -> int:
"""Returns the total number of sequences that can be generated."""
return self.n_samples - self.sequence_length - self.forecast_horizon + 1
return self.n_samples - self.sequence_length - self.max_horizon + 1
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns a single (input_sequence, target_sequence) pair.
Returns a single (input_sequence, target_vector) pair.
Target vector contains values for the specified forecast horizons.
"""
if not (0 <= idx < len(self)):
raise IndexError(f"Index {idx} out of bounds for dataset with length {len(self)}")
input_start = idx
input_end = idx + self.sequence_length
input_sequence = self.data[input_start:input_end, :]
target_start = input_end
target_end = target_start + self.forecast_horizon
target_sequence = self.data[target_start:target_end, self.target_col_index]
return input_sequence, target_sequence
input_sequence = self.data[input_start:input_end, :] # Shape: (seq_len, n_features)
# Calculate indices for each horizon relative to the end of the input sequence
# Horizon h corresponds to index: input_end + h - 1
target_indices = [input_end + h - 1 for h in self.forecast_horizon_list]
target_vector = self.data[target_indices, self.target_col_index] # Shape: (len(forecast_horizon_list),)
return input_sequence, target_vector
# --- Data Preparation ---
def prepare_fold_data_and_loaders(
@ -577,6 +589,7 @@ def prepare_fold_data_and_loaders(
train_config: Configuration for training (used for batch size, device hints).
eval_config: Configuration for evaluation (used for batch size).
Returns:
Tuple containing:
- train_loader: DataLoader for the training set.
@ -598,13 +611,25 @@ def prepare_fold_data_and_loaders(
if feature_config.lags:
max_lookback = max(max_lookback, max(feature_config.lags))
if feature_config.rolling_window_sizes:
max_lookback = max(max_lookback, max(feature_config.rolling_window_sizes) -1 )
max_history_needed = max(max_lookback, feature_config.sequence_length)
max_lookback = max(max_lookback, max(feature_config.rolling_window_sizes) -1)
# Also need history for the input sequence length and max target horizon
max_horizon_needed = max(feature_config.forecast_horizon) if feature_config.forecast_horizon else 0
# Max history needed is max of lookback for features OR (sequence_length + max_horizon - 1) for targets/inputs
# Correct logic: Need `sequence_length` history for input, and `max_horizon` steps *after* the train data for targets/evaluation.
# The slicing needs to ensure enough data *before* train_idx[0] for feature lookback *and* sequence_length.
# Max history *before* the start of the training set
max_history_needed_before_train = max(max_lookback, feature_config.sequence_length)
slice_start_idx = max(0, train_idx[0] - max_history_needed_before_train)
# The end index needs to cover the test set PLUS the maximum horizon needed for the last test target
slice_end_idx = test_idx[-1] + max_horizon_needed # Go up to the last needed target
# Ensure end index is within bounds
slice_end_idx = min(slice_end_idx + 1, len(full_df)) # +1 because iloc is exclusive
slice_start_idx = max(0, train_idx[0] - max_history_needed)
slice_end_idx = test_idx[-1] + 1
if slice_start_idx >= slice_end_idx:
raise ValueError(f"Calculated slice start ({slice_start_idx}) >= slice end ({slice_end_idx}). Check indices.")
raise ValueError(f"Calculated slice start ({slice_start_idx}) >= slice end ({slice_end_idx}). Check indices and horizon.")
fold_data_slice = full_df.iloc[slice_start_idx:slice_end_idx]
logger.debug(f"Required data slice for fold: indices {slice_start_idx} to {slice_end_idx-1} "
@ -709,22 +734,38 @@ def prepare_fold_data_and_loaders(
input_size = train_data_scaled.shape[1]
# --- Ensure final data arrays are float32 for PyTorch ---
try:
# Explicitly convert to float32 AFTER scaling (or non-scaling)
train_data_final = train_data_scaled.astype(np.float32)
val_data_final = val_data_scaled.astype(np.float32)
test_data_final = test_data_scaled.astype(np.float32)
logger.debug("Ensured final data arrays are float32.")
except ValueError as e:
# This might happen if data cannot be safely cast (e.g., strings remain unexpectedly)
logger.error(f"Failed to convert data arrays to float32 before creating Tensors: {e}", exc_info=True)
# Consider adding more debug info here if it fails, e.g.:
# logger.debug(f"Data types in train_df before conversion: \n{train_df.dtypes}")
raise ValueError("Data could not be converted to numeric type (float32) for PyTorch.") from e
# 6. Dataset Instantiation
logger.debug("Creating TimeSeriesDataset instances for the fold.")
try:
# Use the explicitly converted arrays
train_dataset = TimeSeriesDataset(
train_data_scaled, feature_config.sequence_length, feature_config.forecast_horizon, target_col_index=target_col_index_in_features
train_data_final, feature_config.sequence_length, feature_config.forecast_horizon, target_col_index=target_col_index_in_features
)
val_dataset = TimeSeriesDataset(
val_data_scaled, feature_config.sequence_length, feature_config.forecast_horizon, target_col_index=target_col_index_in_features
val_data_final, feature_config.sequence_length, feature_config.forecast_horizon, target_col_index=target_col_index_in_features
)
test_dataset = TimeSeriesDataset(
test_data_scaled, feature_config.sequence_length, feature_config.forecast_horizon, target_col_index=target_col_index_in_features
test_data_final, feature_config.sequence_length, feature_config.forecast_horizon, target_col_index=target_col_index_in_features
)
except ValueError as e:
logger.error(f"Error creating TimeSeriesDataset: {e}")
logger.error(f"Shapes fed to Dataset: Train={train_data_scaled.shape}, Val={val_data_scaled.shape}, Test={test_data_scaled.shape}")
logger.error(f"SeqLen={feature_config.sequence_length}, Horizon={feature_config.forecast_horizon}")
logger.error(f"Shapes fed to Dataset: Train={train_data_final.shape}, Val={val_data_final.shape}, Test={test_data_final.shape}")
logger.error(f"SeqLen={feature_config.sequence_length}, Horizons={feature_config.forecast_horizon}")
raise
@ -749,3 +790,68 @@ def prepare_fold_data_and_loaders(
logger.info("Data loaders prepared successfully for the fold.")
return train_loader, val_loader, test_loader, target_scaler, input_size
# --- Classic Train/Val/Test Split ---
def split_data_classic(
n_samples: int,
val_frac: float,
test_frac: float,
start_from_end: bool = True
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Splits data indices into one train, one validation, and one test set based on fractions.
Args:
n_samples: Total number of samples in the dataset.
val_frac: Fraction of the *total* data to use for validation.
test_frac: Fraction of the *total* data to use for testing.
start_from_end: If True (default), test and validation sets are taken from the end
of the series. If False, they are taken after the initial training block.
Default is True for typical time series evaluation.
Returns:
Tuple of (train_indices, val_indices, test_indices).
Raises:
ValueError: If fractions are invalid or sum to >= 1.
"""
if not (0 < val_frac < 1):
raise ValueError(f"val_frac must be between 0 and 1, got {val_frac}")
if not (0 < test_frac < 1):
raise ValueError(f"test_frac must be between 0 and 1, got {test_frac}")
if val_frac + test_frac >= 1:
raise ValueError(f"Sum of val_frac ({val_frac}) and test_frac ({test_frac}) must be less than 1.")
test_size = math.ceil(n_samples * test_frac) # Use ceil to ensure at least one sample if frac is tiny
val_size = math.ceil(n_samples * val_frac)
train_size = n_samples - val_size - test_size
if train_size <= 0:
raise ValueError(f"Calculated train_size ({train_size}) is not positive. Adjust fractions or increase data.")
if val_size <= 0:
raise ValueError("Calculated val_size is not positive.")
if test_size <= 0:
raise ValueError("Calculated test_size is not positive.")
indices = np.arange(n_samples)
if start_from_end:
train_indices = indices[:train_size]
val_indices = indices[train_size:train_size + val_size]
test_indices = indices[train_size + val_size:]
# Adjust if ceil caused slight overallocation in test
test_indices = test_indices[:test_size]
else:
# Less common: place val/test directly after train
train_indices = indices[:train_size]
val_indices = indices[train_size:train_size + val_size]
test_indices = indices[train_size + val_size:train_size + val_size + test_size]
# Remaining data is unused in this scenario
logger.info(f"Classic split: Train indices {train_indices[0]}-{train_indices[-1]} (size {len(train_indices)}), "
f"Val indices {val_indices[0]}-{val_indices[-1]} (size {len(val_indices)}), "
f"Test indices {test_indices[0]}-{test_indices[-1]} (size {len(test_indices)})")
return train_indices, val_indices, test_indices

View File

@ -1,24 +1,22 @@
import logging
import os
from pathlib import Path # Added
import numpy as np
import torch
import torchmetrics
from torch.utils.data import DataLoader
from sklearn.preprocessing import StandardScaler, MinMaxScaler # For type hinting target_scaler
from typing import Dict, Any, Optional, Union, List, Tuple
# import matplotlib.pyplot as plt # No longer needed directly
# import seaborn as sns # No longer needed directly
from typing import Dict, Optional, Union, List
import pandas as pd # For time index type hint
# Assuming config_model and io.plotting are accessible
from forecasting_model.utils.config_model import EvaluationConfig
from forecasting_model.io.plotting import ( # Import the plotting utilities
from forecasting_model.utils.forecast_config_model import EvaluationConfig
from forecasting_model.train.model import LSTMForecastLightningModule
from forecasting_model.io.plotting import (
setup_plot_style,
save_plot,
create_time_series_plot,
create_scatter_plot,
create_residuals_plot,
create_residuals_distribution_plot
create_residuals_distribution_plot,
)
@ -82,90 +80,101 @@ def calculate_rmse_np(y_true: np.ndarray, y_pred: np.ndarray) -> float:
return float(rmse)
# --- Plotting Functions (Utilities) ---
# REMOVED - These are now imported from io.plotting
# --- Fold Evaluation Function ---
def evaluate_fold_predictions(
y_true_scaled: np.ndarray,
y_pred_scaled: np.ndarray,
y_true_scaled: np.ndarray, # Shape: (n_samples, len(horizons))
y_pred_scaled: np.ndarray, # Shape: (n_samples, len(horizons))
target_scaler: Union[StandardScaler, MinMaxScaler, None],
eval_config: EvaluationConfig,
fold_num: int,
output_dir: str, # Base output directory (e.g., output/cv_results)
time_index: Optional[np.ndarray] = None # Optional: Pass time index for x-axis
fold_num: int, # Zero-based fold index
output_dir: str, # Base output directory
plot_subdir: Optional[str] = "plots",
# time_index: Optional[Union[np.ndarray, pd.Index]] = None, # OLD: Index for samples
prediction_time_index: Optional[pd.Index] = None, # Index corresponding to the prediction times (n_samples,)
forecast_horizons: Optional[List[int]] = None, # The list of horizons predicted (e.g., [1, 6, 12])
plot_title_prefix: Optional[str] = None
) -> Dict[str, float]:
"""
Processes prediction results for a fold's test set using torchmetrics.
Processes prediction results (multiple horizons) for a fold or ensemble.
Takes scaled predictions and targets, inverse transforms them,
calculates final metrics (MAE, RMSE) using torchmetrics.functional,
and generates evaluation plots using utilities from io.plotting. Assumes
model inference is already done.
Takes scaled predictions and targets (shape: samples, num_horizons),
inverse transforms them, calculates overall metrics (MAE, RMSE) across all horizons,
and generates evaluation plots *for the first specified horizon only*.
Args:
y_true_scaled: Numpy array of scaled ground truth targets (n_samples, horizon).
y_pred_scaled: Numpy array of scaled model predictions (n_samples, horizon).
target_scaler: The scaler fitted on the target variable during training. Needed
for inverse transforming to original scale. Can be None.
eval_config: Configuration object for evaluation parameters (e.g., plotting).
fold_num: The current fold number (e.g., 0, 1, ...).
output_dir: The base directory to save fold-specific outputs (plots, metrics).
time_index: Optional array representing the time index for the test set,
used for x-axis in time-based plots. If None, uses integer indices.
y_true_scaled: Numpy array of scaled ground truth targets (n_samples, len(horizons)).
y_pred_scaled: Numpy array of scaled model predictions (n_samples, len(horizons)).
target_scaler: The scaler fitted on the target variable.
eval_config: Configuration object for evaluation parameters.
fold_num: The current fold number (zero-based or -1 for classic).
output_dir: The base directory to save outputs.
plot_subdir: Specific subdirectory under output_dir for plots.
prediction_time_index: Pandas Index representing the time for each prediction point (n_samples,).
Required for meaningful time plots.
forecast_horizons: List of horizons predicted (e.g., [1, 6, 12]). Required for plotting.
plot_title_prefix: Optional string to prepend to plot titles.
Returns:
Dictionary containing evaluation metrics {'MAE': value, 'RMSE': value} on the
original scale. Metrics will be NaN if inverse transform or calculation fails.
Raises:
ValueError: If input shapes are inconsistent or required scaler is missing.
original scale, calculated *across all predicted horizons*.
"""
logger.info(f"Processing evaluation results for Fold {fold_num + 1}...")
fold_id = fold_num + 1 # Use 1-based indexing for reporting/filenames
fold_id_str = f"Fold {fold_num + 1}" if fold_num >= 0 else "Classic Run"
eval_context_str = f"{plot_title_prefix} {fold_id_str}" if plot_title_prefix else fold_id_str
logger.info(f"Processing evaluation results for: {eval_context_str}")
if y_true_scaled.shape != y_pred_scaled.shape:
raise ValueError(f"Shape mismatch between targets and predictions: "
raise ValueError(f"Shape mismatch between targets and predictions for {eval_context_str}: "
f"{y_true_scaled.shape} vs {y_pred_scaled.shape}")
if y_true_scaled.ndim != 2:
raise ValueError(f"Expected 2D arrays for targets and predictions, got {y_true_scaled.ndim}D")
raise ValueError(f"Expected 2D arrays (samples, num_horizons) for {eval_context_str}, got {y_true_scaled.ndim}D")
n_samples, horizon = y_true_scaled.shape
logger.debug(f"Processing {n_samples} samples with horizon {horizon}.")
n_samples, n_horizons = y_true_scaled.shape
logger.debug(f"Processing {n_samples} samples across {n_horizons} horizons for {eval_context_str}.")
# --- Inverse Transform (Outputs NumPy) ---
y_true_flat_scaled = y_true_scaled.reshape(-1, 1)
y_pred_flat_scaled = y_pred_scaled.reshape(-1, 1)
# Flatten the multi-horizon arrays for the scaler (which expects (N, 1))
y_true_flat_scaled = y_true_scaled.reshape(-1, 1) # Shape: (n_samples * n_horizons, 1)
y_pred_flat_scaled = y_pred_scaled.reshape(-1, 1) # Shape: (n_samples * n_horizons, 1)
y_true_inv_np: np.ndarray
y_pred_inv_np: np.ndarray
if target_scaler is not None:
try:
logger.debug("Inverse transforming predictions and targets.")
y_true_inv_np = target_scaler.inverse_transform(y_true_flat_scaled)
y_pred_inv_np = target_scaler.inverse_transform(y_pred_flat_scaled)
# Flatten NumPy arrays for metric calculation and plotting
y_true_np = y_true_inv_np.flatten()
y_pred_np = y_pred_inv_np.flatten()
logger.debug(f"Inverse transforming predictions and targets for {eval_context_str}.")
y_true_inv_flat = target_scaler.inverse_transform(y_true_flat_scaled)
y_pred_inv_flat = target_scaler.inverse_transform(y_pred_flat_scaled)
# Reshape back to (n_samples, n_horizons) for potential per-horizon analysis later
y_true_inv_np = y_true_inv_flat.reshape(n_samples, n_horizons)
y_pred_inv_np = y_pred_inv_flat.reshape(n_samples, n_horizons)
except Exception as e:
logger.error(f"Error during inverse scaling for Fold {fold_id}: {e}", exc_info=True)
logger.error(f"Error during inverse scaling for {eval_context_str}: {e}", exc_info=True)
logger.error("Metrics calculation will be skipped due to inverse transform failure.")
return {'MAE': np.nan, 'RMSE': np.nan}
else:
logger.info("No target scaler provided, assuming inputs are already on original scale.")
# Flatten NumPy arrays for metric calculation and plotting
y_true_np = y_true_flat_scaled.flatten()
y_pred_np = y_pred_flat_scaled.flatten()
# --- Calculate Metrics using torchmetrics.functional ---
metrics: Dict[str, float] = {'MAE': np.nan, 'RMSE': np.nan} # Initialize with NaN
else:
logger.info(f"No target scaler provided for {eval_context_str}, assuming inputs are on original scale.")
y_true_inv_np = y_true_scaled # Keep original shape (n_samples, n_horizons)
y_pred_inv_np = y_pred_scaled # Keep original shape
# --- Calculate Metrics using torchmetrics.functional (Overall across all horizons) ---
metrics: Dict[str, float] = {'MAE': np.nan, 'RMSE': np.nan}
try:
if len(y_true_np) > 0: # Check if data exists after potential failures
y_true_tensor = torch.from_numpy(y_true_np).float().cpu()
y_pred_tensor = torch.from_numpy(y_pred_np).float().cpu()
# Flatten arrays for overall metrics calculation
y_true_flat_for_metrics = y_true_inv_np.flatten()
y_pred_flat_for_metrics = y_pred_inv_np.flatten()
valid_mask = ~np.isnan(y_true_flat_for_metrics) & ~np.isnan(y_pred_flat_for_metrics)
if np.sum(valid_mask) < len(y_true_flat_for_metrics):
nan_count = len(y_true_flat_for_metrics) - np.sum(valid_mask)
logger.warning(f"{nan_count} NaN values found in predictions/targets (across all horizons) for {eval_context_str}. These will be excluded from metrics.")
if np.sum(valid_mask) > 0:
y_true_tensor = torch.from_numpy(y_true_flat_for_metrics[valid_mask]).float().cpu()
y_pred_tensor = torch.from_numpy(y_pred_flat_for_metrics[valid_mask]).float().cpu()
mae_tensor = torchmetrics.functional.mean_absolute_error(y_pred_tensor, y_true_tensor)
mse_tensor = torchmetrics.functional.mean_squared_error(y_pred_tensor, y_true_tensor)
@ -174,82 +183,95 @@ def evaluate_fold_predictions(
metrics['MAE'] = mae_tensor.item()
metrics['RMSE'] = rmse_tensor.item()
logger.info(f"Fold {fold_id} Test Set Metrics (torchmetrics): MAE={metrics['MAE']:.4f}, RMSE={metrics['RMSE']:.4f}")
logger.info(f"{eval_context_str} Test Set Overall Metrics (torchmetrics): MAE={metrics['MAE']:.4f}, RMSE={metrics['RMSE']:.4f} (across all horizons)")
else:
logger.warning(f"Skipping metric calculation for Fold {fold_id} due to empty data after inverse transform.")
logger.warning(f"Skipping metric calculation for {eval_context_str} due to no valid (non-NaN) data points.")
except Exception as e:
logger.error(f"Failed to calculate metrics using torchmetrics for Fold {fold_id}: {e}", exc_info=True)
# metrics already initialized to NaN
logger.error(f"Failed to calculate overall metrics using torchmetrics for {eval_context_str}: {e}", exc_info=True)
# --- Generate Plots (Optional - uses plotting utilities) ---
if eval_config.save_plots and len(y_true_np) > 0:
logger.info(f"Generating evaluation plots for Fold {fold_id}...")
# Define plot directory and setup style
fold_plot_dir = Path(output_dir) / f"fold_{fold_id:02d}" / "plots"
setup_plot_style() # Apply consistent styling
# --- Generate Plots (Optional - Focus on FIRST horizon) ---
if eval_config.save_plots and np.sum(valid_mask) > 0:
if forecast_horizons is None or not forecast_horizons:
logger.warning(f"Skipping plot generation for {eval_context_str}: `forecast_horizons` list not provided.")
elif prediction_time_index is None or len(prediction_time_index) != n_samples:
logger.warning(f"Skipping plot generation for {eval_context_str}: `prediction_time_index` is missing or has incorrect length ({len(prediction_time_index) if prediction_time_index is not None else 'None'} != {n_samples}).")
else:
logger.info(f"Generating evaluation plots for {eval_context_str} (using first horizon H+{forecast_horizons[0]} only)...")
base_plot_dir = Path(output_dir)
fold_plot_dir = base_plot_dir / plot_subdir if plot_subdir else base_plot_dir
setup_plot_style()
title_suffix = f"Fold {fold_id} Test Set"
residuals_np = y_true_np - y_pred_np
# --- Plotting for the FIRST horizon ---
first_horizon = forecast_horizons[0]
y_true_h1 = y_true_inv_np[:, 0] # Data for the first horizon
y_pred_h1 = y_pred_inv_np[:, 0] # Data for the first horizon
residuals_h1 = y_true_h1 - y_pred_h1
# Determine x-axis: use provided time_index if available, else integer indices
# Note: Flattened y_true/y_pred have length n_samples * horizon
# Need an appropriate index for this flattened view if time_index is provided.
# Simple approach: use integer indices for flattened data.
plot_indices = np.arange(len(y_true_np))
xlabel = "Time Index (Flattened Horizon x Samples)"
# If time_index corresponding to the start of each forecast is passed,
# more sophisticated x-axis handling could be done, but integer indices are simpler.
# Calculate the actual time index for the first horizon's targets
# Requires the original dataset's frequency if available, otherwise assumes simple offset
target_time_index_h1 = prediction_time_index
try:
# Assuming prediction_time_index corresponds to the *time* of prediction
# The target for H+h occurs `h` steps later.
# This requires a DatetimeIndex with a frequency.
if isinstance(prediction_time_index, pd.DatetimeIndex) and prediction_time_index.freq:
time_offset = pd.Timedelta(first_horizon, unit=prediction_time_index.freq.name)
target_time_index_h1 = prediction_time_index + time_offset
xlabel_h1 = f"Time (Target H+{first_horizon})"
else:
logger.warning(f"Prediction time index lacks frequency info. Using original prediction time for H+{first_horizon} plot x-axis.")
xlabel_h1 = f"Prediction Time (Plotting H+{first_horizon})"
except Exception as time_err:
logger.warning(f"Could not calculate target time index for H+{first_horizon}: {time_err}. Using prediction time index for x-axis.")
xlabel_h1 = f"Prediction Time (Plotting H+{first_horizon})"
title_suffix = f"- {eval_context_str} (H+{first_horizon})"
try:
# Create and save each plot using utility functions
fig_ts = create_time_series_plot(
plot_indices, y_true_np, y_pred_np,
f"Predictions vs Actual - {title_suffix}",
xlabel=xlabel,
ylabel="Value (Original Scale)",
target_time_index_h1, y_true_h1, y_pred_h1, # Use H1 data and time
f"Predictions vs Actual {title_suffix}",
xlabel=xlabel_h1, ylabel="Value (Original Scale)",
max_points=eval_config.plot_sample_size
)
save_plot(fig_ts, fold_plot_dir / "predictions_vs_actual.png")
save_plot(fig_ts, fold_plot_dir / f"predictions_vs_actual_h{first_horizon}.png")
fig_scatter = create_scatter_plot(
y_true_np, y_pred_np,
f"Scatter Plot - {title_suffix}",
xlabel="Actual Values (Original Scale)",
ylabel="Predicted Values (Original Scale)"
y_true_h1, y_pred_h1, # Use H1 data
f"Scatter Plot {title_suffix}",
xlabel="Actual Values (Original Scale)", ylabel="Predicted Values (Original Scale)"
)
save_plot(fig_scatter, fold_plot_dir / "scatter_predictions.png")
save_plot(fig_scatter, fold_plot_dir / f"scatter_predictions_h{first_horizon}.png")
fig_res_time = create_residuals_plot(
plot_indices, residuals_np,
f"Residuals Over Time - {title_suffix}",
xlabel=xlabel,
ylabel="Residual (Original Scale)",
target_time_index_h1, residuals_h1, # Use H1 residuals and time
f"Residuals Over Time {title_suffix}",
xlabel=xlabel_h1, ylabel="Residual (Original Scale)",
max_points=eval_config.plot_sample_size
)
save_plot(fig_res_time, fold_plot_dir / "residuals_time.png")
save_plot(fig_res_time, fold_plot_dir / f"residuals_time_h{first_horizon}.png")
# Residual distribution can use residuals from ALL horizons
residuals_all = y_true_inv_np.flatten() - y_pred_inv_np.flatten()
fig_res_dist = create_residuals_distribution_plot(
residuals_np,
f"Residuals Distribution - {title_suffix}",
xlabel="Residual Value (Original Scale)",
ylabel="Density"
residuals_all, # Use all residuals
f"Residuals Distribution {eval_context_str} (All Horizons)", # Adjusted title
xlabel="Residual Value (Original Scale)", ylabel="Density"
)
save_plot(fig_res_dist, fold_plot_dir / "residuals_distribution.png")
save_plot(fig_res_dist, fold_plot_dir / "residuals_distribution_all_horizons.png")
logger.info(f"Evaluation plots saved to: {fold_plot_dir}")
except Exception as e:
logger.error(f"Failed to generate or save one or more plots for Fold {fold_id}: {e}", exc_info=True)
# Continue without plots, metrics are already calculated.
logger.error(f"Failed to generate or save one or more plots for {eval_context_str}: {e}", exc_info=True)
elif eval_config.save_plots and len(y_true_np) == 0:
logger.warning(f"Skipping plot generation for Fold {fold_id} due to empty data.")
elif eval_config.save_plots and np.sum(valid_mask) == 0:
logger.warning(f"Skipping plot generation for {eval_context_str} due to no valid data points.")
logger.info(f"Evaluation processing finished for Fold {fold_id}.")
logger.info(f"Evaluation processing finished for {eval_context_str}.")
return metrics
@ -257,63 +279,90 @@ def evaluate_fold_predictions(
# This function still calls evaluate_fold_predictions internally, so it benefits
# from the updated plotting logic without needing direct changes here.
def evaluate_model_on_fold_test_set(
model: torch.nn.Module,
model: LSTMForecastLightningModule, # Use the specific type
test_loader: DataLoader,
device: torch.device,
target_scaler: Union[StandardScaler, MinMaxScaler, None],
eval_config: EvaluationConfig,
fold_num: int,
output_dir: str
output_dir: str,
# time_index: Optional[Union[np.ndarray, pd.Index]] = None, # OLD
prediction_time_index: Optional[pd.Index] = None, # Pass prediction time index
forecast_horizons: Optional[List[int]] = None # Pass horizons
) -> Dict[str, float]:
"""
[Optional Function] Evaluates a given model on a fold's test set.
Runs the inference loop, collects scaled results, then processes them using
`evaluate_fold_predictions` (which now uses plotting utilities).
Useful for standalone testing or if not using pl.Trainer.test().
Handles multiple forecast horizons.
"""
# ... (Implementation of inference loop remains the same) ...
logger.info(f"Starting full evaluation (inference + processing) for Fold {fold_num + 1}...")
model.eval()
model.to(device)
all_preds_scaled_list: List[torch.Tensor] = []
all_targets_scaled_list: List[torch.Tensor] = []
with torch.no_grad():
for i, (X_batch, y_batch) in enumerate(test_loader):
for i, batch in enumerate(test_loader):
try:
X_batch = X_batch.to(device)
outputs = model(X_batch) # Scaled outputs
if isinstance(batch, (list, tuple)) and len(batch) == 2:
X_batch, y_batch = batch # y_batch shape: (batch, len(horizons))
targets_present = True
else:
X_batch = batch
y_batch = None
targets_present = False
# Ensure outputs match target shape (e.g., handle trailing dimension)
if outputs.shape != y_batch.shape:
if outputs.ndim == y_batch.ndim + 1 and outputs.shape[-1] == 1:
outputs = outputs.squeeze(-1)
if outputs.shape != y_batch.shape:
raise ValueError(f"Shape mismatch: Output {outputs.shape}, Target {y_batch.shape}")
X_batch = X_batch.to(device)
outputs = model(X_batch) # Scaled outputs: (batch, len(horizons))
all_preds_scaled_list.append(outputs.cpu())
all_targets_scaled_list.append(y_batch.cpu()) # Keep targets on CPU
if targets_present and y_batch is not None:
if outputs.shape != y_batch.shape:
raise ValueError(f"Shape mismatch: Output {outputs.shape}, Target {y_batch.shape}")
all_targets_scaled_list.append(y_batch.cpu())
# ... error/warning if targets expected but not found ...
except Exception as e:
logger.error(f"Error during inference batch {i} for Fold {fold_num+1}: {e}", exc_info=True)
raise ValueError(f"Inference failed on batch {i} for Fold {fold_num+1}")
# Concatenate results from all batches
# --- Concatenate results ---
try:
if not all_preds_scaled_list or not all_targets_scaled_list:
logger.error(f"No prediction results collected for Fold {fold_num + 1}. Check test_loader.")
if not all_preds_scaled_list:
# ... handle no predictions ...
return {'MAE': np.nan, 'RMSE': np.nan}
# Resulting shapes: (n_samples, len(horizons))
y_pred_scaled = torch.cat(all_preds_scaled_list, dim=0).numpy()
y_true_scaled = None
if all_targets_scaled_list:
y_true_scaled = torch.cat(all_targets_scaled_list, dim=0).numpy()
elif targets_present:
# ... handle missing targets ...
return {'MAE': np.nan, 'RMSE': np.nan}
else:
# ... handle no targets available ...
return {'MAE': np.nan, 'RMSE': np.nan}
y_pred_scaled = torch.cat(all_preds_scaled_list, dim=0).numpy()
y_true_scaled = torch.cat(all_targets_scaled_list, dim=0).numpy()
except Exception as e:
logger.error(f"Error concatenating prediction results for Fold {fold_num + 1}: {e}", exc_info=True)
# ... error handling ...
raise ValueError("Failed to combine batch results during evaluation inference.")
# Process the collected predictions using the refactored function
# No time_index passed here by default, plotting will use integer indices
if y_true_scaled is None:
# ... handle missing targets ...
return {'MAE': np.nan, 'RMSE': np.nan}
# Ensure forecast_horizons are passed if available from the model
# Retrieve from model's hparams if not passed explicitly
if forecast_horizons is None:
try:
# Assuming forecast_horizon list is stored in model_config hparam
forecast_horizons = model.hparams.model_config.forecast_horizon
except AttributeError:
logger.warning("Could not retrieve forecast_horizons from model hparams for evaluation.")
# Process the collected predictions
return evaluate_fold_predictions(
y_true_scaled=y_true_scaled,
y_pred_scaled=y_pred_scaled,
@ -321,5 +370,8 @@ def evaluate_model_on_fold_test_set(
eval_config=eval_config,
fold_num=fold_num,
output_dir=output_dir,
time_index=None # Explicitly pass None
# time_index=time_index # OLD
prediction_time_index=prediction_time_index, # Pass through
forecast_horizons=forecast_horizons, # Pass through
plot_title_prefix=f"Test Fold {fold_num + 1}" # Example prefix
)

View File

@ -1,11 +1,15 @@
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from typing import Optional, Union
from typing import Optional, Union, List
import logging
import pandas as pd
from pathlib import Path
# Assuming sklearn scalers are available
from sklearn.preprocessing import StandardScaler, MinMaxScaler
logger = logging.getLogger(__name__)
def setup_plot_style(use_seaborn: bool = True) -> None:
@ -17,14 +21,16 @@ def setup_plot_style(use_seaborn: bool = True) -> None:
"""
if use_seaborn:
try:
sns.set_theme(style="whitegrid", palette="muted")
plt.rcParams['figure.figsize'] = (12, 6) # Default figure size
# Use a different style that might be better for multiple lines
sns.set_theme(style="whitegrid", palette="viridis") # Changed palette
plt.rcParams['figure.figsize'] = (15, 7) # Slightly larger default figure size
logger.debug("Seaborn plot style set.")
except Exception as e:
logger.warning(f"Failed to set seaborn theme: {e}. Using default matplotlib style.")
else:
# Optional: Define a default matplotlib style if seaborn is not used
plt.style.use('default')
plt.rcParams['figure.figsize'] = (15, 7)
logger.debug("Using default matplotlib plot style.")
def save_plot(fig: plt.Figure, filename: Union[str, Path]) -> None:
@ -49,16 +55,21 @@ def save_plot(fig: plt.Figure, filename: Union[str, Path]) -> None:
logger.info(f"Plot saved successfully to: {filepath}")
except OSError as e:
logger.error(f"Failed to create directory for plot {filepath}: {e}", exc_info=True)
raise # Re-raise OSError for directory creation issues
# Don't re-raise immediately, try closing figure first
# raise # Re-raise OSError for directory creation issues - Removed to ensure finally runs
except Exception as e:
logger.error(f"Failed to save plot to {filepath}: {e}", exc_info=True)
raise # Re-raise other saving errors
# Don't re-raise immediately, try closing figure first
finally:
# Close the figure to free up memory, regardless of saving success
# Close the figure to free up memory, regardless of saving success or failure
try:
plt.close(fig)
logger.debug(f"Closed figure for plot {filepath}.")
except Exception as e:
logger.warning(f"Failed to close figure for plot {filepath}: {e}")
def create_time_series_plot(
x: np.ndarray,
x: Union[np.ndarray, pd.Index], # Allow pd.Index for time axis
y_true: np.ndarray,
y_pred: np.ndarray,
title: str,
@ -68,9 +79,9 @@ def create_time_series_plot(
) -> plt.Figure:
"""
Create a time series plot comparing actual vs predicted values.
NOTE: When using multi-horizon forecasts, this typically plots only ONE selected horizon.
Args:
x: The array for the x-axis (e.g., time steps, indices).
x: The array or index for the x-axis (e.g., time steps, datetime index). Should align with y_true/y_pred.
y_true: Ground truth values (1D array).
y_pred: Predicted values (1D array).
title: Title for the plot.
@ -84,8 +95,9 @@ def create_time_series_plot(
Raises:
ValueError: If input array shapes are incompatible.
"""
if not (x.shape == y_true.shape == y_pred.shape and x.ndim == 1):
raise ValueError("Input arrays (x, y_true, y_pred) must be 1D and have the same shape.")
# Add check for pd.Index for x
if not isinstance(x, (np.ndarray, pd.Index)) or x.shape[0] != y_true.shape[0] or x.shape[0] != y_pred.shape[0] or y_true.ndim != 1 or y_pred.ndim != 1:
raise ValueError(f"Input shapes mismatch or invalid types: x({type(x)}, {x.shape if hasattr(x, 'shape') else 'N/A'}), y_true({y_true.shape}), y_pred({y_pred.shape}). Expecting 1D y arrays and matching length x.")
if len(x) == 0:
logger.warning("Attempting to create time series plot with empty data.")
# Return an empty figure or raise error? Let's return empty.
@ -305,3 +317,242 @@ def create_residuals_distribution_plot(
fig.tight_layout()
return fig
def create_multi_horizon_time_series_plot(
y_true_scaled_all_horizons: np.ndarray, # (N, H)
y_pred_scaled_all_horizons: np.ndarray, # (N, H)
target_scaler: Optional[Union[StandardScaler, MinMaxScaler]],
prediction_time_index_h1: pd.DatetimeIndex, # Time index for the first horizon predictions
forecast_horizons: List[int],
title: str,
xlabel: str = "Time",
ylabel: str = "Value (Original Scale)",
max_points: Optional[int] = 1000 # Limit points for clarity
) -> plt.Figure:
"""
Create a time series plot comparing actual values to predictions for multiple horizons.
Predictions for each horizon are plotted on their corresponding target time step.
Args:
y_true_scaled_all_horizons: Ground truth values (N, H array) on scaled scale.
y_pred_scaled_all_horizons: Predicted values (N, H array) on scaled scale.
target_scaler: The scaler used for the target variable, needed for inverse transform.
prediction_time_index_h1: DatetimeIndex for the first horizon (h=h1) predictions.
Length should be N.
forecast_horizons: List of forecast horizons (e.g., [1, 6, 12, 24]).
title: Title for the plot.
xlabel: Label for the x-axis.
ylabel: Label for the y-axis.
max_points: Maximum number of points to display (subsamples if needed).
Returns:
The generated matplotlib Figure object.
Raises:
ValueError: If input shapes are incompatible or horizons list is invalid.
"""
if y_true_scaled_all_horizons.shape != y_pred_scaled_all_horizons.shape:
raise ValueError(f"Shapes of y_true_scaled_all_horizons {y_true_scaled_all_horizons.shape} and y_pred_scaled_all_horizons {y_pred_scaled_all_horizons.shape} must match.")
if y_true_scaled_all_horizons.ndim != 2 or y_true_scaled_all_horizons.shape[1] != len(forecast_horizons):
raise ValueError(f"y arrays must be 2D (N, H) where H is the number of horizons ({len(forecast_horizons)}). Shape is {y_true_scaled_all_horizons.shape}.")
if len(prediction_time_index_h1) != y_true_scaled_all_horizons.shape[0]:
raise ValueError(f"Length of prediction_time_index_h1 ({len(prediction_time_index_h1)}) must match the number of predictions ({y_true_scaled_all_horizons.shape[0]}).")
if not isinstance(prediction_time_index_h1, pd.DatetimeIndex):
logger.warning("prediction_time_index_h1 is not a DatetimeIndex. Time shifts may not work as expected.")
if not forecast_horizons or len(forecast_horizons) == 0:
raise ValueError("forecast_horizons list cannot be empty.")
logger.debug(f"Creating multi-horizon time series plot: {title}")
setup_plot_style() # Apply standard style
fig, ax = plt.subplots(figsize=(18, 8)) # Larger figure for multi-horizon
n_points = y_true_scaled_all_horizons.shape[0]
plot_indices = np.arange(n_points)
if max_points and n_points > max_points:
step = max(1, n_points // max_points)
plot_indices = plot_indices[::step]
# Subsample the data and index
y_true_scaled_plot = y_true_scaled_all_horizons[plot_indices]
y_pred_scaled_plot = y_pred_scaled_all_horizons[plot_indices]
time_index_h1_plot = prediction_time_index_h1[plot_indices]
effective_title = f'{title} (Sampled {len(plot_indices)} points)'
else:
y_true_scaled_plot = y_true_scaled_all_horizons
y_pred_scaled_plot = y_pred_scaled_all_horizons
time_index_h1_plot = prediction_time_index_h1
effective_title = title
# Inverse transform the subsampled data
y_true_inv_plot = None
y_pred_inv_plot = None
if target_scaler is not None:
try:
# Scaler expects (N * H, 1), reshape (N, H) to (N*H, 1)
y_true_inv_plot_flat = target_scaler.inverse_transform(y_true_scaled_plot.reshape(-1, 1))
y_pred_inv_plot_flat = target_scaler.inverse_transform(y_pred_scaled_plot.reshape(-1, 1))
# Reshape back to (N, H)
y_true_inv_plot = y_true_inv_plot_flat.reshape(y_true_scaled_plot.shape)
y_pred_inv_plot = y_pred_inv_plot_flat.reshape(y_pred_scaled_plot.shape)
logger.debug("Successfully inverse-transformed data for multi-horizon plot.")
except Exception as e:
logger.error(f"Failed to inverse transform data for multi-horizon plot: {e}", exc_info=True)
# Fallback to plotting scaled data if inverse transform fails
y_true_inv_plot = y_true_scaled_plot
y_pred_inv_plot = y_pred_scaled_plot
ylabel = f"{ylabel} (Scaled Data - Inverse Transform Failed)"
if y_true_inv_plot is None or y_pred_inv_plot is None:
# This should not happen with the fallback, but as a safeguard
logger.error("Inverse transformed data is None, cannot plot.")
return fig # Return empty figure
# Plot Actuals (using h1's time index, as it's the reference point)
ax.plot(time_index_h1_plot, y_true_inv_plot[:, 0], label='Actuals', marker='.', linestyle='-', markersize=4, linewidth=1.5, color='black') # Actuals for H1
# Plot predictions for each horizon
colors = sns.color_palette("viridis", len(forecast_horizons)) # Use palette for distinct colors
linestyles = ['-', '--', '-.', ':'] * (len(forecast_horizons) // 4 + 1) # Cycle through linestyles
for i, horizon in enumerate(forecast_horizons):
preds_h = y_pred_inv_plot[:, i]
# Calculate time index for this specific horizon by shifting the h1 index
# Assumes the time index frequency is appropriate for the horizon steps
try:
time_index_h = time_index_h1_plot + pd.to_timedelta(horizon - forecast_horizons[0], unit='h') # Assuming 'h' for hours
ax.plot(time_index_h, preds_h, label=f'Predicted (h={horizon})', marker='x', linestyle=linestyles[i], markersize=4, alpha=0.8, linewidth=1, color=colors[i])
except Exception as e:
logger.warning(f"Could not calculate time index for horizon {horizon}: {e}. Skipping plot for this horizon.", exc_info=True)
# Configure plot appearance
ax.set_title(effective_title, fontsize=16) # Slightly larger title
ax.set_xlabel(xlabel, fontsize=12)
ax.set_ylabel(ylabel, fontsize=12)
ax.legend(fontsize=10) # Smaller legend font
ax.grid(True, linestyle='--', alpha=0.6)
# Improve x-axis readability for datetimes
fig.autofmt_xdate() # Auto-rotate date labels
fig.tight_layout()
return fig
def plot_loss_curve_from_csv(
metrics_csv_path: Union[str, Path],
output_path: Union[str, Path],
title: str = "Training Loss Curve",
train_loss_col: str = "train_loss", # Changed to match logging in model.py
val_loss_col: str = "val_loss", # Common validation loss metric logged by PL
epoch_col: str = "epoch"
) -> None:
"""
Reads training metrics from a PyTorch Lightning CSVLogger file and plots
training and validation loss curves over epochs.
Args:
metrics_csv_path: Path to the metrics.csv file generated by CSVLogger.
output_path: Path where the plot image will be saved.
title: Title for the plot.
train_loss_col: Name of the column containing epoch-level training loss.
val_loss_col: Name of the column containing epoch-level validation loss.
epoch_col: Name of the column containing the epoch number.
Raises:
FileNotFoundError: If the metrics_csv_path does not exist.
KeyError: If required columns are not found in the CSV.
Exception: For other plotting or file reading errors.
"""
logger.info(f"Generating loss curve plot from: {metrics_csv_path}")
metrics_path = Path(metrics_csv_path)
if not metrics_path.is_file():
raise FileNotFoundError(f"Metrics CSV file not found at: {metrics_path}")
try:
metrics_df = pd.read_csv(metrics_path)
# Check if required columns exist
required_cols = [epoch_col, train_loss_col]
# Val loss column might be the scaled loss or the original scale MAE
possible_val_cols = [val_loss_col, 'val_MeanAbsoluteError_Original_Scale', 'val_mae_orig_scale'] # Include potential names
found_val_col = None
for col in possible_val_cols:
if col in metrics_df.columns:
found_val_col = col
break
if not found_val_col:
missing_cols = [col for col in required_cols if col not in metrics_df.columns]
raise KeyError(f"Missing required columns in {metrics_path}: {missing_cols} or a suitable validation loss/metric column from {possible_val_cols}.")
# --- Plotting ---
setup_plot_style() # Apply standard style
fig, ax1 = plt.subplots(figsize=(12, 6))
color1 = 'tab:red'
ax1.set_xlabel(epoch_col.capitalize())
# Adjust ylabel based on actual column name used for train loss
ax1.set_ylabel(train_loss_col.replace('_epoch','').replace('_',' ').capitalize(), color=color1)
# Drop NaNs specific to this column for plotting integrity
train_plot_data = metrics_df[[epoch_col, train_loss_col]].dropna(subset=[train_loss_col])
# Filter for epoch column only if needed (usually not for loss plots)
# train_plot_data = train_plot_data[train_plot_data[epoch_col].notna()]
# Ensure epoch starts from 0 or 1 consistently
if train_plot_data[epoch_col].min() > 0 and 0 in metrics_df[epoch_col].unique():
# If epoch starts from 1 in plot data but 0 exists, adjust x-axis for alignment
ax1.plot(train_plot_data[epoch_col] + 1, train_plot_data[train_loss_col], color=color1, label='Train Loss', marker='.', linestyle='-')
logger.debug("Adjusting train loss x-axis by +1 for epoch alignment.")
else:
ax1.plot(train_plot_data[epoch_col], train_plot_data[train_loss_col], color=color1, label='Train Loss', marker='.', linestyle='-')
ax1.tick_params(axis='y', labelcolor=color1)
ax1.grid(True, axis='y', linestyle='--', alpha=0.6, which='major')
# Validation loss/metric plotting on twin axis
ax2 = ax1.twinx()
color2 = 'tab:blue'
# Adjust ylabel based on actual column name used for val metric
ax2.set_ylabel(found_val_col.replace('_epoch','').replace('_',' ').capitalize(), color=color2)
# Drop NaNs specific to the found validation column
val_plot_data = metrics_df[[epoch_col, found_val_col]].dropna(subset=[found_val_col])
# val_plot_data = val_plot_data[val_plot_data[epoch_col].notna()] # Ensure epoch is not NaN
# Ensure epoch starts from 0 or 1 consistently
if val_plot_data[epoch_col].min() > 0 and 0 in metrics_df[epoch_col].unique():
# If epoch starts from 1 in plot data but 0 exists, adjust x-axis for alignment
ax2.plot(val_plot_data[epoch_col] + 1, val_plot_data[found_val_col], color=color2, label='Validation Metric', marker='x', linestyle='--')
logger.debug("Adjusting val metric x-axis by +1 for epoch alignment.")
else:
ax2.plot(val_plot_data[epoch_col], val_plot_data[found_val_col], color=color2, label='Validation Metric', marker='x', linestyle='--')
ax2.tick_params(axis='y', labelcolor=color2)
# Add legend manually combining lines from both axes
lines, labels = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax2.legend(lines + lines2, labels + labels2, loc='upper right')
plt.title(title, fontsize=14)
fig.tight_layout() # Otherwise the right y-label is slightly clipped
# Save the plot
save_plot(fig, output_path)
except pd.errors.EmptyDataError:
logger.error(f"Metrics CSV file is empty: {metrics_csv_path}")
except KeyError as e:
logger.error(f"Could not find expected column in {metrics_csv_path}: {e}")
raise # Re-raise specific error after logging
except Exception as e:
logger.error(f"Failed to create or save loss curve plot from {metrics_csv_path}: {e}", exc_info=True)
raise # Re-raise general errors

View File

@ -0,0 +1,20 @@
"""
TODO
"""
__version__ = "0.1.0"
# Expose core components for easier import
from .ensemble_evaluation import (
run_ensemble_evaluation
)
# Expose main configuration class from utils
from ..utils import MainConfig
# Define __all__ for explicit public API (optional but good practice)
__all__ = [
"run_ensemble_evaluation",
"MainConfig",
]

View File

@ -0,0 +1,276 @@
"""
Classic training routine: Train on initial data segment, validate and test on final segments.
"""
import logging
import time
from pathlib import Path
import pandas as pd
import torch
import yaml
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import CSVLogger
from typing import Dict, Optional
from forecasting_model.utils.forecast_config_model import MainConfig
from forecasting_model.data_processing import prepare_fold_data_and_loaders, split_data_classic
from forecasting_model.train.model import LSTMForecastLightningModule
from forecasting_model.evaluation import evaluate_fold_predictions
from forecasting_model.utils.helper import save_results
from forecasting_model.io.plotting import plot_loss_curve_from_csv
logger = logging.getLogger(__name__)
def run_classic_training(
config: MainConfig,
full_df: pd.DataFrame,
output_base_dir: Path
) -> Optional[Dict[str, float]]:
"""
Runs a single training pipeline using a classic train/val/test split.
Args:
config: The main configuration object.
full_df: The complete raw DataFrame.
output_base_dir: The base directory where general outputs are saved.
Classic results will be saved in a subdirectory.
Returns:
A dictionary containing test metrics (e.g., {'MAE': ..., 'RMSE': ...})
for the classic run, or None if it fails.
"""
run_start_time = time.perf_counter()
logger.info("--- Starting Classic Training Run ---")
# Define a specific output directory for this run
classic_output_dir = output_base_dir / "classic_run"
classic_output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Classic run outputs will be saved to: {classic_output_dir}")
test_metrics: Optional[Dict[str, float]] = None
best_val_score: Optional[float] = None
best_model_path: Optional[str] = None
try:
# --- Data Splitting ---
logger.info("Splitting data into classic train/val/test sets...")
n_samples = len(full_df)
val_frac = config.cross_validation.val_size_fraction
test_frac = config.cross_validation.test_size_fraction
train_idx, val_idx, test_idx = split_data_classic(n_samples, val_frac, test_frac)
# Store test datetime index for evaluation plotting
test_datetime_index = full_df.iloc[test_idx].index
# --- Data Preparation ---
logger.info("Preparing data loaders for the classic split...")
train_loader, val_loader, test_loader, target_scaler, input_size = prepare_fold_data_and_loaders(
full_df=full_df,
train_idx=train_idx,
val_idx=val_idx,
test_idx=test_idx,
target_col=config.data.target_col,
feature_config=config.features,
train_config=config.training,
eval_config=config.evaluation
)
logger.info(f"Data loaders prepared. Input size determined: {input_size}")
# Save artifacts specific to this run if needed (e.g., for later inference)
torch.save(test_loader, classic_output_dir / "classic_test_loader.pt")
torch.save(target_scaler, classic_output_dir / "classic_target_scaler.pt")
torch.save(input_size, classic_output_dir / "classic_input_size.pt")
# Save config for this run
try: config_dump = config.model_dump()
except AttributeError: config_dump = config.model_dump()
with open(classic_output_dir / "config.yaml", 'w') as f:
yaml.dump(config_dump, f, default_flow_style=False)
# --- Model Initialization ---
model = LSTMForecastLightningModule(
model_config=config.model,
train_config=config.training,
input_size=input_size,
target_scaler=target_scaler
)
logger.info("Classic LSTMForecastLightningModule initialized.")
# --- PyTorch Lightning Callbacks ---
monitor_metric = "val_MeanAbsoluteError" # Monitor same metric as CV folds
monitor_mode = "min"
early_stop_callback = None
if config.training.early_stopping_patience is not None and config.training.early_stopping_patience > 0:
early_stop_callback = EarlyStopping(
monitor=monitor_metric, min_delta=0.0001,
patience=config.training.early_stopping_patience, verbose=True, mode=monitor_mode
)
logger.info(f"Enabled EarlyStopping: monitor='{monitor_metric}', patience={config.training.early_stopping_patience}")
checkpoint_callback = ModelCheckpoint(
dirpath=classic_output_dir / "checkpoints",
filename="best_classic_model", # Simple filename
save_top_k=1, monitor=monitor_metric, mode=monitor_mode, verbose=True
)
logger.info(f"Enabled ModelCheckpoint: monitor='{monitor_metric}', mode='{monitor_mode}'")
lr_monitor = LearningRateMonitor(logging_interval='epoch')
callbacks = [checkpoint_callback, lr_monitor]
if early_stop_callback: callbacks.append(early_stop_callback)
# --- PyTorch Lightning Logger ---
pl_logger = CSVLogger(save_dir=str(classic_output_dir), name="training_logs")
logger.info(f"Using CSVLogger, logs will be saved in: {pl_logger.log_dir}")
# --- PyTorch Lightning Trainer ---
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
devices = 1 if accelerator == 'gpu' else None
precision = getattr(config.training, 'precision', 32)
trainer = pl.Trainer(
accelerator=accelerator, devices=devices,
max_epochs=config.training.epochs,
callbacks=callbacks, logger=pl_logger,
log_every_n_steps=max(1, len(train_loader)//10),
enable_progress_bar=True,
gradient_clip_val=getattr(config.training, 'gradient_clip_val', None),
precision=precision,
)
logger.info(f"Initialized PyTorch Lightning Trainer: accelerator='{accelerator}', devices={devices}, precision={precision}")
# --- Training ---
logger.info("Starting classic model training...")
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
logger.info("Classic model training finished.")
# Store best validation score and path
best_val_score_tensor = trainer.checkpoint_callback.best_model_score
best_model_path = trainer.checkpoint_callback.best_model_path
best_val_score = best_val_score_tensor.item() if best_val_score_tensor is not None else None
if best_val_score is not None:
logger.info(f"Best validation score ({monitor_metric}): {best_val_score:.4f}")
logger.info(f"Best model checkpoint path: {best_model_path}")
else:
logger.warning(f"Could not retrieve best validation score/path (metric: {monitor_metric}). Evaluation might use last model.")
best_model_path = None
# --- Prediction on Test Set ---
logger.info("Starting prediction on classic test set using best checkpoint...")
prediction_results_list = trainer.predict(
ckpt_path=best_model_path if best_model_path else 'last',
dataloaders=test_loader
)
# --- Evaluation ---
if not prediction_results_list:
logger.error("Predict phase did not return any results for classic run.")
test_metrics = None
else:
try:
# Shapes: (n_samples, len(horizons))
all_preds_scaled = torch.cat([b['preds_scaled'] for b in prediction_results_list], dim=0).numpy()
n_predictions = len(all_preds_scaled) # Number of samples actually predicted
if 'targets_scaled' in prediction_results_list[0]:
all_targets_scaled = torch.cat([b['targets_scaled'] for b in prediction_results_list], dim=0).numpy()
if len(all_targets_scaled) != n_predictions:
logger.error(f"Classic Run: Mismatch between number of predictions ({n_predictions}) and targets ({len(all_targets_scaled)}).")
raise ValueError("Prediction and target count mismatch during classic evaluation.")
else:
raise ValueError("Targets missing from prediction results.")
logger.info(f"Processing {n_predictions} prediction results for classic test set...")
# --- Calculate Correct Time Index for Plotting (First Horizon) ---
target_time_index_for_plotting = None
if test_idx is not None and config.features.forecast_horizon:
try:
test_block_index = full_df.index[test_idx] # Use the test_idx from classic split
seq_len = config.features.sequence_length
first_horizon = config.features.forecast_horizon[0]
start_offset = seq_len + first_horizon - 1
if start_offset < len(test_block_index):
end_index = min(start_offset + n_predictions, len(test_block_index))
target_time_index_for_plotting = test_block_index[start_offset:end_index]
if len(target_time_index_for_plotting) != n_predictions:
logger.warning(f"Classic Run: Calculated target time index length ({len(target_time_index_for_plotting)}) "
f"does not match prediction count ({n_predictions}). Plotting x-axis might be misaligned.")
target_time_index_for_plotting = None
else:
logger.warning(f"Classic Run: Cannot calculate target time index, start offset ({start_offset}) "
f"exceeds test block length ({len(test_block_index)}).")
except Exception as e:
logger.error(f"Classic Run: Error calculating target time index for plotting: {e}", exc_info=True)
target_time_index_for_plotting = None # Ensure it's None if error occurs
else:
logger.warning(f"Classic Run: Skipping target time index calculation (missing test_idx or forecast_horizon).")
# --- End Index Calculation ---
# Use the classic run specific objects and config
test_metrics = evaluate_fold_predictions(
y_true_scaled=all_targets_scaled,
y_pred_scaled=all_preds_scaled,
target_scaler=target_scaler,
eval_config=config.evaluation,
fold_num=-1, # Indicate classic run
output_dir=str(classic_output_dir),
plot_subdir="plots",
prediction_time_index=target_time_index_for_plotting, # Pass the correctly calculated index
forecast_horizons=config.features.forecast_horizon,
plot_title_prefix="Classic Run"
)
# Save metrics
save_results({"overall_metrics": test_metrics}, classic_output_dir / "test_metrics.json")
logger.info(f"Classic run test metrics (overall): {test_metrics}")
# --- Plot Loss Curve for Classic Run ---
try:
# Adjusted logic to find metrics.csv inside potential version_*/ directories
classic_log_dir = classic_output_dir / "training_logs"
metrics_file = None
version_dirs = list(classic_log_dir.glob("version_*"))
if version_dirs:
# Assuming the latest version directory contains the relevant logs
latest_version_dir = max(version_dirs, key=lambda p: p.stat().st_mtime)
potential_metrics_file = latest_version_dir / "metrics.csv"
if potential_metrics_file.is_file():
metrics_file = potential_metrics_file
else:
logger.warning(f"Classic Run: metrics.csv not found in latest version directory: {latest_version_dir}")
else:
# Fallback if no version_* directories exist (less common with CSVLogger)
potential_metrics_file = classic_log_dir / "metrics.csv"
if potential_metrics_file.is_file():
metrics_file = potential_metrics_file
if metrics_file and metrics_file.is_file():
plot_loss_curve_from_csv(
metrics_csv_path=metrics_file,
output_path=classic_output_dir / "loss_curve.png",
title="Classic Run Training Progression",
train_loss_col='train_loss', # Changed from 'train_loss_epoch'
val_loss_col='val_loss' # Keep as 'val_loss'
)
logger.info(f"Generating loss curve for classic run from: {metrics_file}")
else:
logger.warning(f"Classic Run: Could not find metrics.csv in {classic_log_dir} or its version subdirectories for loss curve plot.")
except Exception as plot_e:
logger.error(f"Classic Run: Failed to generate loss curve plot: {plot_e}", exc_info=True)
# --- End Classic Loss Plotting ---
except (KeyError, ValueError, Exception) as e:
logger.error(f"Error processing classic prediction results: {e}", exc_info=True)
test_metrics = None
except Exception as e:
logger.error(f"An error occurred during the classic training pipeline: {e}", exc_info=True)
test_metrics = None # Indicate failure
finally:
if torch.cuda.is_available(): torch.cuda.empty_cache()
run_end_time = time.perf_counter()
logger.info(f"--- Finished Classic Training Run in {run_end_time - run_start_time:.2f} seconds ---")
return test_metrics

View File

@ -0,0 +1,425 @@
"""
Ensemble evaluation for time series forecasting models.
This module provides functionality to evaluate ensemble predictions
by combining predictions from n-1 folds and testing on the remaining fold.
"""
import logging
import numpy as np
import torch
import yaml # For loading fold config
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import pandas as pd # For time index handling
import pickle # Need pickle for the specific error check
from forecasting_model.evaluation import evaluate_fold_predictions
from forecasting_model.train.model import LSTMForecastLightningModule
from forecasting_model.utils.forecast_config_model import MainConfig
logger = logging.getLogger(__name__)
def load_fold_model_and_objects(
fold_dir: Path,
) -> Optional[Tuple[LSTMForecastLightningModule, MainConfig, torch.utils.data.DataLoader, Union[StandardScaler, MinMaxScaler, None], int, Optional[pd.Index], List[int]]]:
"""
Load a trained model, its config, dataloader, scaler, input_size, prediction time index, and forecast horizons.
Args:
fold_dir: Directory containing the fold's artifacts (checkpoint, config, loader, etc.).
Returns:
A tuple containing (model, config, test_loader, target_scaler, input_size, prediction_target_time_index, forecast_horizons)
or None if any essential artifact is missing or loading fails.
"""
try:
logger.info(f"Loading artifacts from: {fold_dir}")
# 1. Load Fold Configuration
config_path = fold_dir / "config.yaml"
if not config_path.is_file():
logger.error(f"Fold config file not found in {fold_dir}")
return None
with open(config_path, 'r') as f:
fold_config_dict = yaml.safe_load(f)
fold_config = MainConfig(**fold_config_dict) # Validate fold's config
# 2. Load Saved Objects using torch.load
test_loader_path = fold_dir / "test_loader.pt"
scaler_path = fold_dir / "target_scaler.pt"
input_size_path = fold_dir / "input_size.pt"
prediction_index_path = fold_dir / "prediction_target_time_index.pt"
if not all([p.is_file() for p in [test_loader_path, scaler_path, input_size_path]]):
logger.error(f"Missing one or more required artifacts (test_loader, target_scaler, input_size) in {fold_dir}")
return None
try:
# --- Explicitly set weights_only=False for non-model objects ---
test_loader = torch.load(test_loader_path, weights_only=False)
target_scaler = torch.load(scaler_path, weights_only=False)
input_size = torch.load(input_size_path, weights_only=False)
# --- End Modification ---
except pickle.UnpicklingError as e:
# Catch potential unpickling errors even with weights_only=False
logger.error(f"Failed to unpickle saved object in {fold_dir}: {e}", exc_info=True)
return None
except AttributeError as e:
# Catch potential issues if class definitions changed between saving and loading
logger.error(f"AttributeError loading saved object in {fold_dir} (class definition changed?): {e}", exc_info=True)
return None
except Exception as e:
# Catch other potential loading errors
logger.error(f"Unexpected error loading saved objects (loader/scaler/size) from {fold_dir}: {e}", exc_info=True)
return None
# Retrieve forecast horizon list from the fold's config
forecast_horizons = fold_config.features.forecast_horizon
# --- Extract prediction target time index (if available) ---
prediction_target_time_index: Optional[pd.Index] = None
if prediction_index_path.is_file():
try:
prediction_target_time_index = torch.load(prediction_index_path, weights_only=False)
# Basic validation
if not isinstance(prediction_target_time_index, pd.Index):
logger.warning(f"Loaded prediction index from {prediction_index_path} is not a pandas Index.")
prediction_target_time_index = None
else:
logger.debug(f"Loaded prediction target time index from {prediction_index_path}")
except Exception as e:
logger.warning(f"Failed to load prediction target time index from {prediction_index_path}: {e}")
else:
logger.warning(f"Prediction target time index file not found at {prediction_index_path}. Plotting x-axis might be inaccurate for ensemble plots.")
# --- End Index Extraction ---
# 3. Find Checkpoint and Load Model
checkpoint_path = None
try:
# Use rglob to find the checkpoint potentially nested deeper
checkpoints = list(fold_dir.glob("**/best_model_fold_*.ckpt"))
if not checkpoints:
logger.error(f"No 'best_model_fold_*.ckpt' checkpoint found in {fold_dir} or subdirectories.")
return None
if len(checkpoints) > 1:
logger.warning(f"Multiple checkpoints found in {fold_dir}, using the first one: {checkpoints[0]}")
checkpoint_path = checkpoints[0]
logger.info(f"Loading model from checkpoint: {checkpoint_path}")
model = LSTMForecastLightningModule.load_from_checkpoint(
checkpoint_path,
map_location=torch.device('cpu'), # Optional: load to CPU first if memory is tight
model_config=fold_config.model,
train_config=fold_config.training,
input_size=input_size,
target_scaler=target_scaler
)
model.eval()
logger.info(f"Successfully loaded model and artifacts from {fold_dir}")
return model, fold_config, test_loader, target_scaler, input_size, prediction_target_time_index, forecast_horizons
except FileNotFoundError:
logger.error(f"Checkpoint file not found: {checkpoint_path}")
return None
except Exception as e:
logger.error(f"Failed to load model from checkpoint {checkpoint_path} in {fold_dir}: {e}", exc_info=True)
return None
except Exception as e:
logger.error(f"Generic error loading artifacts from {fold_dir}: {e}", exc_info=True)
return None
def make_ensemble_predictions(
models: List[LSTMForecastLightningModule],
test_loader: torch.utils.data.DataLoader,
device: Optional[torch.device] = None
) -> Tuple[Optional[Dict[str, np.ndarray]], Optional[np.ndarray]]:
"""
Make predictions using an ensemble of models efficiently.
Processes the test_loader once, getting predictions from all models per batch.
Args:
models: List of trained models (already in eval mode).
test_loader: DataLoader for the test set.
device: Device to run predictions on (e.g., torch.device("cuda:0")).
If None, attempts to use GPU if available, else CPU.
Returns:
Tuple of (ensemble_predictions, targets):
- ensemble_predictions: Dict containing ensemble predictions keyed by method
('mean', 'median', 'min', 'max'). Values are np.arrays.
Returns None if prediction fails.
- targets: Ground truth values as a single np.array. Returns None if prediction fails
or targets are unavailable in loader.
"""
if not models:
logger.warning("make_ensemble_predictions received an empty list of models.")
return None, None
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Running ensemble predictions on device: {device}")
# Move all models to the target device
for model in models:
model.to(device)
all_batch_preds: List[List[np.ndarray]] = [[] for _ in models] # Outer list: models, Inner list: batches
all_batch_targets: List[np.ndarray] = []
targets_available = True
with torch.no_grad():
for batch_idx, batch in enumerate(test_loader):
try:
# Determine if batch contains targets
if isinstance(batch, (list, tuple)) and len(batch) == 2:
x, y = batch
x = x.to(device)
# Keep targets on CPU until needed for concatenation
all_batch_targets.append(y.cpu().numpy())
else:
x = batch.to(device)
targets_available = False # No targets found in this batch
# Get predictions from all models for this batch
for i, model in enumerate(models):
try:
pred = model(x) # Shape: (batch, horizon)
all_batch_preds[i].append(pred.cpu().numpy())
except Exception as model_err:
logger.error(f"Error during prediction with model {i} on batch {batch_idx}: {model_err}", exc_info=True)
# Handle error: Fill with NaNs? Skip model? For now, fill with NaNs of expected shape
# Infer expected shape: (batch_size, horizon)
batch_size = x.shape[0]
horizon = models[0].output_size # Assume all models have same horizon
nan_preds = np.full((batch_size, horizon), np.nan)
all_batch_preds[i].append(nan_preds)
except Exception as batch_err:
logger.error(f"Error processing batch {batch_idx} for ensemble prediction: {batch_err}", exc_info=True)
# If a batch fails catastrophically, we might not be able to proceed reliably
return None, None # Indicate failure
# Concatenate batch results for each model
model_preds_concat = []
for i in range(len(models)):
if not all_batch_preds[i]: # Check if any predictions were collected for this model
logger.warning(f"No predictions collected for model index {i}. Skipping this model in ensemble.")
continue # Skip this model if it failed on all batches
try:
model_preds_concat.append(np.concatenate(all_batch_preds[i], axis=0))
except ValueError as e:
logger.error(f"Failed to concatenate predictions for model index {i}: {e}. Check for shape mismatches or empty lists.")
# Decide how to handle: skip model or fail? Let's skip for robustness.
continue
if not model_preds_concat:
logger.error("No valid predictions collected from any model in the ensemble.")
return None, None
# Concatenate targets if available
targets_concat = None
if targets_available and all_batch_targets:
try:
targets_concat = np.concatenate(all_batch_targets, axis=0)
except ValueError as e:
logger.error(f"Failed to concatenate targets: {e}")
return None, None # Fail if targets were expected but couldn't be combined
elif targets_available and not all_batch_targets:
logger.warning("Targets were expected based on first batch, but none were collected.")
# Proceed without targets, returning None for them
# Stack predictions from all models: Shape (num_models, num_samples, horizon)
try:
stacked_preds = np.stack(model_preds_concat, axis=0)
except ValueError as e:
logger.error(f"Failed to stack model predictions: {e}. Check if all models produced compatible shapes.")
return None, targets_concat # Return targets if available, but no ensemble preds
# Calculate different ensemble predictions (handle NaNs potentially introduced by model failures)
# np.nanmean, np.nanmedian etc. ignore NaNs
ensemble_preds = {
'mean': np.nanmean(stacked_preds, axis=0),
'median': np.nanmedian(stacked_preds, axis=0),
'min': np.nanmin(stacked_preds, axis=0),
'max': np.nanmax(stacked_preds, axis=0)
}
logger.info(f"Ensemble predictions generated using {stacked_preds.shape[0]} models.")
return ensemble_preds, targets_concat
def evaluate_ensemble_for_test_fold(
test_fold_num: int,
all_fold_dirs: List[Path],
output_base_dir: Path,
# full_data_index: Optional[pd.Index] = None # Removed, get from loaded objects
) -> Optional[Dict[str, Dict[str, float]]]:
"""
Evaluates ensemble predictions for a specific test fold.
Args:
test_fold_num: The 1-based number of the fold to use as the test set.
all_fold_dirs: List of paths to all fold directories.
output_base_dir: Base directory for saving evaluation results/plots.
Returns:
Dictionary containing metrics for each ensemble method for this test fold,
or None if evaluation fails.
"""
logger.info(f"--- Evaluating Ensemble: Test Fold {test_fold_num} ---")
test_fold_dir = output_base_dir / f"fold_{test_fold_num:02d}"
load_result = load_fold_model_and_objects(test_fold_dir)
if load_result is None:
logger.error(f"Failed to load necessary artifacts for test fold {test_fold_num}. Skipping ensemble evaluation for this fold.")
return None
# Unpack results including the prediction time index and horizons
_, test_fold_config, test_loader, target_scaler, _, prediction_target_time_index, test_forecast_horizons = load_result
# Load models from all *other* folds
ensemble_models: List[LSTMForecastLightningModule] = []
model_forecast_horizons = None # Track horizons from loaded models
for i, fold_dir in enumerate(all_fold_dirs):
current_fold_num = i + 1
if current_fold_num == test_fold_num:
continue # Skip the test fold itself
model_load_result = load_fold_model_and_objects(fold_dir)
if model_load_result:
model, _, _, _, _, _, fold_horizons = model_load_result # Only need the model here
if model:
ensemble_models.append(model)
# Store horizons from the first successful model load
if model_forecast_horizons is None:
model_forecast_horizons = fold_horizons
# Optional: Check consistency of horizons across ensemble models
elif set(model_forecast_horizons) != set(fold_horizons):
logger.error(f"Inconsistent forecast horizons between ensemble models! Test fold {test_fold_num} expected {test_forecast_horizons}, "
f"Model {i+1} has {fold_horizons}. Ensemble may be invalid.")
# Decide how to handle: error out, or proceed with caution?
# return None # Option: Fail hard
else:
logger.warning(f"Could not load model from fold {current_fold_num} to include in ensemble for test fold {test_fold_num}.")
if len(ensemble_models) < 2:
logger.warning(f"Skipping ensemble evaluation for test fold {test_fold_num}: "
f"Need at least 2 models for ensemble, only loaded {len(ensemble_models)}.")
return {} # Return empty dict, not None, to indicate process ran but no ensemble formed
# Check consistency between test fold horizons and ensemble model horizons
if model_forecast_horizons is None: # Should not happen if len(ensemble_models) >= 1
logger.error(f"Could not determine forecast horizons from ensemble models for test fold {test_fold_num}.")
return None
if set(test_forecast_horizons) != set(model_forecast_horizons):
logger.error(f"Forecast horizons of test fold {test_fold_num} ({test_forecast_horizons}) do not match "
f"horizons from ensemble models ({model_forecast_horizons}). Cannot evaluate.")
return None
# Make ensemble predictions using the loaded models and the test fold's data loader
# Use the test fold's config to determine device implicitly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ensemble_preds_dict, targets_np = make_ensemble_predictions(ensemble_models, test_loader, device=device)
if ensemble_preds_dict is None or targets_np is None:
logger.error(f"Failed to generate ensemble predictions or retrieve targets for test fold {test_fold_num}.")
return None # Indicate failure
# Evaluate each ensemble method's predictions against the test fold's targets
fold_ensemble_results: Dict[str, Dict[str, float]] = {}
for method, preds_np in ensemble_preds_dict.items():
logger.info(f"Evaluating ensemble method '{method}' for test fold {test_fold_num}...")
# Define a unique output directory for this method's plots
method_plot_dir = output_base_dir / "ensemble_eval_plots" / f"test_fold_{test_fold_num:02d}" / f"method_{method}"
# Use the prediction_target_time_index loaded earlier
prediction_time_index_for_plot = None
if prediction_target_time_index is not None:
if len(prediction_target_time_index) == targets_np.shape[0]:
prediction_time_index_for_plot = prediction_target_time_index
else:
logger.warning(f"Length of loaded prediction target time index ({len(prediction_target_time_index)}) does not match "
f"number of samples ({targets_np.shape[0]}) for test fold {test_fold_num}, method '{method}'. Plot x-axis may be incorrect.")
# Call the standard evaluation function
metrics = evaluate_fold_predictions(
y_true_scaled=targets_np,
y_pred_scaled=preds_np,
target_scaler=target_scaler,
eval_config=test_fold_config.evaluation,
fold_num=test_fold_num - 1,
output_dir=str(method_plot_dir.parent.parent),
plot_subdir=f"method_{method}",
prediction_time_index=prediction_time_index_for_plot, # Pass the index
forecast_horizons=test_forecast_horizons,
plot_title_prefix=f"Ensemble ({method})"
)
fold_ensemble_results[method] = metrics
logger.info(f"--- Finished Ensemble Evaluation: Test Fold {test_fold_num} ---")
return fold_ensemble_results
def run_ensemble_evaluation(
config: MainConfig, # Pass main config for context if needed, though fold configs are loaded
output_base_dir: Path,
# full_data_index: Optional[pd.Index] = None # Removed, get index from loaded objects
) -> Dict[int, Dict[str, Dict[str, float]]]:
"""
Run ensemble evaluation across all folds, treating each as the test set once.
Args:
config: The main configuration object (potentially unused if fold configs sufficient).
output_base_dir: Base directory where fold outputs are stored.
Returns:
Dictionary containing ensemble metrics for each test fold:
{ test_fold_num: { ensemble_method: { metric_name: value, ... }, ... }, ... }
"""
logger.info("===== Starting Cross-Validated Ensemble Evaluation =====")
all_ensemble_results: Dict[int, Dict[str, Dict[str, float]]] = {}
# Discover fold directories
fold_dirs = sorted([d for d in output_base_dir.glob("fold_*") if d.is_dir()])
if not fold_dirs:
logger.error(f"No fold directories found in {output_base_dir} for ensemble evaluation.")
return {}
if len(fold_dirs) < 2:
logger.warning(f"Need at least 2 folds for ensemble evaluation, found {len(fold_dirs)}. Skipping.")
return {}
logger.info(f"Found {len(fold_dirs)} fold directories.")
# Iterate through each fold, designating it as the test fold
for i, test_fold_dir in enumerate(fold_dirs):
test_fold_num = i + 1 # 1-based fold number
try:
results_for_test_fold = evaluate_ensemble_for_test_fold(
test_fold_num=test_fold_num,
all_fold_dirs=fold_dirs,
output_base_dir=output_base_dir,
# full_data_index=full_data_index # Removed
)
if results_for_test_fold is not None:
# Only add results if the evaluation didn't fail completely
all_ensemble_results[test_fold_num] = results_for_test_fold
except Exception as e:
# Catch unexpected errors during a specific test fold evaluation
logger.error(f"Unexpected error during ensemble evaluation with test fold {test_fold_num}: {e}", exc_info=True)
continue # Continue to the next fold
# Saving is handled by the main script (`forecasting_model_run.py`) which calls this
if not all_ensemble_results:
logger.warning("Ensemble evaluation finished, but no results were generated.")
else:
logger.info("===== Finished Cross-Validated Ensemble Evaluation =====")
return all_ensemble_results

View File

View File

@ -9,7 +9,7 @@ from typing import Optional, Dict, Any, Union, List, Tuple
from sklearn.preprocessing import StandardScaler, MinMaxScaler
# Assuming config_model is in sibling directory utils/
from forecasting_model.utils.config_model import ModelConfig, TrainingConfig
from forecasting_model.utils.forecast_config_model import ModelConfig, TrainingConfig
logger = logging.getLogger(__name__)
@ -30,41 +30,42 @@ class LSTMForecastLightningModule(pl.LightningModule):
super().__init__()
# --- Validate & Store Configs ---
# Validate the input_size passed during instantiation
if input_size <= 0:
raise ValueError("`input_size` must be provided as a positive integer during model instantiation.")
self._input_size = input_size # Use a temporary attribute
# Store the validated input_size directly for use in layer definitions
self._input_size = input_size # Use a temporary attribute before hparams are saved
# Ensure forecast_horizon is a valid list in the config
if not hasattr(model_config, 'forecast_horizon') or \
not isinstance(model_config.forecast_horizon, list) or \
not model_config.forecast_horizon or \
any(h <= 0 for h in model_config.forecast_horizon):
raise ValueError("ModelConfig requires `forecast_horizon` to be a non-empty list of positive integers.")
# Ensure forecast_horizon is set in the config for the output layer
if not hasattr(model_config, 'forecast_horizon') or model_config.forecast_horizon is None or model_config.forecast_horizon <= 0:
raise ValueError("ModelConfig requires `forecast_horizon` to be set and positive.")
self.output_size = model_config.forecast_horizon
# Output size is the number of horizons we predict
self.output_size = len(model_config.forecast_horizon)
# Store the actual horizon list for reference if needed, ensure sorted
self.forecast_horizons = sorted(model_config.forecast_horizon)
# Store configurations - input_size argument will be saved via save_hyperparameters
self.model_config = model_config
self.train_config = train_config
self.target_scaler = target_scaler # Store scaler for this fold
# Use save_hyperparameters() to automatically log configs and allow loading
# Pass input_size explicitly to be saved in hparams
# Exclude scaler as it's stateful and fold-specific
# Use save_hyperparameters() - forecast_horizon is part of model_config which is saved
self.save_hyperparameters('model_config', 'train_config', 'input_size', ignore=['target_scaler'])
# Note: Pydantic models might not be perfectly saved/loaded by PL's hparams, check if needed.
# If issues arise loading, might need to flatten relevant hparams manually.
# --- Define Model Layers ---
# Access input_size via hparams now
self.lstm = nn.LSTM(
input_size=self.hparams.input_size,
hidden_size=self.hparams.model_config.hidden_size,
num_layers=self.hparams.model_config.num_layers,
batch_first=True, # Input shape: (batch, seq_len, features)
batch_first=True,
dropout=self.hparams.model_config.dropout if self.hparams.model_config.num_layers > 1 else 0.0
)
self.dropout = nn.Dropout(self.hparams.model_config.dropout)
# Output layer maps LSTM hidden state to the forecast horizon
# We typically take the output of the last time step
# Output layer maps LSTM hidden state to the number of forecast horizons
self.fc = nn.Linear(self.hparams.model_config.hidden_size, self.output_size)
# Optional residual connection handling
@ -96,7 +97,7 @@ class LSTMForecastLightningModule(pl.LightningModule):
self.val_metrics = metrics.clone(prefix='val_')
self.test_metrics = metrics.clone(prefix='test_')
self.val_mae_original_scale = torchmetrics.MeanAbsoluteError()
self.val_MeanAbsoluteError_Original_Scale = torchmetrics.MeanAbsoluteError()
def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -107,7 +108,8 @@ class LSTMForecastLightningModule(pl.LightningModule):
x: Input tensor of shape (batch_size, sequence_length, input_size)
Returns:
Predictions tensor of shape (batch_size, forecast_horizon)
Predictions tensor of shape (batch_size, len(forecast_horizons))
where each element corresponds to a predicted horizon in sorted order.
"""
# LSTM forward pass
lstm_out, (hidden, cell) = self.lstm(x) # Shape: (batch, seq_len, hidden_size)
@ -126,59 +128,50 @@ class LSTMForecastLightningModule(pl.LightningModule):
last_time_step_out = last_time_step_out + residual
# Final fully connected layer
predictions = self.fc(last_time_step_out) # Shape: (batch_size, output_size/horizon)
predictions = self.fc(last_time_step_out) # Shape: (batch_size, output_size/len(horizons))
return predictions # Shape: (batch_size, forecast_horizon)
return predictions # Shape: (batch_size, len(forecast_horizons))
def _calculate_loss(self, outputs, targets):
# Ensure shapes match before loss calculation
if outputs.shape != targets.shape:
# Squeeze potential extra dim: (batch, horizon, 1) -> (batch, horizon)
if outputs.ndim == targets.ndim + 1 and outputs.shape[-1] == 1:
outputs = outputs.squeeze(-1)
# Shapes should now be (batch_size, len(horizons)) for both
if outputs.shape != targets.shape:
# Minimal check, dataset __getitem__ should ensure this
raise ValueError(f"Output shape {outputs.shape} doesn't match target shape {targets.shape} for loss calculation.")
return self.criterion(outputs, targets)
def _inverse_transform(self, data: torch.Tensor) -> Optional[torch.Tensor]:
"""Helper to inverse transform data using the stored target scaler."""
"""Helper to inverse transform data (preds or targets) using the stored target scaler."""
if self.target_scaler is None:
# logger.warning("Cannot inverse transform: target_scaler not available.")
return None # Cannot inverse transform
return None
data_cpu = data.detach().cpu().numpy().astype(np.float64)
original_shape = data_cpu.shape # e.g., (batch_size, len(horizons))
num_elements = data_cpu.size
# Scaler expects 2D input (N, 1)
# Ensure data is on CPU and is float64 for sklearn scaler typically
data_cpu = data.detach().cpu().numpy().astype(np.float64)
original_shape = data_cpu.shape
if data_cpu.ndim == 1:
data_flat = data_cpu.reshape(-1, 1)
elif data_cpu.ndim == 2: # (batch, horizon)
data_flat = data_cpu.reshape(-1, 1)
else:
logger.warning(f"Unexpected shape for inverse transform: {original_shape}. Reshaping to (-1, 1).")
data_flat = data_cpu.reshape(-1, 1)
data_flat = data_cpu.reshape(num_elements, 1)
try:
inversed_np = self.target_scaler.inverse_transform(data_flat)
# Return as tensor on the original device
# Return as tensor on the original device, potentially reshaped
inversed_tensor = torch.from_numpy(inversed_np).float().to(data.device)
# Reshape back? Or keep flat? Keep flat for direct metric use often.
return inversed_tensor.flatten()
# return inversed_tensor.reshape(original_shape) # If original shape needed
# Reshape back to original multi-horizon shape
return inversed_tensor.reshape(original_shape)
# return inversed_tensor.flatten() # Keep flat if needed for specific metric inputs
except Exception as e:
logger.error(f"Failed to inverse transform data: {e}", exc_info=True)
return None # Return None if inverse transform fails
return None
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
x, y = batch # Shapes: x=(batch, seq_len, features), y=(batch, horizon)
outputs = self(x) # Scaled outputs: (batch, horizon)
x, y = batch # Shapes: x=(batch, seq_len, features), y=(batch, len(horizons))
outputs = self(x) # Scaled outputs: (batch, len(horizons))
loss = self._calculate_loss(outputs, y)
# Log scaled metrics
metrics = self.train_metrics(outputs, y) # Update internal state
self.train_metrics.update(outputs, y)
self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log_dict(self.train_metrics, on_step=False, on_epoch=True, logger=True) # Log all metrics in collection
self.log_dict(self.train_metrics, on_step=False, on_epoch=True, logger=True)
return loss
@ -188,20 +181,22 @@ class LSTMForecastLightningModule(pl.LightningModule):
loss = self._calculate_loss(outputs, y)
# Log scaled metrics
metrics = self.val_metrics(outputs, y) # Update internal state
self.val_metrics.update(outputs, y)
self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log_dict(self.val_metrics, on_step=False, on_epoch=True, logger=True)
# Log MAE on ORIGINAL scale if scaler is available (often the primary metric for checkpointing/Optuna)
# Log MAE on ORIGINAL scale (primary metric for checkpoints)
if self.target_scaler is not None:
# Inverse transform keeps the (batch, len(horizons)) shape
outputs_inv = self._inverse_transform(outputs)
y_inv = self._inverse_transform(y)
if outputs_inv is not None and y_inv is not None:
# Ensure shapes are compatible (flattened by _inverse_transform)
# Ensure shapes match
if outputs_inv.shape == y_inv.shape:
self.val_mae_original_scale.update(outputs_inv, y_inv)
self.log('val_mae_orig_scale', self.val_mae_original_scale, on_step=False, on_epoch=True, prog_bar=True, logger=True)
# It will compute the average MAE across all elements if multi-dim
self.val_MeanAbsoluteError_Original_Scale.update(outputs_inv, y_inv)
self.log('val_MeanAbsoluteError_Original_Scale', self.val_MeanAbsoluteError_Original_Scale, on_step=False, on_epoch=True, prog_bar=True, logger=True)
else:
logger.warning(f"Shape mismatch after inverse transform in validation: Preds {outputs_inv.shape}, Targets {y_inv.shape}")
else:

View File

@ -5,7 +5,7 @@ This package contains configuration models, helper functions, and other utilitie
"""
# Expose configuration models
from .config_model import (
from .forecast_config_model import (
MainConfig,
DataConfig,
FeatureConfig,

View File

@ -44,7 +44,7 @@ class DataConfig(BaseModel):
class FeatureConfig(BaseModel):
"""Configuration for feature engineering and preprocessing."""
sequence_length: int = Field(..., gt=0)
forecast_horizon: int = Field(..., gt=0)
forecast_horizon: List[int] = Field(..., min_length=1, description="List of specific forecast horizons to predict (e.g., [1, 6, 12]).")
lags: List[int] = []
rolling_window_sizes: List[int] = []
use_time_features: bool = True
@ -55,11 +55,11 @@ class FeatureConfig(BaseModel):
clipping: ClippingConfig = ClippingConfig() # Default instance
scaling_method: Optional[Literal['standard', 'minmax']] = 'standard' # Added literal validation
@field_validator('lags', 'rolling_window_sizes')
@field_validator('lags', 'rolling_window_sizes', 'forecast_horizon')
@classmethod
def check_positive_list_values(cls, v: List[int]) -> List[int]:
if any(val <= 0 for val in v):
raise ValueError('Lists lags/rolling_window_sizes must contain only positive values')
raise ValueError('Lists lags, rolling_window_sizes, and forecast_horizon must contain only positive values')
return v
class ModelConfig(BaseModel):
@ -69,8 +69,8 @@ class ModelConfig(BaseModel):
num_layers: int = Field(..., gt=0)
dropout: float = Field(..., ge=0.0, le=1.0)
use_residual_skips: bool = False
# Add forecast_horizon here to ensure LightningModule gets it directly
forecast_horizon: Optional[int] = Field(None, gt=0) # Will be set from FeatureConfig
# forecast_horizon: Optional[int] = Field(None, gt=0) # OLD
forecast_horizon: Optional[List[int]] = Field(None, min_length=1) # Will be set from FeatureConfig
class TrainingConfig(BaseModel):
"""Configuration for the training process (PyTorch Lightning)."""
@ -103,10 +103,11 @@ class EvaluationConfig(BaseModel):
class OptunaConfig(BaseModel):
"""Optional configuration for Optuna hyperparameter optimization."""
enabled: bool = False
study_name: str = "default_study" # Added study_name
n_trials: int = Field(20, gt=0)
storage: Optional[str] = None # e.g., "sqlite:///output/hpo_results/study.db"
direction: Literal['minimize', 'maximize'] = 'minimize'
metric_to_optimize: str = 'val_mae_orig_scale'
metric_to_optimize: str = 'val_MeanAbsoluteError_Original_Scale' # Updated default metric
pruning: bool = True
# --- Top-Level Configuration Model ---
@ -114,15 +115,23 @@ class OptunaConfig(BaseModel):
class MainConfig(BaseModel):
"""Main configuration model nesting all sections."""
project_name: str = "TimeSeriesForecasting"
random_seed: Optional[int] = 42 # Added top-level seed
random_seed: Optional[int] = 42
log_level: Literal['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] = 'INFO'
output_dir: str = Field("output/cv_results", description="Base directory for saving all outputs (results, logs, models, plots).")
# --- Execution Control ---
run_cross_validation: bool = Field(True, description="Run the main cross-validation training loop?")
run_classic_training: bool = Field(True, description="Run a single classic train/val/test split training?")
run_ensemble_evaluation: bool = Field(True, description="Run ensemble evaluation using CV fold models?")
# --- End Execution Control ---
data: DataConfig
features: FeatureConfig
model: ModelConfig # ModelConfig no longer contains input_size
model: ModelConfig
training: TrainingConfig
cross_validation: CrossValidationConfig
evaluation: EvaluationConfig
optuna: Optional[OptunaConfig] = OptunaConfig() # Added optional Optuna config
optuna: Optional[OptunaConfig] = OptunaConfig()
@model_validator(mode='after')
def check_forecast_horizon_consistency(self) -> 'MainConfig':
@ -131,20 +140,33 @@ class MainConfig(BaseModel):
if self.model.forecast_horizon is None:
# If model config doesn't have it, set it from features config
self.model.forecast_horizon = self.features.forecast_horizon
elif self.model.forecast_horizon != self.features.forecast_horizon:
elif set(self.model.forecast_horizon) != set(self.features.forecast_horizon): # Compare sets for content equality
# If both are set but differ, raise error
raise ValueError(
f"ModelConfig forecast_horizon ({self.model.forecast_horizon}) must match "
f"FeatureConfig forecast_horizon ({self.features.forecast_horizon})."
)
# After potential setting, ensure model.forecast_horizon is actually set
if self.model and (self.model.forecast_horizon is None or self.model.forecast_horizon <= 0):
raise ValueError("ModelConfig requires a positive forecast_horizon (must be set in features config if not set explicitly in model config).")
# After potential setting, ensure model.forecast_horizon is actually set and valid
if self.model and (
self.model.forecast_horizon is None or
not isinstance(self.model.forecast_horizon, list) or # Check type
len(self.model.forecast_horizon) == 0 or # Check not empty
any(h <= 0 for h in self.model.forecast_horizon) # Check positive values
):
raise ValueError("ModelConfig requires a non-empty list of positive forecast_horizon values (must be set in features config if not set explicitly in model config).")
# Input size check is removed as it's not part of static config anymore
return self
@model_validator(mode='after')
def check_execution_flags(self) -> 'MainConfig':
if not self.run_cross_validation and not self.run_classic_training:
raise ValueError("At least one of 'run_cross_validation' or 'run_classic_training' must be True.")
if self.run_ensemble_evaluation and not self.run_cross_validation:
raise ValueError("'run_ensemble_evaluation' requires 'run_cross_validation' to be True (needs CV fold models).")
return self
class Config:
# Example configuration for Pydantic itself
validate_assignment = True # Re-validate on assignment

View File

@ -0,0 +1,173 @@
import argparse
import json
import logging
import random
from pathlib import Path
from typing import Optional, List, Dict
import numpy as np
import pandas as pd
import torch
import yaml
from forecasting_model import MainConfig
# Get the root logger
logger = logging.getLogger(__name__)
def parse_arguments():
"""Parses command-line arguments."""
parser = argparse.ArgumentParser(
description="Run the Time Series Forecasting training pipeline using a configuration file.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
'-c', '--config',
type=str,
default='config.yaml',
help="Path to the YAML configuration file."
)
# Removed seed, debug, and output-dir arguments
args = parser.parse_args()
return args
def load_config(config_path: Path) -> MainConfig:
"""
Load and validate configuration from YAML file using Pydantic.
Args:
config_path: Path to the YAML configuration file.
Returns:
Validated MainConfig object.
Raises:
FileNotFoundError: If the config file doesn't exist.
yaml.YAMLError: If the file is not valid YAML.
pydantic.ValidationError: If the config doesn't match the schema.
"""
if not config_path.is_file():
logger.error(f"Configuration file not found at: {config_path}")
raise FileNotFoundError(f"Config file not found: {config_path}")
logger.info(f"Loading configuration from: {config_path}")
try:
with open(config_path, 'r') as f:
config_dict = yaml.safe_load(f)
# Validate configuration using Pydantic model
config = MainConfig(**config_dict)
logger.info("Configuration loaded and validated successfully.")
return config
except yaml.YAMLError as e:
logger.error(f"Error parsing YAML file {config_path}: {e}", exc_info=True)
raise
except Exception as e: # Catches Pydantic validation errors too
logger.error(f"Error validating configuration {config_path}: {e}", exc_info=True)
raise
def set_seeds(seed: Optional[int] = 42) -> None:
"""
Set random seeds for reproducibility across libraries.
Args:
seed: The seed value to use. If None, uses default 42.
"""
actual_seed = seed if seed is not None else 42
if seed is None:
logger.warning(f"No random_seed specified in config, using default seed: {actual_seed}")
else:
logger.info(f"Setting random seed from config: {actual_seed}")
random.seed(actual_seed)
np.random.seed(actual_seed)
torch.manual_seed(actual_seed)
# Ensure reproducibility for CUDA operations where possible
if torch.cuda.is_available():
torch.cuda.manual_seed(actual_seed)
torch.cuda.manual_seed_all(actual_seed) # For multi-GPU
# These settings can slow down training but improve reproducibility
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# PyTorch Lightning seeding (optional, as we seed torch directly)
# pl.seed_everything(seed, workers=True) # workers=True ensures dataloader reproducibility
def aggregate_cv_metrics(all_fold_metrics: List[Dict[str, float]]) -> Dict[str, Dict[str, float]]:
"""
Calculate mean and standard deviation of metrics across folds.
Handles potential NaN values by ignoring them.
Args:
all_fold_metrics: A list where each element is a dictionary of
metrics for one fold (e.g., {'MAE': v1, 'RMSE': v2}).
Returns:
A dictionary where keys are metric names and values are dicts
containing 'mean' and 'std' for that metric across folds.
Example: {'MAE': {'mean': m, 'std': s}, 'RMSE': {'mean': m2, 'std': s2}}
"""
if not all_fold_metrics:
logger.warning("Received empty list for metric aggregation.")
return {}
aggregated: Dict[str, Dict[str, float]] = {}
# Get metric names from the first valid fold's results
first_valid_metrics = next((m for m in all_fold_metrics if m), None)
if not first_valid_metrics:
logger.warning("No valid fold metrics found for aggregation.")
return {}
metric_names = list(first_valid_metrics.keys())
for metric in metric_names:
# Collect values for this metric across all folds, ignoring NaNs
values = [fold_metrics.get(metric) for fold_metrics in all_fold_metrics if fold_metrics and metric in fold_metrics]
valid_values = [v for v in values if v is not None and not np.isnan(v)]
if not valid_values:
logger.warning(f"No valid values found for metric '{metric}' across folds.")
mean_val = np.nan
std_val = np.nan
else:
mean_val = float(np.mean(valid_values))
std_val = float(np.std(valid_values))
logger.debug(f"Aggregated '{metric}': Mean={mean_val:.4f}, Std={std_val:.4f} from {len(valid_values)} folds.")
aggregated[metric] = {'mean': mean_val, 'std': std_val}
return aggregated
def save_results(results: Dict, filename: Path):
"""Save dictionary results to a JSON file."""
try:
filename.parent.mkdir(parents=True, exist_ok=True)
# Convert numpy types to native Python types for JSON serialization
results_serializable = json.loads(json.dumps(results, cls=NumpyEncoder))
with open(filename, 'w') as f:
json.dump(results_serializable, f, indent=4)
logger.info(f"Saved results to {filename}")
except TypeError as e:
logger.error(f"Serialization error saving results to {filename}. Check for non-serializable types (e.g., numpy types): {e}", exc_info=True)
except Exception as e:
logger.error(f"Failed to save results to {filename}: {e}", exc_info=True)
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, (np.bool_, bool)):
return bool(obj)
elif pd.isna(obj): # Handle pandas NaT or numpy NaN gracefully
return None
return super(NumpyEncoder, self).default(obj)

View File

@ -1,30 +1,34 @@
import argparse
import logging
import sys
import os
import random
from pathlib import Path
import time
import json
import numpy as np
import pandas as pd
import torch
import yaml
import pytorch_lightning as pl
from matplotlib import pyplot as plt
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import CSVLogger
from sklearn.preprocessing import StandardScaler, MinMaxScaler
# Import necessary components from your project structure
# Assuming forecasting_model is a package installable or in PYTHONPATH
from forecasting_model.utils.config_model import MainConfig
from forecasting_model.utils.forecast_config_model import MainConfig
from forecasting_model.data_processing import (
load_raw_data,
TimeSeriesCrossValidationSplitter,
prepare_fold_data_and_loaders
)
from forecasting_model.model import LSTMForecastLightningModule
from forecasting_model.train.model import LSTMForecastLightningModule
from forecasting_model.evaluation import evaluate_fold_predictions
from typing import Dict, List, Any, Optional
from forecasting_model.train.ensemble_evaluation import run_ensemble_evaluation
# Import the new classic training function
from forecasting_model.train.classic import run_classic_training
from typing import Dict, List, Optional, Tuple, Union
from forecasting_model.utils.helper import parse_arguments, load_config, set_seeds, aggregate_cv_metrics, save_results
from forecasting_model.io.plotting import plot_loss_curve_from_csv, create_multi_horizon_time_series_plot, save_plot
# Silence overly verbose libraries if needed
mpl_logger = logging.getLogger('matplotlib')
@ -33,203 +37,75 @@ pil_logger = logging.getLogger('PIL')
pil_logger.setLevel(logging.WARNING)
# --- Basic Logging Setup ---
# Configure logging early. Level might be adjusted by config.
# Configure logging early. Level might be adjusted by config later.
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)-7s - %(message)s',
datefmt='%H:%M:%S')
# Get the root logger
logger = logging.getLogger()
# --- Argument Parsing ---
def parse_arguments():
"""Parses command-line arguments."""
parser = argparse.ArgumentParser(
description="Run the Time Series Forecasting training pipeline.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
'-c', '--config',
type=str,
default='config.yaml',
help="Path to the YAML configuration file."
)
parser.add_argument(
'--seed',
type=int,
default=None, # Default to None, use config value if not provided
help="Override random seed defined in config."
)
parser.add_argument(
'--debug',
action='store_true',
help="Override log level to DEBUG."
)
parser.add_argument(
'--output-dir',
type=str,
default='output/cv_results', # Default output base directory
help="Base directory for saving cross-validation results (checkpoints, logs, plots)."
)
args = parser.parse_args()
return args
# --- Helper Functions ---
def load_config(config_path: Path) -> MainConfig:
# --- Single Fold Processing Function ---
def run_single_fold(
fold_num: int,
train_idx: np.ndarray,
val_idx: np.ndarray,
test_idx: np.ndarray,
config: MainConfig,
full_df: pd.DataFrame,
output_base_dir: Path # Receives Path object from run_training_pipeline
) -> Tuple[Dict[str, float], Optional[float], Optional[Path], Optional[Path], Optional[Path], Optional[Path]]:
"""
Load and validate configuration from YAML file using Pydantic.
Runs the pipeline for a single cross-validation fold.
Args:
config_path: Path to the YAML configuration file.
fold_num: The zero-based index of the current fold.
train_idx: Indices for the training set.
val_idx: Indices for the validation set.
test_idx: Indices for the test set.
config: The main configuration object.
full_df: The complete raw DataFrame.
output_base_dir: The base directory Path for saving results.
Returns:
Validated MainConfig object.
Raises:
FileNotFoundError: If the config file doesn't exist.
yaml.YAMLError: If the file is not valid YAML.
pydantic.ValidationError: If the config doesn't match the schema.
A tuple containing:
- fold_metrics: Dictionary of test metrics for the fold (e.g., {'MAE': ..., 'RMSE': ...}).
- best_val_score: The best validation score achieved during training (or None).
- saved_model_path: Path to the best saved model checkpoint (or None).
- saved_target_scaler_path: Path to the saved target scaler (or None).
- saved_input_size_path: Path to the saved input size file (or None).
- saved_config_path: Path to the saved config file for this fold (or None).
"""
if not config_path.is_file():
logger.error(f"Configuration file not found at: {config_path}")
raise FileNotFoundError(f"Config file not found: {config_path}")
logger.info(f"Loading configuration from: {config_path}")
try:
with open(config_path, 'r') as f:
config_dict = yaml.safe_load(f)
# Validate configuration using Pydantic model
config = MainConfig(**config_dict)
logger.info("Configuration loaded and validated successfully.")
return config
except yaml.YAMLError as e:
logger.error(f"Error parsing YAML file {config_path}: {e}", exc_info=True)
raise
except Exception as e: # Catches Pydantic validation errors too
logger.error(f"Error validating configuration {config_path}: {e}", exc_info=True)
raise
def set_seeds(seed: Optional[int] = 42) -> None:
"""
Set random seeds for reproducibility across libraries.
Args:
seed: The seed value to use. If None, uses default 42.
"""
if seed is None:
seed = 42
logger.warning(f"No seed provided, using default seed: {seed}")
else:
logger.info(f"Setting random seed: {seed}")
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# Ensure reproducibility for CUDA operations where possible
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # For multi-GPU
# These settings can slow down training but improve reproducibility
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# PyTorch Lightning seeding (optional, as we seed torch directly)
# pl.seed_everything(seed, workers=True) # workers=True ensures dataloader reproducibility
def aggregate_cv_metrics(all_fold_metrics: List[Dict[str, float]]) -> Dict[str, Dict[str, float]]:
"""
Calculate mean and standard deviation of metrics across folds.
Handles potential NaN values by ignoring them.
Args:
all_fold_metrics: A list where each element is a dictionary of
metrics for one fold (e.g., {'MAE': v1, 'RMSE': v2}).
Returns:
A dictionary where keys are metric names and values are dicts
containing 'mean' and 'std' for that metric across folds.
Example: {'MAE': {'mean': m, 'std': s}, 'RMSE': {'mean': m2, 'std': s2}}
"""
if not all_fold_metrics:
logger.warning("Received empty list for metric aggregation.")
return {}
aggregated: Dict[str, Dict[str, float]] = {}
# Get metric names from the first valid fold's results
first_valid_metrics = next((m for m in all_fold_metrics if m), None)
if not first_valid_metrics:
logger.warning("No valid fold metrics found for aggregation.")
return {}
metric_names = list(first_valid_metrics.keys())
for metric in metric_names:
# Collect values for this metric across all folds, ignoring NaNs
values = [fold_metrics.get(metric) for fold_metrics in all_fold_metrics if fold_metrics and metric in fold_metrics]
valid_values = [v for v in values if v is not None and not np.isnan(v)]
if not valid_values:
logger.warning(f"No valid values found for metric '{metric}' across folds.")
mean_val = np.nan
std_val = np.nan
else:
mean_val = float(np.mean(valid_values))
std_val = float(np.std(valid_values))
logger.debug(f"Aggregated '{metric}': Mean={mean_val:.4f}, Std={std_val:.4f} from {len(valid_values)} folds.")
aggregated[metric] = {'mean': mean_val, 'std': std_val}
return aggregated
def save_results(results: Dict, filename: Path):
"""Save dictionary results to a JSON file."""
try:
filename.parent.mkdir(parents=True, exist_ok=True)
with open(filename, 'w') as f:
json.dump(results, f, indent=4)
logger.info(f"Saved results to {filename}")
except Exception as e:
logger.error(f"Failed to save results to {filename}: {e}", exc_info=True)
# --- Main Training & Evaluation Function ---
def run_training_pipeline(config: MainConfig, output_base_dir: Path):
"""Runs the full cross-validation training and evaluation pipeline."""
start_time = time.perf_counter()
# --- Data Loading ---
try:
df = load_raw_data(config.data)
except Exception as e:
logger.critical(f"Failed to load raw data: {e}", exc_info=True)
sys.exit(1) # Cannot proceed without data
# --- Cross-Validation Setup ---
try:
cv_splitter = TimeSeriesCrossValidationSplitter(config.cross_validation, len(df))
except ValueError as e:
logger.critical(f"Failed to initialize CV splitter: {e}", exc_info=True)
sys.exit(1)
all_fold_test_metrics: List[Dict[str, float]] = []
all_fold_best_val_scores: Dict[int, Optional[float]] = {} # Store best val score per fold
# --- Cross-Validation Loop ---
logger.info(f"Starting {config.cross_validation.n_splits}-Fold Cross-Validation...")
for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split()):
fold_start_time = time.perf_counter()
fold_id = fold_num + 1
fold_id = fold_num + 1 # User-facing fold number (1-based)
logger.info(f"--- Starting Fold {fold_id}/{config.cross_validation.n_splits} ---")
fold_output_dir = output_base_dir / f"fold_{fold_id:02d}"
fold_output_dir.mkdir(parents=True, exist_ok=True)
logger.debug(f"Fold output directory: {fold_output_dir}")
fold_metrics: Dict[str, float] = {'MAE': np.nan, 'RMSE': np.nan} # Default in case of failure
best_val_score: Optional[float] = None
best_model_path_str: Optional[str] = None # Use a different name for the string from callback
# Variables to hold prediction results for plotting later
all_preds_scaled: Optional[np.ndarray] = None
all_targets_scaled: Optional[np.ndarray] = None
target_scaler: Optional[Union[StandardScaler, MinMaxScaler]] = None # Need to keep scaler reference
prediction_target_time_index_h1: Optional[pd.DatetimeIndex] = None
# Variables to store paths of saved artifacts
saved_model_path: Optional[Path] = None
saved_target_scaler_path: Optional[Path] = None
saved_input_size_path: Optional[Path] = None
saved_config_path: Optional[Path] = None
try:
# --- Per-Fold Data Preparation ---
logger.info("Preparing data loaders for the fold...")
train_loader, val_loader, test_loader, target_scaler, input_size = prepare_fold_data_and_loaders(
full_df=df,
# Keep scaler and input_size references returned by prepare_fold_data_and_loaders
train_loader, val_loader, test_loader, target_scaler_fold, input_size = prepare_fold_data_and_loaders( # Renamed target_scaler
full_df=full_df,
train_idx=train_idx,
val_idx=val_idx,
test_idx=test_idx,
@ -238,41 +114,55 @@ def run_training_pipeline(config: MainConfig, output_base_dir: Path):
train_config=config.training,
eval_config=config.evaluation
)
target_scaler = target_scaler_fold # Store the scaler in the outer scope
logger.info(f"Data loaders prepared. Input size determined: {input_size}")
# --- Model Initialization ---
# Pass input_size directly, ModelConfig no longer holds it.
# Ensure forecast horizon is consistent (checked in MainConfig validation)
current_model_config = config.model # Use the validated model config
# Save necessary items for potential later use (e.g., ensemble)
# Capture the paths when saving
saved_target_scaler_path = fold_output_dir / "target_scaler.pt"
torch.save(target_scaler, saved_target_scaler_path)
torch.save(test_loader, fold_output_dir / "test_loader.pt") # Test loader might be large, consider if needed
# Save input size and capture path
saved_input_size_path = fold_output_dir / "input_size.pt"
torch.save(input_size, saved_input_size_path)
# Save config for this fold (needed for reloading model) and capture path
config_dump = config.model_dump()
saved_config_path = fold_output_dir / "config.yaml" # Capture the path before saving
with open(saved_config_path, 'w') as f:
yaml.dump(config_dump, f, default_flow_style=False)
# --- Model Initialization ---
model = LSTMForecastLightningModule(
model_config=current_model_config, # Does not contain input_size
model_config=config.model,
train_config=config.training,
input_size=input_size, # Pass the dynamically determined input_size
target_scaler=target_scaler # Pass the fold-specific scaler
input_size=input_size,
target_scaler=target_scaler_fold # Pass scaler during init
)
logger.info("LSTMForecastLightningModule initialized.")
# --- PyTorch Lightning Callbacks ---
# Monitor the validation MAE on the original scale (logged by LightningModule)
monitor_metric = "val_mae_orig_scale"
# Ensure monitor_metric matches the exact name logged in model.py
monitor_metric = "val_MeanAbsoluteError_Original_Scale" # Corrected metric name
monitor_mode = "min"
early_stop_callback = None
if config.training.early_stopping_patience is not None and config.training.early_stopping_patience > 0:
early_stop_callback = EarlyStopping(
monitor=monitor_metric,
min_delta=0.0001, # Minimum change to qualify as improvement
min_delta=0.0001,
patience=config.training.early_stopping_patience,
verbose=True,
mode=monitor_mode
)
logger.info(f"Enabled EarlyStopping: monitor='{monitor_metric}', patience={config.training.early_stopping_patience}")
# Checkpoint callback to save the best model based on validation metric
checkpoint_callback = ModelCheckpoint(
dirpath=fold_output_dir / "checkpoints",
filename=f"best_model_fold_{fold_id}", # {{epoch}}-{{val_loss:.2f}} etc. possible
filename=f"best_model_fold_{fold_id}",
save_top_k=1,
monitor=monitor_metric,
mode=monitor_mode,
@ -280,7 +170,6 @@ def run_training_pipeline(config: MainConfig, output_base_dir: Path):
)
logger.info(f"Enabled ModelCheckpoint: monitor='{monitor_metric}', mode='{monitor_mode}'")
# Learning rate monitor callback
lr_monitor = LearningRateMonitor(logging_interval='epoch')
callbacks = [checkpoint_callback, lr_monitor]
@ -288,15 +177,16 @@ def run_training_pipeline(config: MainConfig, output_base_dir: Path):
callbacks.append(early_stop_callback)
# --- PyTorch Lightning Logger ---
# Log metrics to a CSV file within the fold directory
pl_logger = CSVLogger(save_dir=str(output_base_dir), name=f"fold_{fold_id:02d}", version='logs')
# Log to a subdir specific to the fold, relative to output_base_dir
log_dir = output_base_dir / f"fold_{fold_id:02d}" / "training_logs"
pl_logger = CSVLogger(save_dir=str(log_dir.parent), name=log_dir.name, version='') # Use name for subdir, version='' to avoid 'version_0'
logger.info(f"Using CSVLogger, logs will be saved in: {pl_logger.log_dir}")
# --- PyTorch Lightning Trainer ---
# Determine accelerator and devices based on PyTorch check
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
devices = 1 if accelerator == 'gpu' else None # Or specify specific GPU IDs [0], [1] etc.
precision = getattr(config.training, 'precision', 32) # Default to 32-bit
devices = 1 if accelerator == 'gpu' else None
precision = getattr(config.training, 'precision', 32)
trainer = pl.Trainer(
accelerator=accelerator,
@ -304,11 +194,10 @@ def run_training_pipeline(config: MainConfig, output_base_dir: Path):
max_epochs=config.training.epochs,
callbacks=callbacks,
logger=pl_logger,
log_every_n_steps=max(1, len(train_loader)//10), # Log ~10 times per epoch
enable_progress_bar=True, # Set to False for less verbose runs (e.g., HPO)
log_every_n_steps=max(1, len(train_loader)//10),
enable_progress_bar=True,
gradient_clip_val=getattr(config.training, 'gradient_clip_val', None),
precision=precision,
# deterministic=True, # For stricter reproducibility (can slow down)
)
logger.info(f"Initialized PyTorch Lightning Trainer: accelerator='{accelerator}', devices={devices}, precision={precision}")
@ -317,112 +206,383 @@ def run_training_pipeline(config: MainConfig, output_base_dir: Path):
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
logger.info(f"Training finished for Fold {fold_id}.")
# Store best validation score for this fold
best_val_score = trainer.checkpoint_callback.best_model_score
best_model_path = trainer.checkpoint_callback.best_model_path
all_fold_best_val_scores[fold_id] = best_val_score.item() if best_val_score else None
# Store best validation score and path
best_val_score_tensor = trainer.checkpoint_callback.best_model_score
# Capture the best model path reported by the checkpoint callback
best_model_path_str = trainer.checkpoint_callback.best_model_path # Capture the string path
best_val_score = best_val_score_tensor.item() if best_val_score_tensor is not None else None
if best_val_score is not None:
logger.info(f"Best validation score ({monitor_metric}) for Fold {fold_id}: {all_fold_best_val_scores[fold_id]:.4f}")
logger.info(f"Best model checkpoint path: {best_model_path}")
logger.info(f"Best validation score ({monitor_metric}) for Fold {fold_id}: {best_val_score:.4f}")
# Check if best_model_path was actually set by the callback
if best_model_path_str:
saved_model_path = Path(best_model_path_str) # Convert string to Path object and store
logger.info(f"Best model checkpoint path: {best_model_path_str}")
else:
logger.warning(f"ModelCheckpoint callback did not report a best_model_path for Fold {fold_id}.")
else:
logger.warning(f"Could not retrieve best validation score/path for Fold {fold_id} (metric: {monitor_metric}). Evaluation might use last model.")
best_model_path = None # Ensure evaluation doesn't try to load 'best' if checkpointing failed
best_model_path_str = None # Ensure string path is None if no best score
# --- Prediction on Test Set ---
# Use trainer.predict() to get model outputs
logger.info(f"Starting prediction for Fold {fold_id} using best checkpoint...")
# predict_step returns dict {'preds_scaled': ..., 'targets_scaled': ...}
# We pass the test_loader here, which yields (x, y) pairs, so predict_step will include targets
logger.info(f"Starting prediction for Fold {fold_id} using {'best checkpoint' if saved_model_path else 'last model'}...")
# Use the best checkpoint path if available, otherwise use the in-memory model instance
ckpt_path_for_predict = str(saved_model_path) if saved_model_path else None # Use the saved Path object, convert to string for ckpt_path
prediction_results_list = trainer.predict(
# model=model, # Not needed if using ckpt_path
ckpt_path=best_model_path if best_model_path else 'last', # Load best model or last if best failed
dataloaders=test_loader
# return_predictions=True # Default is True
model=model, # Use the in-memory model instance
dataloaders=test_loader,
ckpt_path=ckpt_path_for_predict # Specify checkpoint path if needed, though using model=model is typical
)
# Check if prediction returned results
# --- Process Prediction Results & Get Time Index ---
if not prediction_results_list:
logger.error(f"Predict phase did not return any results for Fold {fold_id}. Check predict_step and logs.")
fold_metrics = {'MAE': np.nan, 'RMSE': np.nan}
all_preds_scaled = None # Ensure these are None on failure
all_targets_scaled = None
else:
try:
# Concatenate predictions and targets from predict_step results
all_preds_scaled = torch.cat([batch_res['preds_scaled'] for batch_res in prediction_results_list], dim=0).numpy()
# Check if targets were included (they should be if using test_loader)
n_predictions = len(all_preds_scaled)
if 'targets_scaled' in prediction_results_list[0]:
all_targets_scaled = torch.cat([batch_res['targets_scaled'] for batch_res in prediction_results_list], dim=0).numpy()
if len(all_targets_scaled) != n_predictions:
logger.error(f"Fold {fold_id}: Mismatch between number of predictions ({n_predictions}) and targets ({len(all_targets_scaled)}).")
raise ValueError("Prediction and target count mismatch during evaluation.")
else:
# This case shouldn't happen if using test_loader, but good safeguard
logger.error(f"Targets not found in prediction results for Fold {fold_id}. Cannot evaluate.")
raise ValueError("Targets missing from prediction results.")
logger.error(f"Targets not found in prediction results for Fold {fold_id}. Cannot evaluate or plot original scale targets.")
all_targets_scaled = None
# --- Final Evaluation & Plotting ---
logger.info(f"Processing prediction results for Fold {fold_id}...")
logger.info(f"Processing {n_predictions} prediction results for Fold {fold_id}...")
# --- Calculate Correct Time Index for Plotting (First Horizon) ---
prediction_target_time_index_h1_path = fold_output_dir / "prediction_target_time_index_h1.pt"
prediction_target_time_index_h1 = None
if test_idx is not None and config.features.forecast_horizon and len(config.features.forecast_horizon) > 0:
try:
test_block_index = full_df.index[test_idx]
seq_len = config.features.sequence_length
first_horizon = config.features.forecast_horizon[0]
target_indices_h1 = test_idx + seq_len + first_horizon - 1
valid_target_indices_h1_mask = target_indices_h1 < len(full_df)
valid_target_indices_h1 = target_indices_h1[valid_target_indices_h1_mask]
if len(valid_target_indices_h1) >= n_predictions: # Should be exactly n_predictions if no indices were out of bounds
prediction_target_time_index_h1 = full_df.index[valid_target_indices_h1[:n_predictions]]
if len(prediction_target_time_index_h1) != n_predictions:
logger.warning(f"Fold {fold_id}: Calculated target time index length ({len(prediction_target_time_index_h1)}) "
f"does not match prediction count ({n_predictions}). Plotting x-axis might be misaligned.")
prediction_target_time_index_h1 = None
else:
logger.warning(f"Fold {fold_id}: Cannot calculate target time index for h1; insufficient valid indices ({len(valid_target_indices_h1)} < {n_predictions}).")
prediction_target_time_index_h1 = None
# Save the calculated index if it's valid and evaluation plots are enabled
if prediction_target_time_index_h1 is not None and not prediction_target_time_index_h1.empty and config.evaluation.save_plots:
try:
torch.save(prediction_target_time_index_h1, prediction_target_time_index_h1_path)
logger.debug(f"Saved prediction target time index for h1 to {prediction_target_time_index_h1_path}")
except Exception as save_e:
logger.warning(f"Failed to save prediction target time index file {prediction_target_time_index_h1_path}: {save_e}")
elif prediction_target_time_index_h1_path.exists():
try:
prediction_target_time_index_h1_path.unlink()
logger.debug("Removed outdated prediction target time index h1 file.")
except OSError as e:
logger.warning(f"Could not remove outdated prediction target index h1 file {prediction_target_time_index_h1_path}: {e}")
except Exception as e:
logger.error(f"Fold {fold_id}: Error calculating or saving target time index for plotting: {e}", exc_info=True)
prediction_target_time_index_h1 = None
else:
logger.warning(f"Fold {fold_id}: Skipping target time index calculation (missing test_idx, forecast_horizon, or empty list).")
if prediction_target_time_index_h1_path.exists():
try:
prediction_target_time_index_h1_path.unlink()
logger.debug("Removed outdated prediction target time index h1 file as calculation was skipped.")
except OSError as e:
logger.warning(f"Could not remove outdated prediction target index h1 file {prediction_target_time_index_h1_path}: {e}")
# --- End Index Calculation and Saving ---
# --- Evaluation ---
if all_targets_scaled is not None: # Only evaluate if targets are available
fold_metrics = evaluate_fold_predictions(
y_true_scaled=all_targets_scaled,
y_pred_scaled=all_preds_scaled,
target_scaler=target_scaler, # Use the scaler from this fold
y_true_scaled=all_targets_scaled, # Pass the (N, H) array
y_pred_scaled=all_preds_scaled, # Pass the (N, H) array
target_scaler=target_scaler,
eval_config=config.evaluation,
fold_num=fold_num, # Pass zero-based index
output_dir=output_base_dir, # Base dir for saving plots etc.
# time_index=df.iloc[test_idx].index # Pass time index if needed
output_dir=str(fold_output_dir),
plot_subdir="plots",
# Pass the calculated index for the targets being plotted (h1 reference)
prediction_time_index=prediction_target_time_index_h1, # Use the calculated index here (for h1)
forecast_horizons=config.features.forecast_horizon, # Pass the list of horizons
plot_title_prefix=f"CV Fold {fold_id}"
)
# Save fold metrics
save_results(fold_metrics, fold_output_dir / "test_metrics.json")
else:
logger.error(f"Skipping evaluation for Fold {fold_id} due to missing targets.")
# --- Multi-Horizon Plotting ---
if config.evaluation.save_plots and all_preds_scaled is not None and all_targets_scaled is not None and prediction_target_time_index_h1 is not None and target_scaler is not None:
logger.info(f"Generating multi-horizon plot for Fold {fold_id}...")
try:
multi_horizon_plot_path = fold_output_dir / "plots" / "multi_horizon_forecast.png"
# Need to import save_plot function if it's not already imported
# from forecasting_model.io.plotting import save_plot # Ensure this import is present if needed
fig = create_multi_horizon_time_series_plot(
y_true_scaled_all_horizons=all_targets_scaled,
y_pred_scaled_all_horizons=all_preds_scaled,
target_scaler=target_scaler,
prediction_time_index_h1=prediction_target_time_index_h1,
forecast_horizons=config.features.forecast_horizon,
title=f"Fold {fold_id} Multi-Horizon Forecast",
max_points=1000 # Limit points for clarity
)
# Check if save_plot is available or use fig.savefig()
try:
save_plot(fig, multi_horizon_plot_path)
except NameError:
# Fallback if save_plot is not defined/imported
fig.savefig(multi_horizon_plot_path)
plt.close(fig) # Close the figure after saving
logger.warning("Using fig.savefig as save_plot function was not found.")
except Exception as plot_e:
logger.error(f"Fold {fold_id}: Failed to generate multi-horizon plot: {plot_e}", exc_info=True)
elif config.evaluation.save_plots:
logger.warning(f"Fold {fold_id}: Skipping multi-horizon plot due to missing data (preds, targets, time index, or scaler).")
except KeyError as e:
logger.error(f"KeyError processing prediction results for Fold {fold_id}: Missing key {e}. Check predict_step return format.", exc_info=True)
fold_metrics = {'MAE': np.nan, 'RMSE': np.nan}
except ValueError as e: # Catch specific error from above
logger.error(f"ValueError processing prediction results for Fold {fold_id}: {e}", exc_info=True)
except Exception as e:
logger.error(f"Error processing prediction results for Fold {fold_id}: {e}", exc_info=True)
fold_metrics = {'MAE': np.nan, 'RMSE': np.nan}
all_fold_test_metrics.append(fold_metrics)
# --- Plot Loss Curve for Fold ---
try:
actual_log_dir = Path(pl_logger.log_dir) / pl_logger.name # Should be .../fold_XX/training_logs
metrics_file_path = actual_log_dir / "metrics.csv"
# --- (Optional) Log final test metrics using trainer.test() ---
# If you want the metrics logged by test_step aggregated, call test now.
# logger.info(f"Logging final test metrics via trainer.test() for Fold {fold_id}...")
# try:
# trainer.test(ckpt_path=best_model_path if best_model_path else 'last', dataloaders=test_loader, verbose=False)
# except Exception as e:
# logger.warning(f"trainer.test() call failed for Fold {fold_id}: {e}")
if metrics_file_path.is_file():
plot_loss_curve_from_csv(
metrics_csv_path=metrics_file_path,
output_path=fold_output_dir / "plots" / "loss_curve.png", # Save in plots subdir
title=f"Fold {fold_id} Training Progression",
train_loss_col='train_loss',
val_loss_col='val_loss' # This function handles fallback
)
logger.info(f"Loss curve plot saved for Fold {fold_id} to {fold_output_dir / 'plots' / 'loss_curve.png'}.")
else:
logger.warning(f"Fold {fold_id}: Could not find metrics.csv at {metrics_file_path} for loss curve plot.")
except Exception as e:
logger.error(f"Fold {fold_id}: Failed to generate loss curve plot: {e}", exc_info=True)
# --- End Loss Curve Plotting ---
except Exception as e:
# Catch errors during the fold processing (data prep, training, prediction, eval)
logger.error(f"An error occurred during Fold {fold_id} pipeline: {e}", exc_info=True)
all_fold_test_metrics.append({'MAE': np.nan, 'RMSE': np.nan})
# Ensure paths are None if an error occurs before they are set
if saved_model_path is None: saved_model_path = None
if saved_target_scaler_path is None: saved_target_scaler_path = None
# --- Cleanup per fold ---
finally:
# Clean up GPU memory explicitly
del model, trainer # Ensure objects are deleted before clearing cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.debug("Cleared CUDA cache.")
# Delete loaders explicitly if they might hold references
del train_loader, val_loader, test_loader
fold_end_time = time.perf_counter()
logger.info(f"--- Finished Fold {fold_id} in {fold_end_time - fold_start_time:.2f} seconds ---")
# Return the calculated fold metrics, best validation score, and saved artifact paths
return fold_metrics, best_val_score, saved_model_path, saved_target_scaler_path, saved_input_size_path, saved_config_path
# --- Aggregation and Final Reporting ---
# --- Main Training & Evaluation Function ---
def run_training_pipeline(config: MainConfig, output_base_dir: Path):
"""Runs the full training and evaluation pipeline based on config flags."""
start_time = time.perf_counter()
logger.info(f"Starting training pipeline. Results will be saved to: {output_base_dir}")
output_base_dir.mkdir(parents=True, exist_ok=True)
# --- Data Loading ---
try:
df = load_raw_data(config.data)
except Exception as e:
logger.critical(f"Failed to load raw data: {e}", exc_info=True)
sys.exit(1)
# --- Initialize results ---
all_fold_test_metrics: List[Dict[str, float]] = []
all_fold_best_val_scores: Dict[int, Optional[float]] = {}
aggregated_metrics: Dict = {}
final_results: Dict = {} # Initialize empty results dict
# --- Cross-Validation Loop ---
if config.run_cross_validation:
logger.info(f"Starting {config.cross_validation.n_splits}-Fold Cross-Validation...")
try:
cv_splitter = TimeSeriesCrossValidationSplitter(config.cross_validation, len(df))
except ValueError as e:
logger.critical(f"Failed to initialize CV splitter: {e}", exc_info=True)
sys.exit(1)
for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split()):
# Unpack the two new return values from run_single_fold
fold_metrics, best_val_score, saved_model_path, saved_target_scaler_path, _input_size_path, _config_path = run_single_fold(
fold_num=fold_num,
train_idx=train_idx,
val_idx=val_idx,
test_idx=test_idx,
config=config,
full_df=df,
output_base_dir=output_base_dir
)
all_fold_test_metrics.append(fold_metrics)
all_fold_best_val_scores[fold_num + 1] = best_val_score
# --- Aggregation and Reporting for CV ---
logger.info("Cross-validation finished. Aggregating results...")
aggregated_metrics = aggregate_cv_metrics(all_fold_test_metrics)
# Save aggregated results
final_results = {
'aggregated_test_metrics': aggregated_metrics,
'per_fold_test_metrics': all_fold_test_metrics,
'per_fold_best_val_scores': all_fold_best_val_scores,
}
final_results['aggregated_test_metrics'] = aggregated_metrics
final_results['per_fold_test_metrics'] = all_fold_test_metrics
final_results['per_fold_best_val_scores'] = all_fold_best_val_scores
# Save intermediate results after CV
save_results(final_results, output_base_dir / "aggregated_cv_results.json")
# Log final results
logger.info("--- Aggregated Cross-Validation Test Results ---")
if aggregated_metrics:
for metric, stats in aggregated_metrics.items():
logger.info(f"{metric}: {stats['mean']:.4f} ± {stats['std']:.4f}")
else:
logger.warning("No metrics available for aggregation.")
logger.info("Skipping Cross-Validation loop as per config.")
# --- Ensemble Evaluation ---
if config.run_ensemble_evaluation:
# The validator in MainConfig already ensures run_cross_validation is also true here
logger.info("Starting ensemble evaluation...")
try:
ensemble_results = run_ensemble_evaluation(
config=config, # Pass config for context if needed by sub-functions
output_base_dir=output_base_dir
)
if ensemble_results:
logger.info("Ensemble evaluation completed successfully")
final_results['ensemble_results'] = ensemble_results
save_results(final_results, output_base_dir / "aggregated_cv_results.json")
else:
logger.warning("No ensemble results were generated (potentially < 2 folds available).")
except Exception as e:
logger.error(f"Error during ensemble evaluation: {e}", exc_info=True)
else:
logger.info("Skipping Ensemble evaluation as per config.")
# --- Classic Training Run ---
if config.run_classic_training:
logger.info("Starting classic training run...")
classic_output_dir = output_base_dir / "classic_run" # Define dir for logging path
try:
# Call the original classic training function directly
classic_metrics = run_classic_training(
config=config,
full_df=df,
output_base_dir=output_base_dir # It creates classic_run subdir internally
)
if classic_metrics:
logger.info(f"Classic training run completed. Test Metrics: {classic_metrics}")
final_results['classic_training_results'] = classic_metrics
save_results(final_results, output_base_dir / "aggregated_cv_results.json")
# --- Plot Loss Curve for Classic Run ---
try:
classic_log_dir = classic_output_dir / "training_logs"
metrics_file = classic_log_dir / "metrics.csv"
version_dirs = list(classic_log_dir.glob("version_*"))
if version_dirs:
metrics_file = version_dirs[0] / "metrics.csv"
if metrics_file.is_file():
plot_loss_curve_from_csv(
metrics_csv_path=metrics_file,
output_path=classic_output_dir / "loss_curve.png",
title="Classic Run Training Progression",
train_loss_col='train_loss', # Changed from 'train_loss_epoch'
val_loss_col='val_loss' # Check your logged metric names
)
else:
logger.warning(f"Classic Run: Could not find metrics.csv at {metrics_file} for loss curve plot.")
except Exception as plot_e:
logger.error(f"Classic Run: Failed to generate loss curve plot: {plot_e}", exc_info=True)
# --- End Classic Loss Plotting ---
else:
logger.warning("Classic training run did not produce metrics.")
except Exception as e:
logger.error(f"Error during classic training run: {e}", exc_info=True)
else:
logger.info("Skipping Classic training run as per config.")
# --- Final Logging Summary ---
logger.info("--- Final Summary ---")
# Log aggregated CV results if they exist
if 'aggregated_test_metrics' in final_results and final_results['aggregated_test_metrics']:
logger.info("--- Aggregated Cross-Validation Test Results ---")
for metric, stats in final_results['aggregated_test_metrics'].items():
logger.info(f"{metric}: {stats.get('mean', np.nan):.4f} ± {stats.get('std', np.nan):.4f}")
elif config.run_cross_validation:
logger.warning("Cross-validation was run, but no metrics were aggregated.")
# Log aggregated ensemble results if they exist
if 'ensemble_results' in final_results and final_results['ensemble_results']:
logger.info("--- Aggregated Ensemble Test Results (Mean over Test Folds) ---")
agg_ensemble = {}
for fold_res in final_results['ensemble_results'].values():
if isinstance(fold_res, dict):
for method, metrics in fold_res.items():
if method not in agg_ensemble: agg_ensemble[method] = {}
if isinstance(metrics, dict):
for m_name, m_val in metrics.items():
if m_name not in agg_ensemble[method]: agg_ensemble[method][m_name] = []
agg_ensemble[method][m_name].append(m_val)
else: logger.warning(f"Skipping non-dict metrics for ensemble method '{method}'.")
else: logger.warning("Skipping non-dict fold result in ensemble aggregation.")
for method, metrics_data in agg_ensemble.items():
logger.info(f" Ensemble Method: {method}")
for m_name, values in metrics_data.items():
valid_vals = [v for v in values if v is not None and not np.isnan(v)]
if valid_vals: logger.info(f" {m_name}: {np.mean(valid_vals):.4f} ± {np.std(valid_vals):.4f}")
else: logger.info(f" {m_name}: N/A")
# Log classic results if they exist
if 'classic_training_results' in final_results and final_results['classic_training_results']:
logger.info("--- Classic Training Test Results ---")
classic_res = final_results['classic_training_results']
for metric, value in classic_res.items():
logger.info(f"{metric}: {value:.4f}")
logger.info("-------------------------------------------------")
end_time = time.perf_counter()
@ -434,12 +594,6 @@ def run():
"""Main execution function."""
args = parse_arguments()
config_path = Path(args.config)
output_dir = Path(args.output_dir)
# Adjust log level if debug flag is set
if args.debug:
logger.setLevel(logging.DEBUG)
logger.debug("# --- Debug mode enabled. --- #")
# --- Configuration Loading ---
try:
@ -448,10 +602,20 @@ def run():
# Error already logged in load_config
sys.exit(1)
# --- Seed Setting ---
# Use command-line seed if provided, otherwise use config seed
seed = args.seed if args.seed is not None else getattr(config, 'random_seed', 42)
set_seeds(seed)
# --- Setup based on Config ---
# 1. Set Log Level
log_level_name = config.log_level.upper()
log_level = getattr(logging, log_level_name, logging.INFO)
logger.setLevel(log_level)
logger.info(f"Log level set to: {log_level_name}")
if log_level == logging.DEBUG:
logger.debug("# --- Debug mode enabled via config. --- #")
# 2. Set Seed
set_seeds(config.random_seed)
# 3. Determine Output Directory
output_dir = Path(config.output_dir)
# --- Pipeline Execution ---
try:
@ -459,7 +623,7 @@ def run():
except SystemExit as e:
logger.warning(f"Pipeline exited with code {e.code}.")
sys.exit(e.code) # Propagate exit code
sys.exit(e.code)
except Exception as e:
logger.critical(f"An critical error occurred during pipeline execution: {e}", exc_info=True)
sys.exit(1)

123
main.py
View File

@ -1,123 +0,0 @@
import logging
import torch
import numpy as np
from pathlib import Path
from typing import Dict, List, Any
from forecasting_model.utils.config_model import MainConfig
from forecasting_model.data_processing import (
load_raw_data,
TimeSeriesCrossValidationSplitter,
prepare_fold_data_and_loaders
)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def load_config(config_path: Path) -> MainConfig:
"""
Load and validate configuration from YAML file.
"""
# TODO: Implement config loading
pass
def set_seeds(seed: int = 42) -> None:
"""
Set random seeds for reproducibility.
"""
# TODO: Implement seed setting
pass
def determine_device(config: MainConfig) -> torch.device:
"""
Determine the device to use for training.
"""
# TODO: Implement device determination
pass
def aggregate_cv_metrics(all_fold_metrics: List[Dict[str, float]]) -> Dict[str, Dict[str, float]]:
"""
Calculate mean and standard deviation of metrics across folds.
"""
# TODO: Implement metric aggregation
pass
def main():
# Load configuration
config = load_config(Path("config.yaml"))
# Set random seeds
set_seeds()
# Determine device
device = determine_device(config)
# Load raw data
df = load_raw_data(config.data)
# Initialize CV splitter
cv_splitter = TimeSeriesCrossValidationSplitter(config.cross_validation, len(df))
# Initialize list to store fold metrics
all_fold_metrics = []
# Cross-validation loop
for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split(), 1):
logger.info(f"Starting fold {fold_num}")
# Prepare data loaders
train_loader, val_loader, test_loader, target_scaler, input_size = prepare_fold_data_and_loaders(
df, train_idx, val_idx, test_idx,
config.features, config.training, config.evaluation
)
# Update model config with input size
config.model.input_size = input_size
# Initialize model
model = LSTMForecastModel(config.model).to(device)
# Initialize loss function
loss_fn = torch.nn.MSELoss() if config.training.loss_function == "MSE" else torch.nn.L1Loss()
# Initialize scheduler if configured
scheduler = None
if config.training.scheduler_step_size is not None:
# TODO: Initialize scheduler
pass
# Initialize trainer
trainer = Trainer(
model, train_loader, val_loader, loss_fn, device,
config.training, scheduler, target_scaler
)
# Train model
trainer.train()
# Evaluate on test set
fold_metrics = evaluate_fold(
model, test_loader, loss_fn, device,
target_scaler, config.evaluation, fold_num
)
all_fold_metrics.append(fold_metrics)
# Optional: Clear GPU memory
if device.type == "cuda":
torch.cuda.empty_cache()
# Aggregate metrics
aggregated_metrics = aggregate_cv_metrics(all_fold_metrics)
# Log final results
logger.info("Cross-validation results:")
for metric, stats in aggregated_metrics.items():
logger.info(f"{metric}: {stats['mean']:.4f} ± {stats['std']:.4f}")
if __name__ == "__main__":
main()

25
optim_config.yaml Normal file
View File

@ -0,0 +1,25 @@
# Configuration for the battery optimization runs
# Initial state of charge of the battery (MWh)
initial_b: 0.0
# Maximum energy capacity of the battery (MWh)
max_capacity: 1.0
# Maximum charge/discharge power rate of the battery (MW)
max_rate: 1.0
# The length of the time window (in hours) for which the optimization is run
# This should match the forecast horizon of the models being evaluated.
optimization_horizon_hours: 24
# List of models to evaluate. Each entry includes the path to the model's
# forecast output file and the path to the forecasting config used for that model.
models:
- name: "Model_A"
forecast_path: "path/to/model_a_forecast_output.csv" # Path to the file containing forecast time points and prices
forecast_config_path: "configs/model_a_forecasting_config.yaml" # Path to the forecasting config used for this model
- name: "Model_B"
forecast_path: "path/to/model_b_forecast_output.csv"
forecast_config_path: "configs/model_b_forecasting_config.yaml"
# Add more models here

544
optim_run.py Normal file
View File

@ -0,0 +1,544 @@
import pandas as pd
import numpy as np
import yaml
import logging
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
# Import Forecasting Providers
from forecasting_model.data_processing import load_raw_data
from optimizer.forecasting.base import ForecastProvider
from optimizer.forecasting.single_model import SingleModelProvider
from optimizer.forecasting.ensemble import EnsembleProvider
from optimizer.optimization.battery import solve_battery_optimization_hourly
from optimizer.utils.optim_config import OptimizationRunConfig
from forecasting_model.utils.forecast_config_model import DataConfig, MainConfig
# Import the newly created loading functions
from optimizer.utils.model_io import load_single_model_artifact, load_ensemble_artifact
from typing import Dict, Any, Optional, Union # Added Union
# Silence overly verbose libraries if needed
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)
pil_logger = logging.getLogger('PIL')
pil_logger.setLevel(logging.WARNING)
# --- Basic Logging Setup ---
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)-7s - %(message)s',
datefmt='%H:%M:%S')
logger = logging.getLogger()
def load_optimization_config(config_path: str) -> OptimizationRunConfig | None:
"""Loads the main optimization configuration from a YAML file."""
logger.info(f"Loading optimization config from {config_path}")
try:
with open(config_path, 'r') as f:
config_data = yaml.safe_load(f)
return OptimizationRunConfig(**config_data)
except FileNotFoundError:
logger.error(f"Optimization config file not found at {config_path}")
return None
except yaml.YAMLError as e:
logger.error(f"Error parsing YAML optimization config file: {e}")
return None
except Exception as e:
logger.error(f"Error loading optimization config: {e}", exc_info=True)
return None
def load_main_forecasting_config(config_path: str) -> MainConfig | None:
"""Loads the main forecasting configuration from a YAML file."""
logger.info(f"Loading main forecasting config from: {config_path}")
try:
with open(config_path, 'r') as f:
config_data = yaml.safe_load(f)
# Assuming MainConfig is the top-level model in forecast_config_model.py
return MainConfig(**config_data)
except FileNotFoundError:
logger.error(f"Main forecasting config file not found at {config_path}")
return None
except yaml.YAMLError as e:
logger.error(f"Error parsing YAML main forecasting config file: {e}")
return None
except Exception as e:
logger.error(f"Error loading main forecasting config: {e}", exc_info=True)
return None
# --- Main Execution Logic ---
# 1. Load configs
# 2. Initialize forecast providers
# 3. For each time window:
# a. Get forecasts for all horizons
# b. Run optimization for each horizon
# c. Store results
# 4. Evaluate and visualize
if __name__ == "__main__":
logger.info("Starting battery optimization evaluation with baseline and models/ensembles.")
# --- Load Main Optimization Config ---
optimization_config_path = "optim_config.yaml"
optimization_config = load_optimization_config(optimization_config_path)
if optimization_config is None:
logger.critical("Failed to load main optimization config. Exiting.") # Use critical for exit
exit(1) # Use non-zero exit code for error
optim_run_script_dir = Path(__file__).parent
if not optimization_config.models:
logger.critical("No models or ensembles specified in optimization config. Exiting.")
exit(1)
# Try to load the main forecasting config for the first model/ensemble to get the data path
first_model_config_path = Path(optimization_config.models[0].model_config_path)
main_forecasting_config_for_data = load_main_forecasting_config(str(first_model_config_path))
if main_forecasting_config_for_data is None:
logger.critical("Failed to load forecasting config for the first specified model/ensemble to get data path. Exiting.")
exit(1)
# Use the DataConfig from the first loaded forecasting config
historical_data_config = DataConfig(
data_path=main_forecasting_config_for_data.data.data_path,
raw_datetime_col=main_forecasting_config_for_data.data.raw_datetime_col,
raw_datetime_format=main_forecasting_config_for_data.data.raw_datetime_format,
datetime_col=main_forecasting_config_for_data.data.datetime_col,
raw_target_col=main_forecasting_config_for_data.data.raw_target_col,
target_col=main_forecasting_config_for_data.data.target_col,
expected_frequency=main_forecasting_config_for_data.data.expected_frequency,
fill_initial_target_nans=main_forecasting_config_for_data.data.fill_initial_target_nans
)
logger.info(f"Loading original historical data from: {historical_data_config.data_path}")
try:
full_historical_df = load_raw_data(historical_data_config)
if full_historical_df.empty:
logger.critical("Loaded original historical data is empty. Cannot proceed. Exiting.")
exit(1)
# Ensure data is at the expected frequency and sorted
full_historical_df = full_historical_df.sort_index().asfreq(historical_data_config.expected_frequency)
# Fill any NaNs introduced by asfreq if not already handled by fill_initial_target_nans
if full_historical_df[historical_data_config.target_col].isnull().any():
logger.warning(f"NaNs found after setting frequency {historical_data_config.expected_frequency}. Applying ffill().bfill().")
full_historical_df[historical_data_config.target_col] = full_historical_df[historical_data_config.target_col].ffill().bfill()
if full_historical_df[historical_data_config.target_col].isnull().any():
logger.critical("NaNs still remain after filling. Cannot proceed. Exiting.")
exit(1)
logger.info(f"Original historical data loaded and prepared. Shape: {full_historical_df.shape}")
except Exception as e:
logger.critical(f"Failed to load or prepare original historical data from {historical_data_config.data_path}: {e}", exc_info=True)
exit(1)
# --- Define Evaluation Window and Step ---
optimization_horizon_hours = optimization_config.optimization_horizon_hours
step_size_hours = 1 # Evaluate every hour by sliding the window by 1 hour
logger.info(f"Using optimization horizon: {optimization_horizon_hours} hours with a step size of {step_size_hours} hour(s).")
# --- Storage for results per time window ---
window_results_list = []
# --- Load Models/Ensembles and Instantiate Providers ---
# Store loaded provider instances, keyed by the name from optim_config
forecast_providers: Dict[str, ForecastProvider] = {} # Store provider instances
for model_eval_config in optimization_config.models:
provider_name = model_eval_config.name
artifact_type = model_eval_config.type
artifact_path = Path(model_eval_config.model_path) # Path to .ckpt or .json
config_path = Path(model_eval_config.model_config_path) # Path to YAML config
provider_instance: Optional[ForecastProvider] = None # Initialize provider instance
if artifact_type == 'model':
logger.info(f"Attempting to load single model artifact and create provider: {provider_name}")
target_scaler_path = Path(model_eval_config.target_scaler_path) if model_eval_config.target_scaler_path else None
input_size_path = artifact_path.parent / "input_size.pt" # Derive path convention
if not input_size_path.exists() and artifact_path.parent.name == 'checkpoints':
input_size_path = artifact_path.parent.parent / "input_size.pt"
loaded_artifact_info = load_single_model_artifact(
model_path=artifact_path,
config_path=config_path,
input_size_path=input_size_path,
target_scaler_path=target_scaler_path
)
if loaded_artifact_info:
try:
provider_instance = SingleModelProvider(
model_instance=loaded_artifact_info['model_instance'],
feature_config=loaded_artifact_info['feature_config'],
target_col=loaded_artifact_info['main_forecasting_config'].data.target_col, # Get target col from loaded config
target_scaler=loaded_artifact_info['target_scaler']
)
# Validation check (basic horizon check)
if 1 not in provider_instance.feature_config.forecast_horizon:
logger.error(f"Model '{provider_name}' forecast horizon {provider_instance.feature_config.forecast_horizon} does not include 1 hour. Cannot use for this evaluation.")
provider_instance = None # Discard if validation fails
else:
logger.info(f"Successfully created SingleModelProvider for '{provider_name}'.")
except Exception as e:
logger.error(f"Failed to instantiate SingleModelProvider for '{provider_name}': {e}", exc_info=True)
else:
logger.warning(f"Single model artifact '{provider_name}' could not be loaded. Skipping provider creation.")
elif artifact_type == 'ensemble':
logger.info(f"Attempting to load ensemble artifact and create provider: {provider_name}")
hpo_base_output_dir_for_ensemble = artifact_path.parent
loaded_artifact_info = load_ensemble_artifact(
ensemble_definition_path=artifact_path,
hpo_base_output_dir=hpo_base_output_dir_for_ensemble
)
if loaded_artifact_info:
try:
# Ensure necessary keys are present before instantiation
required_keys = ['fold_artifacts', 'ensemble_method', 'ensemble_feature_config', 'ensemble_target_col']
if not all(key in loaded_artifact_info for key in required_keys):
missing_keys = [key for key in required_keys if key not in loaded_artifact_info]
raise ValueError(f"Ensemble artifact info is missing required keys: {missing_keys}")
provider_instance = EnsembleProvider(
fold_artifacts=loaded_artifact_info['fold_artifacts'],
ensemble_method=loaded_artifact_info['ensemble_method'],
ensemble_feature_config=loaded_artifact_info['ensemble_feature_config'],
ensemble_target_col=loaded_artifact_info['ensemble_target_col']
)
# Validation check (basic horizon check)
if 1 not in provider_instance.common_forecast_horizons:
logger.error(f"Ensemble '{provider_name}' common forecast horizon {provider_instance.common_forecast_horizons} does not include 1 hour. Cannot use for this evaluation.")
provider_instance = None # Discard if validation fails
else:
logger.info(f"Successfully created EnsembleProvider for '{provider_name}'.")
except Exception as e:
logger.error(f"Failed to instantiate EnsembleProvider for '{provider_name}': {e}", exc_info=True)
else:
logger.warning(f"Ensemble artifact '{provider_name}' could not be loaded. Skipping provider creation.")
else:
logger.error(f"Unknown artifact type '{artifact_type}' for '{provider_name}'. Skipping.")
continue # Skip to next model_eval_config
# Store the successfully created provider instance
if provider_instance:
forecast_providers[provider_name] = provider_instance
# --- End Loading ---
if not forecast_providers:
logger.critical("No forecast providers were successfully created. Cannot proceed with evaluation. Exiting.")
exit(1)
# --- Calculate Max Lookback Needed Across All Providers ---
max_required_lookback = 0
for provider_name, provider in forecast_providers.items():
try:
lookback = provider.get_required_lookback()
max_required_lookback = max(max_required_lookback, lookback)
logger.debug(f"Provider '{provider_name}' requires lookback: {lookback}")
except AttributeError:
logger.error(f"Provider '{provider_name}' does not have a 'get_required_lookback' method. Cannot determine lookback requirements accurately. Exiting.")
exit(1)
except Exception as e:
logger.error(f"Error getting lookback for provider '{provider_name}': {e}. Exiting.", exc_info=True)
exit(1)
logger.info(f"Maximum lookback required across all providers: {max_required_lookback} hours.")
# The first timestamp for which we can generate a forecast needs `max_required_lookback` points *before* it.
# If optimization starts at `window_start_time` (iloc `i`), the forecast generation needs data up to `i-1`.
# The historical slice passed to `get_forecast` must contain `max_required_lookback` points, ending at `i-1`.
# Therefore, the slice starts at `i - max_required_lookback`. This must be >= 0.
# So, `i >= max_required_lookback`.
first_window_start_iloc = max_required_lookback
# The last window starts such that the window ends within the data: `i + optimization_horizon_hours - 1 < len(df)`
# So, `i < len(df) - optimization_horizon_hours + 1`.
last_window_start_iloc = len(full_historical_df) - optimization_horizon_hours
if first_window_start_iloc > last_window_start_iloc:
logger.critical(f"Not enough historical data ({len(full_historical_df)} hours) for the required lookback ({max_required_lookback}) and optimization horizon ({optimization_horizon_hours}). First possible window start iloc: {first_window_start_iloc}, last possible: {last_window_start_iloc}. Exiting.")
exit(1)
logger.info(f"Evaluating over historical windows from iloc {first_window_start_iloc} to {last_window_start_iloc}.")
# --- Evaluation Loop ---
for i in range(first_window_start_iloc, last_window_start_iloc + 1, step_size_hours):
# Define the actual optimization window in terms of iloc and time
window_start_iloc = i
window_end_iloc = i + optimization_horizon_hours - 1 # Inclusive index for the window end
# Check if the window is complete within the dataset bounds
if window_end_iloc >= len(full_historical_df):
logger.warning(f"Skipping window starting at iloc {window_start_iloc}: extends beyond available data (needs up to iloc {window_end_iloc}, max is {len(full_historical_df)-1}).")
continue
window_timestamps = full_historical_df.index[window_start_iloc : window_end_iloc + 1]
# Double-check length just in case
if len(window_timestamps) != optimization_horizon_hours:
logger.warning(f"Skipping window starting at iloc {window_start_iloc} due to unexpected timestamp slice length ({len(window_timestamps)} instead of {optimization_horizon_hours} hours).")
continue
window_start_time = window_timestamps[0]
window_end_time = window_timestamps[-1]
logger.info(f"Processing window: {window_start_time.strftime('%Y-%m-%d %H:%M')} to {window_end_time.strftime('%Y-%m-%d %H:%M')} (iloc {window_start_iloc})")
# --- Prepare Historical Slice for Forecasting ---
# We need data *up to* the beginning of the optimization window, including lookback.
# Slice should end at iloc `window_start_iloc - 1`.
# Slice should start at `window_start_iloc - max_required_lookback`.
hist_slice_start_iloc = max(0, window_start_iloc - max_required_lookback)
hist_slice_end_iloc = window_start_iloc # Exclusive end iloc for slicing, so it includes up to window_start_iloc - 1
if hist_slice_end_iloc <= hist_slice_start_iloc:
logger.error(f"Invalid historical slice range for window starting at {window_start_time}: start_iloc={hist_slice_start_iloc}, end_iloc={hist_slice_end_iloc}. Skipping window.")
continue
historical_slice_for_forecasting = full_historical_df.iloc[hist_slice_start_iloc : hist_slice_end_iloc].copy()
# Check if the slice has the expected length (at least max_required_lookback, unless near start of data)
if len(historical_slice_for_forecasting) < max_required_lookback and window_start_iloc >= max_required_lookback:
logger.warning(f"Historical slice for window starting {window_start_time} has unexpected length {len(historical_slice_for_forecasting)}, expected {max_required_lookback}. Check slicing logic. Skipping.")
continue
elif len(historical_slice_for_forecasting) == 0:
logger.warning(f"Historical slice for window starting {window_start_time} is empty. Skipping.")
continue
logger.debug(f"Using historical slice from {historical_slice_for_forecasting.index.min()} to {historical_slice_for_forecasting.index.max()} (Length: {len(historical_slice_for_forecasting)}) for forecasting.")
# --- Collect Window Results ---
window_results = {
'start_time': window_start_time,
'end_time': window_end_time,
'actual_prices': full_historical_df[historical_data_config.target_col].iloc[window_start_iloc : window_end_iloc + 1].tolist()
}
# --- Baseline Optimization ---
baseline_prices_input = np.array(window_results['actual_prices'])
logger.debug(f"Running baseline optimization for window starting {window_start_time}")
try:
baseline_status, baseline_profit, baseline_power, baseline_B = solve_battery_optimization_hourly(
baseline_prices_input,
optimization_config.initial_b,
optimization_config.max_capacity,
optimization_config.max_rate
)
window_results['baseline'] = {
"status": baseline_status,
"profit": baseline_profit,
"power_schedule": baseline_power.tolist() if baseline_power is not None else None,
"B_schedule": baseline_B.tolist() if baseline_B is not None else None
}
logger.debug(f"Baseline profit: {baseline_profit if baseline_profit is not None else 'N/A'}")
except Exception as e:
logger.error(f"Baseline optimization failed for window starting {window_start_time}: {e}", exc_info=True)
window_results['baseline'] = {"status": "Error", "profit": None, "power_schedule": None, "B_schedule": None}
# --- Forecast Provider Optimizations ---
for provider_name, provider_instance in forecast_providers.items():
logger.debug(f"Generating forecast and running optimization for provider '{provider_name}' for window starting {window_start_time}")
# Generate forecast using the provider's get_forecast method
try:
forecast_prices_input = provider_instance.get_forecast(
historical_data_slice=historical_slice_for_forecasting.copy(), # Pass a copy
optimization_horizon_hours=optimization_horizon_hours
)
except Exception as e:
logger.error(f"Error calling get_forecast for provider '{provider_name}': {e}", exc_info=True)
forecast_prices_input = None
if forecast_prices_input is None or len(forecast_prices_input) != optimization_horizon_hours:
logger.warning(f"Forecast generation failed or returned incorrect length ({len(forecast_prices_input) if forecast_prices_input is not None else 0} instead of {optimization_horizon_hours}) for provider '{provider_name}' window starting {window_start_time}. Skipping optimization.")
window_results[provider_name] = {"status": "Forecast Generation Failed", "profit": None, "power_schedule": None, "B_schedule": None}
continue # Skip optimization for this provider/window
# Ensure the forecast input is a numpy array of the correct shape
if not isinstance(forecast_prices_input, np.ndarray) or forecast_prices_input.shape != (optimization_horizon_hours,):
logger.error(f"Forecast input for provider '{provider_name}' has incorrect format ({type(forecast_prices_input)}, shape {forecast_prices_input.shape if isinstance(forecast_prices_input, np.ndarray) else 'N/A'}). Expected ({optimization_horizon_hours},). Skipping optimization.")
window_results[provider_name] = {"status": "Invalid Forecast Format", "profit": None, "power_schedule": None, "B_schedule": None}
continue
# --- Run Optimization with Forecast Prices ---
try:
model_status, model_profit, model_power, model_B = solve_battery_optimization_hourly(
forecast_prices_input,
optimization_config.initial_b,
optimization_config.max_capacity,
optimization_config.max_rate
)
window_results[provider_name] = {
"status": model_status,
"profit": model_profit,
"power_schedule": model_power.tolist() if model_power is not None else None,
"B_schedule": model_B.tolist() if model_B is not None else None
}
logger.debug(f"Provider '{provider_name}' profit: {model_profit if model_profit is not None else 'N/A'}")
except Exception as e:
logger.error(f"Optimization failed for provider '{provider_name}' window starting {window_start_time}: {e}", exc_info=True)
window_results[provider_name] = {"status": "Error", "profit": None, "power_schedule": None, "B_schedule": None}
# Append results for this window
window_results_list.append(window_results)
logger.debug(f"Finished processing window starting at: {window_start_time.strftime('%Y-%m-%d %H:%M')}")
logger.info("Finished processing all evaluation windows.")
# --- Post-processing and Plotting ---
logger.info("Starting results analysis and plotting.")
if not window_results_list:
logger.warning("No window results were collected. Skipping plotting.")
exit(0) # Not necessarily an error state
# Convert results list to a DataFrame
flat_results = []
successfully_loaded_provider_names = list(forecast_providers.keys()) # Names of providers used
for window_res in window_results_list:
base_info = {
'start_time': window_res['start_time'],
'end_time': window_res['end_time'],
}
# Add baseline results
flat_results.append({**base_info, 'type': 'baseline', **window_res.get('baseline', {})})
# Add provider results
for provider_name in successfully_loaded_provider_names:
provider_res = window_res.get(provider_name, {}) # Get results or empty dict
flat_results.append({**base_info, 'type': provider_name, **provider_res})
results_df = pd.DataFrame(flat_results)
results_df['start_time'] = pd.to_datetime(results_df['start_time']) # Ensure datetime type
# Filter out rows where essential optimization results are missing
# results_df.dropna(subset=['profit', 'power_schedule'], inplace=True) # Be careful with dropna
# Calculate Profit Absolute Error over time
profit_pivot = results_df.pivot_table(index='start_time', columns='type', values='profit')
mae_df = pd.DataFrame(index=profit_pivot.index)
if 'baseline' in profit_pivot.columns:
for provider_name in successfully_loaded_provider_names:
if provider_name in profit_pivot.columns:
# Use .sub() and .abs() to handle potential NaNs gracefully
mae_df[f'Profit_Abs_Error_{provider_name}'] = profit_pivot[provider_name].sub(profit_pivot['baseline']).abs()
else:
logger.warning(f"Cannot calculate profit MAE for provider '{provider_name}'. Data not found in pivoted results.")
else:
logger.warning("Cannot calculate profit MAE because baseline results are missing or incomplete.")
# --- Plotting ---
# Plot 1: Price and First Hour's Power Schedule Over Time
logger.info("Generating Price and Power Schedule plot.")
continuous_power_data = []
for window_res in window_results_list:
start_time = window_res['start_time']
# Baseline power
baseline_data = window_res.get('baseline', {})
if baseline_data.get('power_schedule') and len(baseline_data['power_schedule']) > 0:
continuous_power_data.append({'time': start_time, 'type': 'baseline', 'power': baseline_data['power_schedule'][0]})
# Provider powers
for provider_name in successfully_loaded_provider_names:
provider_data = window_res.get(provider_name, {})
if provider_data.get('power_schedule') and len(provider_data['power_schedule']) > 0:
continuous_power_data.append({'time': start_time, 'type': provider_name, 'power': provider_data['power_schedule'][0]})
continuous_power_df = pd.DataFrame(continuous_power_data)
if not continuous_power_df.empty:
continuous_power_df['time'] = pd.to_datetime(continuous_power_df['time'])
# Get historical prices corresponding to the evaluation window start times
eval_start_times = results_df['start_time'].unique()
price_plot_df = full_historical_df.loc[eval_start_times, [historical_data_config.target_col]].reset_index()
price_plot_df.rename(columns={price_plot_df.columns[0]: 'time', historical_data_config.target_col: 'price'}, inplace=True) # Use positional index for timestamp column rename
plot_range_start = continuous_power_df['time'].min()
plot_range_end = continuous_power_df['time'].max()
# Filter data for the plot range
filtered_price_df = price_plot_df[(price_plot_df['time'] >= plot_range_start) & (price_plot_df['time'] <= plot_range_end)]
filtered_power_df = continuous_power_df[(continuous_power_df['time'] >= plot_range_start) & (continuous_power_df['time'] <= plot_range_end)]
if not filtered_power_df.empty:
fig1, ax1 = plt.subplots(figsize=(15, 7))
ax2 = ax1.twinx()
sns.lineplot(data=filtered_price_df, x='time', y='price', ax=ax1, color='gray', linestyle='--', label='Historical Price (Window Start)', zorder=1)
ax1.set_ylabel('Price (€/MWh)', color='gray')
ax1.tick_params(axis='y', labelcolor='gray')
sns.lineplot(data=filtered_power_df, x='time', y='power', hue='type', ax=ax2, zorder=2)
ax2.set_ylabel('Power (MW)')
h1, l1 = ax1.get_legend_handles_labels()
h2, l2 = ax2.get_legend_handles_labels()
ax2.legend(h1 + h2, l1 + l2, loc='upper left', title='Schedule Type')
ax1.get_legend().remove() # Remove the original legend from ax1
ax1.set_xlabel('Time')
ax1.set_title('Battery Power Schedule (1st Hour) vs. Historical Price (Window Start)')
plt.tight_layout()
plt.savefig("power_schedule_vs_price.png")
logger.info("Price and Power Schedule plot saved as power_schedule_vs_price.png")
# plt.show()
else:
logger.warning("No power data available within the determined plot range.")
else:
logger.warning("No continuous power data generated for plotting power schedule.")
# Plot 2: Absolute Profit Error over time
logger.info("Generating Profit Absolute Error plot.")
if not mae_df.empty and not mae_df.isnull().all().all(): # Check if not empty and not all NaN
fig2, ax = plt.subplots(figsize=(15, 7))
# Use the plot range from the power plot if available
mae_plot_range_start = plot_range_start if 'plot_range_start' in locals() else mae_df.index.min()
mae_plot_range_end = plot_range_end if 'plot_range_end' in locals() else mae_df.index.max()
filtered_mae_df = mae_df[(mae_df.index >= mae_plot_range_start) & (mae_df.index <= mae_plot_range_end)].copy() # Create copy
# Optional: Handle or remove columns that are all NaN within the range
filtered_mae_df.dropna(axis=1, how='all', inplace=True)
if not filtered_mae_df.empty:
sns.lineplot(data=filtered_mae_df, ax=ax)
ax.set_xlabel('Time')
ax.set_ylabel('Absolute Profit Error vs. Baseline (€)')
ax.set_title('Absolute Profit Error of Providers vs. Baseline over Time')
ax.legend(title='Provider Type')
plt.tight_layout()
plt.savefig("profit_abs_error_over_time.png")
logger.info("Profit Absolute Error plot saved as profit_abs_error_over_time.png")
# plt.show()
else:
logger.warning("MAE data is all NaN or empty within the plot range. Skipping MAE plot.")
else:
logger.warning("No valid data available to plot Profit Absolute Error.")
logger.info("Evaluation and plotting completed.")

View File

View File

View File

View File

View File

@ -0,0 +1,22 @@
# forecasting/base.py
from typing import List, Dict, Any
import pandas as pd
import numpy as np
class ForecastProvider:
def get_forecasts(self,
historical_data: pd.DataFrame,
forecast_horizons: List[int],
optimization_horizon: int) -> Dict[int, np.ndarray]:
"""Returns forecasts for each requested horizon."""
pass
def get_required_lookback(self) -> int:
"""Returns the minimum number of historical data points required."""
pass
def get_forecast_horizons(self) -> List[int]:
"""Returns the list of forecast horizons."""
pass

View File

@ -0,0 +1,188 @@
import logging
from typing import List, Dict, Any, Optional
import numpy as np
import pandas as pd
import torch
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from .base import ForecastProvider
from forecasting_model.utils import FeatureConfig
from forecasting_model.train.model import LSTMForecastLightningModule
from forecasting_model import engineer_features
from optimizer.forecasting.utils import interpolate_forecast
logger = logging.getLogger(__name__)
class EnsembleProvider(ForecastProvider):
"""Provides forecasts using an ensemble of trained LSTM models."""
def __init__(
self,
fold_artifacts: List[Dict[str, Any]],
ensemble_method: str,
ensemble_feature_config: FeatureConfig, # Assumed consistent across folds by loading logic
ensemble_target_col: str, # Assumed consistent
):
if not fold_artifacts:
raise ValueError("EnsembleProvider requires at least one fold artifact.")
self.fold_artifacts = fold_artifacts
self.ensemble_method = ensemble_method
# Store common config for reference, but use fold-specific details in get_forecast
self.ensemble_feature_config = ensemble_feature_config
self.ensemble_target_col = ensemble_target_col
self.common_forecast_horizons = sorted(ensemble_feature_config.forecast_horizon) # Assumed consistent
# Calculate max lookback needed across all folds
max_lookback = 0
for i, fold in enumerate(fold_artifacts):
try:
fold_feature_config = fold['feature_config']
fold_seq_len = fold_feature_config.sequence_length
feature_lookback = 0
if fold_feature_config.lags:
feature_lookback = max(feature_lookback, max(fold_feature_config.lags))
if fold_feature_config.rolling_window_sizes:
feature_lookback = max(feature_lookback, max(w - 1 for w in fold_feature_config.rolling_window_sizes))
fold_total_lookback = fold_seq_len + feature_lookback
max_lookback = max(max_lookback, fold_total_lookback)
except KeyError as e:
raise ValueError(f"Fold artifact {i} is missing expected key: {e}") from e
except Exception as e:
raise ValueError(f"Error processing fold artifact {i} for lookback calculation: {e}") from e
self._required_lookback = max_lookback
logger.debug(f"EnsembleProvider initialized with {len(fold_artifacts)} folds. Method: '{ensemble_method}'. Required lookback: {self._required_lookback}")
if ensemble_method not in ['mean', 'median']:
raise ValueError(f"Unsupported ensemble method: {ensemble_method}. Use 'mean' or 'median'.")
def get_required_lookback(self) -> int:
return self._required_lookback
def get_forecast(
self,
historical_data_slice: pd.DataFrame,
optimization_horizon_hours: int
) -> np.ndarray | None:
"""
Generates forecasts from each fold model, interpolates, and aggregates.
"""
logger.debug(f"EnsembleProvider: Generating forecast for {optimization_horizon_hours} hours using {self.ensemble_method}.")
if len(historical_data_slice) < self._required_lookback:
logger.error(f"Insufficient historical data provided. Need {self._required_lookback}, got {len(historical_data_slice)}.")
return None
fold_forecasts_interpolated = []
last_actual_price = historical_data_slice[self.ensemble_target_col].iloc[-1] # Common anchor for all folds
for i, fold_artifact in enumerate(self.fold_artifacts):
fold_id = fold_artifact.get("fold_id", i + 1)
try:
fold_model: LSTMForecastLightningModule = fold_artifact['model_instance']
fold_feature_config: FeatureConfig = fold_artifact['feature_config']
fold_target_scaler: Optional[Any] = fold_artifact['target_scaler']
fold_target_col: str = fold_artifact['main_forecasting_config'].data.target_col # Use fold specific target
fold_seq_len = fold_feature_config.sequence_length
fold_horizons = sorted(fold_feature_config.forecast_horizon)
# Calculate lookback needed *for this specific fold* to check slice length
fold_feature_lookback = 0
if fold_feature_config.lags: fold_feature_lookback = max(fold_feature_lookback, max(fold_feature_config.lags))
if fold_feature_config.rolling_window_sizes: fold_feature_lookback = max(fold_feature_lookback, max(w - 1 for w in fold_feature_config.rolling_window_sizes))
fold_total_lookback = fold_seq_len + fold_feature_lookback
if len(historical_data_slice) < fold_total_lookback:
logger.warning(f"Fold {fold_id}: Skipping fold. Insufficient historical data in slice for this fold's lookback ({fold_total_lookback} needed).")
continue
# 1. Feature Engineering (using fold's config)
# Slice needs to be long enough for this fold's total lookback.
# The input slice `historical_data_slice` should already be long enough based on max_lookback.
engineered_df_fold = engineer_features(historical_data_slice.copy(), fold_target_col, fold_feature_config)
if engineered_df_fold.isnull().any().any():
logger.warning(f"Fold {fold_id}: NaNs found after feature engineering. Attempting fill.")
engineered_df_fold = engineered_df_fold.ffill().bfill()
if engineered_df_fold.isnull().any().any():
logger.error(f"Fold {fold_id}: NaNs persist after fill. Skipping fold.")
continue
# 2. Create *one* input sequence (using fold's sequence length)
if len(engineered_df_fold) < fold_seq_len:
logger.error(f"Fold {fold_id}: Engineered data ({len(engineered_df_fold)}) is shorter than fold sequence length ({fold_seq_len}). Skipping fold.")
continue
input_sequence_data_fold = engineered_df_fold.iloc[-fold_seq_len:].copy()
feature_columns_fold = [col for col in engineered_df_fold.columns if col != fold_target_col] # Example
if not feature_columns_fold: feature_columns_fold = engineered_df_fold.columns.tolist()
input_sequence_np_fold = input_sequence_data_fold[feature_columns_fold].values
if input_sequence_np_fold.shape != (fold_seq_len, len(feature_columns_fold)):
logger.error(f"Fold {fold_id}: Input sequence has wrong shape. Expected ({fold_seq_len}, {len(feature_columns_fold)}), got {input_sequence_np_fold.shape}. Skipping fold.")
continue
input_tensor_fold = torch.FloatTensor(input_sequence_np_fold).unsqueeze(0)
# 3. Run Inference (using fold's model)
fold_model.eval()
with torch.no_grad():
predictions_scaled_fold = fold_model(input_tensor_fold) # Shape (1, num_fold_horizons)
if predictions_scaled_fold.ndim != 2 or predictions_scaled_fold.shape[0] != 1 or predictions_scaled_fold.shape[1] != len(fold_horizons):
logger.error(f"Fold {fold_id}: Prediction output shape mismatch. Expected (1, {len(fold_horizons)}), got {predictions_scaled_fold.shape}. Skipping fold.")
continue
predictions_scaled_np_fold = predictions_scaled_fold.squeeze(0).cpu().numpy()
# 4. Inverse Transform (using fold's scaler)
predictions_original_scale_fold = predictions_scaled_np_fold
if fold_target_scaler:
try:
predictions_original_scale_fold = fold_target_scaler.inverse_transform(predictions_scaled_np_fold.reshape(-1, 1)).flatten()
except Exception as e:
logger.error(f"Fold {fold_id}: Failed to apply inverse transform: {e}. Skipping fold.", exc_info=True)
continue
# 5. Interpolate (using fold's horizons)
interpolated_forecast_fold = interpolate_forecast(
native_horizons=fold_horizons,
native_predictions=predictions_original_scale_fold,
target_horizon=optimization_horizon_hours,
last_known_actual=last_actual_price
)
if interpolated_forecast_fold is not None:
fold_forecasts_interpolated.append(interpolated_forecast_fold)
logger.debug(f"Fold {fold_id}: Successfully generated interpolated forecast.")
else:
logger.warning(f"Fold {fold_id}: Interpolation failed. Skipping fold.")
except Exception as e:
logger.error(f"Error processing ensemble fold {fold_id}: {e}", exc_info=True)
continue # Skip this fold on error
# --- Aggregation ---
if not fold_forecasts_interpolated:
logger.error("No successful forecasts generated from any ensemble folds.")
return None
logger.debug(f"Aggregating forecasts from {len(fold_forecasts_interpolated)} folds using '{self.ensemble_method}'.")
stacked_predictions = np.stack(fold_forecasts_interpolated, axis=0) # Shape (n_folds, target_horizon)
if self.ensemble_method == 'mean':
final_ensemble_forecast = np.mean(stacked_predictions, axis=0)
elif self.ensemble_method == 'median':
final_ensemble_forecast = np.median(stacked_predictions, axis=0)
else:
# Should be caught in __init__, but double-check
logger.error(f"Internal error: Invalid ensemble method '{self.ensemble_method}' during aggregation.")
return None
logger.debug(f"EnsembleProvider: Successfully generated forecast.")
return final_ensemble_forecast

View File

@ -0,0 +1,150 @@
import logging
from typing import List, Dict, Any, Optional
import numpy as np
import pandas as pd
import torch
from sklearn.preprocessing import StandardScaler, MinMaxScaler
# Imports from our project structure
from .base import ForecastProvider
from forecasting_model.utils import FeatureConfig
from forecasting_model.train.model import LSTMForecastLightningModule
from forecasting_model import engineer_features
from optimizer.forecasting.utils import interpolate_forecast
logger = logging.getLogger(__name__)
class SingleModelProvider(ForecastProvider):
"""Provides forecasts using a single trained LSTM model."""
def __init__(
self,
model_instance: LSTMForecastLightningModule,
feature_config: FeatureConfig,
target_col: str,
target_scaler: Optional[Any], # BaseEstimator, TransformerMixin -> more specific if possible
# input_size: int # Not needed directly if model instance is configured
):
self.model = model_instance
self.feature_config = feature_config
self.target_col = target_col
self.target_scaler = target_scaler
self.sequence_length = feature_config.sequence_length
self.forecast_horizons = sorted(feature_config.forecast_horizon) # Ensure sorted
# Calculate required lookback for feature engineering
feature_lookback = 0
if feature_config.lags:
feature_lookback = max(feature_lookback, max(feature_config.lags))
if feature_config.rolling_window_sizes:
# Rolling window of size W needs W-1 previous points
feature_lookback = max(feature_lookback, max(w - 1 for w in feature_config.rolling_window_sizes))
# Total lookback: sequence length for model input + feature engineering needs
# We need `sequence_length` points for the *last* input sequence.
# The first point of that sequence needs `feature_lookback` points before it.
# So, total points needed before the *end* of the input sequence is sequence_length + feature_lookback.
# Since the input sequence ends *before* the first forecast point (t=1),
# we need `sequence_length + feature_lookback` points before t=1.
self._required_lookback = self.sequence_length + feature_lookback
logger.debug(f"SingleModelProvider initialized. Required lookback: {self._required_lookback} (SeqLen: {self.sequence_length}, FeatLookback: {feature_lookback})")
def get_required_lookback(self) -> int:
return self._required_lookback
def get_forecast(
self,
historical_data_slice: pd.DataFrame,
optimization_horizon_hours: int
) -> np.ndarray | None:
"""
Generates forecast using the single model and interpolates to hourly resolution.
"""
logger.debug(f"SingleModelProvider: Generating forecast for {optimization_horizon_hours} hours.")
if len(historical_data_slice) < self._required_lookback:
logger.error(f"Insufficient historical data provided. Need {self._required_lookback}, got {len(historical_data_slice)}.")
return None
try:
# 1. Feature Engineering
# Use the provided slice which already includes the lookback.
engineered_df = engineer_features(historical_data_slice.copy(), self.target_col, self.feature_config)
# Check for NaNs after feature engineering before creating sequences
if engineered_df.isnull().any().any():
logger.warning("NaNs found after feature engineering. Attempting to fill with ffill/bfill.")
# Be careful about filling target vs features if needed
engineered_df = engineered_df.ffill().bfill()
if engineered_df.isnull().any().any():
logger.error("NaNs persist after fill. Cannot create sequences.")
return None
# 2. Create *one* input sequence ending at the last point of the historical slice
# This sequence is used to predict starting from the next hour (t=1)
if len(engineered_df) < self.sequence_length:
logger.error(f"Engineered data ({len(engineered_df)}) is shorter than sequence length ({self.sequence_length}).")
return None
input_sequence_data = engineered_df.iloc[-self.sequence_length:].copy()
# Convert sequence data to numpy array (excluding target if model expects it that way)
# Assuming model takes all engineered features as input
# TODO: Verify the exact features the model expects (target included/excluded?)
# Assuming all columns except maybe the original target are features
feature_columns = [col for col in engineered_df.columns if col != self.target_col] # Example
if not feature_columns: feature_columns = engineered_df.columns.tolist() # Use all if target wasn't dropped
input_sequence_np = input_sequence_data[feature_columns].values
if input_sequence_np.shape != (self.sequence_length, len(feature_columns)):
logger.error(f"Input sequence has wrong shape. Expected ({self.sequence_length}, {len(feature_columns)}), got {input_sequence_np.shape}")
return None
input_tensor = torch.FloatTensor(input_sequence_np).unsqueeze(0) # Add batch dim
# 3. Run Inference
self.model.eval()
with torch.no_grad():
# Model output shape: (1, num_horizons)
predictions_scaled = self.model(input_tensor)
if predictions_scaled.ndim != 2 or predictions_scaled.shape[0] != 1 or predictions_scaled.shape[1] != len(self.forecast_horizons):
logger.error(f"Model prediction output shape mismatch. Expected (1, {len(self.forecast_horizons)}), got {predictions_scaled.shape}.")
return None
predictions_scaled_np = predictions_scaled.squeeze(0).cpu().numpy() # Shape: (num_horizons,)
# 4. Inverse Transform
predictions_original_scale = predictions_scaled_np
if self.target_scaler:
try:
# Scaler expects shape (n_samples, n_features), even if n_features=1
predictions_original_scale = self.target_scaler.inverse_transform(predictions_scaled_np.reshape(-1, 1)).flatten()
logger.debug("Applied inverse transform to predictions.")
except Exception as e:
logger.error(f"Failed to apply inverse transform: {e}", exc_info=True)
# Decide whether to return scaled or None. Returning None is safer.
return None
# 5. Interpolate
# Use the last actual price from the input data as the anchor point t=0
last_actual_price = historical_data_slice[self.target_col].iloc[-1]
interpolated_forecast = interpolate_forecast(
native_horizons=self.forecast_horizons,
native_predictions=predictions_original_scale,
target_horizon=optimization_horizon_hours,
last_known_actual=last_actual_price
)
if interpolated_forecast is None:
logger.error("Interpolation step failed.")
return None
logger.debug(f"SingleModelProvider: Successfully generated forecast.")
return interpolated_forecast
except Exception as e:
logger.error(f"Error during single model forecast generation: {e}", exc_info=True)
return None

View File

@ -0,0 +1,67 @@
from typing import List, Optional, Dict, Any
import numpy as np
import logging
logger = logging.getLogger(__name__)
# --- Interpolation Helper ---
def interpolate_forecast(
native_horizons: List[int],
native_predictions: np.ndarray,
target_horizon: int,
last_known_actual: Optional[float] = None # Optional: use last known price as t=0 for anchor
) -> np.ndarray | None:
"""
Linearly interpolates model predictions at native horizons to a full hourly sequence.
Args:
native_horizons: List of horizons the model predicts (e.g., [1, 6, 12, 24]). Must not be empty.
native_predictions: Numpy array of predictions corresponding to native_horizons. Must not be empty.
target_horizon: The desired length of the hourly forecast (e.g., 24).
last_known_actual: Optional last actual price before the forecast starts (at t=0). Used as anchor if 0 not in native_horizons.
Returns:
A numpy array of shape (target_horizon,) with interpolated values, or None on error.
"""
if not native_horizons or native_predictions is None or native_predictions.size == 0:
logger.error("Cannot interpolate with empty native horizons or predictions.")
return None
if len(native_horizons) != len(native_predictions):
logger.error(f"Mismatched lengths: native_horizons ({len(native_horizons)}) vs native_predictions ({len(native_predictions)})")
return None
try:
# Ensure horizons are sorted
sorted_indices = np.argsort(native_horizons)
# Use float for potentially non-integer horizons if ever needed, ensure points > 0 usually
xp = np.array(native_horizons, dtype=float)[sorted_indices]
fp = native_predictions[sorted_indices]
# Target points for interpolation (hours 1 to target_horizon)
x_target = np.arange(1, target_horizon + 1, dtype=float)
# Add t=0 point if provided and 0 is not already a native horizon
# This anchors the start of the interpolation.
if last_known_actual is not None and xp[0] > 0:
xp = np.insert(xp, 0, 0.0)
fp = np.insert(fp, 0, last_known_actual)
elif xp[0] == 0 and last_known_actual is not None:
logger.debug("Native horizons include 0, using model's prediction for t=0 instead of last_known_actual.")
elif last_known_actual is None and xp[0] > 0:
logger.warning("No last_known_actual provided and native horizons start > 0. Interpolation might be less accurate at the beginning.")
# If the first native horizon is > 1, np.interp will extrapolate constantly backwards from the first point.
# Check if target range requires extrapolation beyond the model's capability
if target_horizon > xp[-1]:
logger.warning(f"Target horizon ({target_horizon}) extends beyond the maximum native forecast horizon ({xp[-1]}). Extrapolation will occur (constant value).")
interpolated_values = np.interp(x_target, xp, fp)
return interpolated_values
except Exception as e:
logger.error(f"Linear interpolation failed: {e}", exc_info=True)
return None

View File

View File

@ -0,0 +1,74 @@
import cvxpy as cp
import numpy as np
def solve_battery_optimization_hourly(
hourly_prices, # Array of prices for each hour [0, 1, ..., n-1]
initial_B, # Current state of charge (MWh)
max_capacity=1.0, # MWh
max_rate=1.0 # MW (+ve discharge / -ve charge)
):
"""
Solves the battery scheduling optimization problem assuming hourly steps. We want to decide at the start of each hour t=0..n-1
how much power to buy/sell (P_net_t) and therefore the state of charge at the start of each next hour (B_t+1).
Args:
hourly_prices: Prices (€/MWh) for each hour t=0..n-1.
initial_B: The state of charge at the beginning (time t=0).
max_capacity: Maximum battery energy capacity (MWh).
max_rate: Maximum charge/discharge power rate (MW).
Returns:
Tuple: (status, optimal_profit, power_schedule, B_schedule)
Returns (status, None, None, None) if optimization fails.
"""
n_hours = len(hourly_prices)
# --- CVXPY Variables ---
# Power flow for each hour t=0..n-1 (-discharge, +charge)
P = cp.Variable(n_hours, name="Power_Flow_MW")
# State of charge at the START of each hour t=0..n (B[t] is B at hour t)
B = cp.Variable(n_hours + 1, name="State_of_Charge_MWh")
# --- Objective Function ---
# Profit = sum(price[t] * Power[t])
prices = np.array(hourly_prices)
profit = prices @ P # Equivalent to cp.sum(cp.multiply(prices, P)) / prices.dot(P)
objective = cp.Maximize(profit)
# --- Constraints ---
constraints = []
# 1. Initial B
constraints.append(B[0] == initial_B)
# 2. B Dynamics: B[t+1] = B[t] - P[t] * 1 hour
constraints.append(B[1:] == B[:-1] + P)
# 3. Power Rate Limits: -max_rate <= P[t] <= max_rate
constraints.append(cp.abs(P) <= max_rate)
# 4. B Limits: 0 <= B[t] <= max_capacity (applies to B[0]...B[n])
constraints.append(B >= 0)
constraints.append(B <= max_capacity)
# --- Problem Definition and Solving ---
problem = cp.Problem(objective, constraints)
try:
# Alternative solvers are ECOS, MOSEK, and SCS
optimal_profit = problem.solve(solver=cp.CLARABEL, verbose=False)
if problem.status in [cp.OPTIMAL, cp.OPTIMAL_INACCURATE]:
return (
problem.status,
optimal_profit,
P.value, # NumPy array of optimal power flows per hour
B.value # NumPy array of optimal B at start of each hour
)
else:
print(f"Optimization failed. Solver status: {problem.status}")
return problem.status, None, None, None
except cp.error.SolverError as e:
print(f"Solver Error: {e}")
return "Solver Error", None, None, None

View File

View File

@ -0,0 +1,41 @@
## Optimizer Definition for a constraint n-forecast Trading-Problem
We want to optimize the performance of an energy trader given the forecast for n steps.
The battery:
- holds 1MWh
- charges/discharges at max. 1MW per hour (we can add/loose x*1MW, x \in R )
Prices are stable for the given hour (t) and we sell and buy for the same price.
### Considerations:
- Single variable, P (=x), for each hour t from 0 to n-1.
- If P > 0, it represents discharging (selling power) with a magnitude of P.
- If P < 0, it represents charging (buying power) with a magnitude of -P.
- If P = 0, it represents holding (doing nothing).
- if we have forecasts for t_n, t_n+m we might have to **interpolate** between n .. m
- or... we work with the gaps and dt as charge time .... no
#### Variables:
- price_t = price per MWH at t (eq)
- B (t=0..n) = State of Battery in MWH
- P (t=0..n) = Charge/Discharge factor given the possible base rate of 1MW/h
- max_p = 1 (charge/discharge limits) & and battery capacity limits (both=1)
- SoB_initial = 0
- h = horizon \in N^+
### Objective
- We **Maximize**: Sum_{t=0}^{n-1} (price_t * P)
### Constraints
- Fixed starting state: SoB_0 = SoB_initial
- Charge/Discharge Limit: (-max_p <= P <= max_p) for all t = 0, ..., n-1
- Storage Limit: (0 <= B+(1*P) <= max_p) for all t = 0, ..., n-1
- Future B State: SoB_{t+1} = (B + P) for t = 0 to n-1

297
optimizer/utils/model_io.py Normal file
View File

@ -0,0 +1,297 @@
import logging
import yaml
import json
from pathlib import Path
from typing import Dict, Any, Optional, List
import torch
from sklearn.base import BaseEstimator, TransformerMixin # For scaler type hint
# Import necessary components from forecasting_model
from forecasting_model.utils.forecast_config_model import MainConfig, FeatureConfig
from forecasting_model.train.model import LSTMForecastLightningModule
logger = logging.getLogger(__name__)
def load_single_model_artifact(
model_path: Path,
config_path: Path,
input_size_path: Path,
target_scaler_path: Optional[Path] = None
) -> Optional[Dict[str, Any]]:
"""
Loads artifacts for a single trained model checkpoint.
Args:
model_path: Path to the model checkpoint file (.ckpt).
config_path: Path to the corresponding main YAML config file.
input_size_path: Path to the input_size.pt file.
target_scaler_path: Optional path to the target_scaler.pt file.
Returns:
A dictionary containing loaded artifacts ('model_instance', 'feature_config',
'target_scaler', 'main_forecasting_config'), or None if loading fails.
"""
logger.info(f"Loading single model artifact from directory: {model_path.parent}")
loaded_artifacts = {}
try:
# 1. Load Config
if not config_path.is_file():
logger.error(f"Config file not found at {config_path}")
return None
with open(config_path, 'r') as f:
config_data = yaml.safe_load(f)
main_config = MainConfig(**config_data)
loaded_artifacts['main_forecasting_config'] = main_config
loaded_artifacts['feature_config'] = main_config.features
logger.debug(f"Loaded config from {config_path}")
# 2. Load Input Size
if not input_size_path.is_file():
logger.error(f"Input size file not found at {input_size_path}")
return None
input_size = torch.load(input_size_path)
if not isinstance(input_size, int) or input_size <= 0:
logger.error(f"Invalid input size loaded from {input_size_path}: {input_size}")
return None
logger.debug(f"Loaded input size ({input_size}) from {input_size_path}")
# 3. Load Target Scaler (Optional)
target_scaler = None
if target_scaler_path:
if not target_scaler_path.is_file():
logger.warning(f"Target scaler file not found at {target_scaler_path}. Proceeding without scaler.")
else:
try:
target_scaler = torch.load(target_scaler_path)
# Basic check if it looks like a scaler
if not isinstance(target_scaler, (BaseEstimator, TransformerMixin)):
logger.warning(f"Loaded object from {target_scaler_path} might not be a valid scaler ({type(target_scaler)}).")
# Decide if this should be a hard failure or just a warning
else:
logger.debug(f"Loaded target scaler from {target_scaler_path}")
except Exception as e:
logger.error(f"Error loading target scaler from {target_scaler_path}: {e}", exc_info=True)
# Decide if this should be a hard failure
return None # Fail hard if scaler loading fails
loaded_artifacts['target_scaler'] = target_scaler
# 4. Initialize Model Architecture
# Ensure model config forecast horizon matches feature config (should be guaranteed by MainConfig validation)
if set(main_config.model.forecast_horizon) != set(main_config.features.forecast_horizon):
logger.warning(f"Mismatch between model ({main_config.model.forecast_horizon}) and feature ({main_config.features.forecast_horizon}) forecast horizons in config {config_path}. Using feature config.")
# This might indicate an issue with the saved config, but we proceed using the feature config horizon
# main_config.model.forecast_horizon = main_config.features.forecast_horizon # Correct it for model init? Risky.
model_instance = LSTMForecastLightningModule(
model_config=main_config.model,
train_config=main_config.training, # Pass train config if needed
input_size=input_size,
target_scaler=target_scaler # Pass scaler to model if it uses it internally during inference
)
logger.debug("Initialized model architecture.")
# 5. Load Model State Dictionary
if not model_path.is_file():
logger.error(f"Model checkpoint file not found at {model_path}")
return None
# Load onto CPU first to avoid GPU memory issues if the loading machine is different
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
# Adjust state dict keys if saved with 'model.' prefix from Lightning wrapper common during saving ckpt
if any(key.startswith('model.') for key in state_dict.get('state_dict', state_dict).keys()):
state_dict = {k.partition('model.')[2]: v for k, v in state_dict.get('state_dict', state_dict).items()}
logger.debug("Adjusted state dict keys (removed 'model.' prefix).")
# Load the state dict
# Use strict=False initially if unsure about exact key matching, but strict=True is safer
try:
load_result = model_instance.load_state_dict(state_dict, strict=True)
logger.debug(f"Model state loaded. Result: {load_result}")
except RuntimeError as e:
logger.error(f"Error loading state dict into model (strict=True): {e}. Trying strict=False.")
try:
load_result = model_instance.load_state_dict(state_dict, strict=False)
logger.warning(f"Model state loaded with strict=False. Result: {load_result}. Check for missing/unexpected keys.")
except Exception as e_false:
logger.error(f"Failed to load state dict even with strict=False: {e_false}", exc_info=True)
return None
model_instance.eval() # Set model to evaluation mode
loaded_artifacts['model_instance'] = model_instance
logger.info(f"Successfully loaded single model artifact: {model_path.name}")
return loaded_artifacts
except FileNotFoundError:
logger.error(f"A required file was not found during artifact loading for {model_path.parent}.", exc_info=True)
return None
except yaml.YAMLError as e:
logger.error(f"Error parsing YAML config file {config_path}: {e}", exc_info=True)
return None
except Exception as e:
logger.error(f"Failed to load single model artifact from {model_path.parent}: {e}", exc_info=True)
return None
def load_ensemble_artifact(
ensemble_definition_path: Path,
hpo_base_output_dir: Path # Base directory where HPO study results (including ensemble JSON) are saved
) -> Optional[Dict[str, Any]]:
"""
Loads artifacts for an ensemble based on its definition JSON file.
Args:
ensemble_definition_path: Path to the _best_ensemble.json file.
hpo_base_output_dir: The base directory where the HPO study ran and
where relative paths within the JSON are anchored.
Returns:
A dictionary containing 'ensemble_method', 'fold_artifacts' (a list
of dictionaries, each like the output of load_single_model_artifact),
'ensemble_feature_config', and 'ensemble_target_col', or None if loading fails.
"""
logger.info(f"Loading ensemble artifact definition from: {ensemble_definition_path}")
try:
if not ensemble_definition_path.is_file():
logger.error(f"Ensemble definition file not found at: {ensemble_definition_path}")
return None
with open(ensemble_definition_path, 'r') as f:
ensemble_definition = json.load(f)
except json.JSONDecodeError as e:
logger.error(f"Error decoding ensemble definition JSON file: {e}", exc_info=True)
return None
except Exception as e:
logger.error(f"Error loading ensemble definition: {e}", exc_info=True)
return None
# Extract information from the definition
ensemble_method = ensemble_definition.get("ensemble_method")
fold_models_definitions = ensemble_definition.get("fold_models")
# Base directory for artifacts *relative to* hpo_base_output_dir
relative_artifacts_base_dir = ensemble_definition.get("ensemble_artifacts_base_dir")
if not ensemble_method or not fold_models_definitions:
logger.error("Ensemble definition file is missing 'ensemble_method' or 'fold_models' list.")
return None
if not relative_artifacts_base_dir:
logger.error("Ensemble definition file is missing 'ensemble_artifacts_base_dir'. Cannot locate fold artifacts.")
return None
# --- Determine Absolute Path to Fold Artifacts ---
# The paths inside fold_models are relative to ensemble_artifacts_base_dir,
# which itself is relative to hpo_base_output_dir.
absolute_artifacts_base_dir = hpo_base_output_dir / Path(relative_artifacts_base_dir)
logger.debug(f"Absolute base directory for fold artifacts: {absolute_artifacts_base_dir}")
if not absolute_artifacts_base_dir.is_dir():
logger.error(f"Calculated absolute artifact base directory does not exist or is not a directory: {absolute_artifacts_base_dir}")
return None
loaded_fold_artifacts: List[Dict[str, Any]] = []
common_feature_config: Optional[FeatureConfig] = None
common_target_col: Optional[str] = None
logger.info(f"Loading artifacts for {len(fold_models_definitions)} folds defined in the ensemble...")
# --- Load Artifacts for Each Fold ---
for i, fold_def in enumerate(fold_models_definitions):
fold_id = fold_def.get("fold_id", i + 1)
logger.debug(f"--- Loading Fold {fold_id} ---")
model_path_rel = fold_def.get("model_path")
scaler_path_rel = fold_def.get("target_scaler_path")
input_size_path_rel = fold_def.get("input_size_path")
config_path_rel = fold_def.get("config_path")
if not model_path_rel or not input_size_path_rel or not config_path_rel:
logger.error(f"Fold {fold_id}: Definition is missing required path(s) (model, input_size, or config). Skipping fold.")
continue
# Construct absolute paths for this fold's artifacts
try:
abs_model_path = (absolute_artifacts_base_dir / Path(model_path_rel)).resolve()
abs_input_size_path = (absolute_artifacts_base_dir / Path(input_size_path_rel)).resolve()
abs_config_path = (absolute_artifacts_base_dir / Path(config_path_rel)).resolve()
abs_scaler_path = (absolute_artifacts_base_dir / Path(scaler_path_rel)).resolve() if scaler_path_rel else None
logger.debug(f"Fold {fold_id} - Model Path: {abs_model_path}")
logger.debug(f"Fold {fold_id} - Config Path: {abs_config_path}")
logger.debug(f"Fold {fold_id} - Input Size Path: {abs_input_size_path}")
logger.debug(f"Fold {fold_id} - Scaler Path: {abs_scaler_path}")
# Load the artifacts for this single fold using the other function
single_fold_loaded_artifacts = load_single_model_artifact(
model_path=abs_model_path,
config_path=abs_config_path,
input_size_path=abs_input_size_path,
target_scaler_path=abs_scaler_path
)
if single_fold_loaded_artifacts:
# Add fold_id for reference
single_fold_loaded_artifacts['fold_id'] = fold_id
loaded_fold_artifacts.append(single_fold_loaded_artifacts)
logger.info(f"Successfully loaded artifacts for fold {fold_id}.")
# --- Consistency Check (Optional but Recommended) ---
# Store the feature config and target col from the first successful fold
# Then compare subsequent folds against these
current_feature_config = single_fold_loaded_artifacts['feature_config']
current_target_col = single_fold_loaded_artifacts['main_forecasting_config'].data.target_col
if common_feature_config is None:
common_feature_config = current_feature_config
common_target_col = current_target_col
logger.debug(f"Set common feature config and target column based on fold {fold_id}.")
else:
# Compare crucial feature engineering aspects
if common_feature_config.sequence_length != current_feature_config.sequence_length or \
set(common_feature_config.forecast_horizon) != set(current_feature_config.forecast_horizon) or \
common_feature_config.scaling_method != current_feature_config.scaling_method: # Add more checks if needed
logger.error(f"Fold {fold_id}: Feature configuration mismatch with previous folds. Cannot proceed with this ensemble definition.")
# You might want to compare more fields like lags, rolling_windows etc.
return None # Fail hard if configs are inconsistent
if common_target_col != current_target_col:
logger.error(f"Fold {fold_id}: Target column '{current_target_col}' differs from previous folds ('{common_target_col}'). Cannot proceed.")
return None # Fail hard
else:
logger.error(f"Failed to load artifacts for fold {fold_id}. Skipping fold.")
# Decide if ensemble loading should fail if *any* fold fails
# For now, we continue and will check if enough folds loaded later
except TypeError as e:
# Catch potential errors if paths are None or invalid types
logger.error(f"Fold {fold_id}: Error constructing artifact paths - check definition file content: {e}", exc_info=True)
continue
except Exception as e:
logger.error(f"Fold {fold_id}: Unexpected error during loading: {e}", exc_info=True)
continue # Skip this fold
# --- Final Checks and Return ---
if not loaded_fold_artifacts:
logger.error("Failed to load artifacts for *any* fold in the ensemble.")
return None
# Add a check if a minimum number of folds is required (e.g., > 1)
if len(loaded_fold_artifacts) < 1: # Or maybe check against len(fold_models_definitions)?
logger.error(f"Only successfully loaded {len(loaded_fold_artifacts)} folds, which might be insufficient for the ensemble.")
# Decide if this is an error or just a warning
return None
if common_feature_config is None or common_target_col is None:
# This should not happen if loaded_fold_artifacts is not empty, but check anyway
logger.error("Internal error: Could not determine common feature config or target column for the ensemble.")
return None
logger.info(f"Successfully loaded artifacts for {len(loaded_fold_artifacts)} ensemble folds.")
return {
'ensemble_method': ensemble_method,
'fold_artifacts': loaded_fold_artifacts, # List of dicts
'ensemble_feature_config': common_feature_config, # The common config
'ensemble_target_col': common_target_col # The common target column name
}

View File

@ -0,0 +1,18 @@
from pydantic import BaseModel, Field
from typing import List, Optional, Literal
class ModelEvalConfig(BaseModel):
"""Configuration for evaluating a single forecasting model or an ensemble."""
name: str = Field(..., description="Name of the forecasting model or ensemble.")
type: Literal['model', 'ensemble'] = Field(..., description="Type of evaluation artifact: 'model' for a single checkpoint, 'ensemble' for an ensemble definition JSON.")
model_path: str = Field(..., description="Path to the saved PyTorch model file (.ckpt for type='model') or the ensemble definition JSON file (.json for type='ensemble').")
model_config_path: str = Field(..., description="Path to the forecasting config (YAML) used for this model training (or for the best trial in an ensemble).")
target_scaler_path: Optional[str] = Field(None, description="Path to the target scaler file for the single model (or will be loaded per fold for ensemble).")
class OptimizationRunConfig(BaseModel):
"""Main configuration for running battery optimization with multiple models/ensembles."""
initial_b: float = Field(..., description="Initial state of charge of the battery (MWh).")
max_capacity: float = Field(..., description="Maximum energy capacity of the battery (MWh).")
max_rate: float = Field(..., description="Maximum charge/discharge power rate of the battery (MW).")
optimization_horizon_hours: int = Field(24, gt=0, description="The length of the time window (in hours) for optimization.")
models: List[ModelEvalConfig] = Field(..., description="List of forecasting models or ensembles to evaluate.")

603
optuna_ensemble_run.py Normal file
View File

@ -0,0 +1,603 @@
import argparse
import logging
import sys
import warnings
import copy # For deep copying config
from pathlib import Path
import time
import numpy as np
import pandas as pd
import torch
import yaml
import json # Import json to save best ensemble definition
import optuna
# Import necessary components from the forecasting_model package
from forecasting_model.utils.forecast_config_model import MainConfig
from forecasting_model.data_processing import (
load_raw_data,
TimeSeriesCrossValidationSplitter,
# prepare_fold_data_and_loaders used by run_single_fold
)
# Import the single fold runner from the main script
from forecasting_model_run import run_single_fold
from forecasting_model.train.ensemble_evaluation import run_ensemble_evaluation
from typing import List, Optional, Tuple, Dict, Any # Added Any for dictionary
# Import helper functions
from forecasting_model.utils.helper import load_config, set_seeds, aggregate_cv_metrics, save_results
# --- Suppress specific PL warnings about logger=True with no logger ---
# This is expected behavior in optuna_run.py where logger=False is intentional
warnings.filterwarnings(
"ignore",
message=".*You called `self.log.*logger=True.*no logger configured.*",
category=UserWarning, # These specific warnings are often UserWarnings
module="pytorch_lightning.core.module"
)
# Silence overly verbose libraries if needed
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)
pil_logger = logging.getLogger('PIL')
pil_logger.setLevel(logging.WARNING)
pl_logger = logging.getLogger('pytorch_lightning')
pl_logger.setLevel(logging.WARNING) # Set PL to WARNING by default, INFO/DEBUG set below if needed
# --- Basic Logging Setup ---
# Configure logging early. Level will be set properly later based on config.
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)-7s - %(message)s',
datefmt='%H:%M:%S')
# Get the root logger
logger = logging.getLogger()
# --- Argument Parsing ---
def parse_arguments():
"""Parses command-line arguments for Optuna Ensemble HPO."""
parser = argparse.ArgumentParser(
description="Run HPO optimizing ensemble performance using Optuna.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
'-c', '--config', type=str, default='forecasting_config.yaml',
help="Path to the YAML configuration file."
)
parser.add_argument(
'--output-dir', type=str, default=None,
help="Override base output directory for HPO results."
)
parser.add_argument(
'--keep-artifacts', action='store_true',
help="Prevent cleanup of trial directories after best trial is determined."
)
args = parser.parse_args()
return args
# --- Optuna Objective Function ---
def objective(
trial: optuna.Trial,
base_config: MainConfig,
df: pd.DataFrame,
hpo_base_output_dir: Path # Pass base dir for trial outputs
) -> float: # Return the single ensemble metric to optimize
"""
Optuna objective function optimizing ensemble performance.
"""
logger.info(f"\n--- Starting Optuna Trial {trial.number} ---")
trial_start_time = time.perf_counter()
# Define trial-specific output directory for fold artifacts
trial_artifacts_dir = hpo_base_output_dir / "ensemble_runs_artifacts" / f"trial_{trial.number}"
trial_artifacts_dir.mkdir(parents=True, exist_ok=True)
logger.debug(f"Trial artifacts will be saved to: {trial_artifacts_dir}")
hpo_config = base_config.optuna
# Metric for pruning based on individual fold performance
validation_metric_monitor = hpo_config.metric_to_optimize
# Ensemble metric and method to optimize (e.g., MAE of the 'mean' ensemble)
ensemble_metric_optimize = 'MAE'
ensemble_method_optimize = 'mean'
optimization_direction = hpo_config.direction # 'minimize' or 'maximize'
worst_value = float('inf') if optimization_direction == 'minimize' else float('-inf')
# Store paths and details for all saved artifacts for this trial's folds
fold_artifact_details: List[Dict[str, Any]] = [] # Changed to list of dicts
# --- 1. Suggest Hyperparameters ---
try:
trial_config_dict = copy.deepcopy(base_config.model_dump(mode='python'))
except Exception as e:
logger.error(f"Trial {trial.number}: Failed to deep copy base config: {e}", exc_info=True)
return worst_value
# ----- Suggest Hyperparameters -----
# Modify trial_config_dict using trial.suggest_*
trial_config_dict['training']['learning_rate'] = trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True)
trial_config_dict['training']['batch_size'] = trial.suggest_categorical('batch_size', [32, 64, 128, 256])
trial_config_dict['training']['loss_function'] = trial.suggest_categorical('loss_function', ['MSE', 'MAE'])
trial_config_dict['model']['hidden_size'] = trial.suggest_int('hidden_size', 18, 498, step=32)
trial_config_dict['model']['num_layers'] = trial.suggest_int('num_layers', 1, 8)
trial_config_dict['model']['dropout'] = trial.suggest_float('dropout', 0.0, 0.25, step=0.05)
trial_config_dict['features']['sequence_length'] = trial.suggest_int('sequence_length', 24, 168, step=12)
trial_config_dict['features']['scaling_method'] = trial.suggest_categorical('scaling_method', ['standard', 'minmax', None])
use_configured_lags = trial.suggest_categorical('use_lags', [True, False])
if not use_configured_lags: trial_config_dict['features']['lags'] = []
use_configured_rolling = trial.suggest_categorical('use_rolling_windows', [True, False])
if not use_configured_rolling: trial_config_dict['features']['rolling_window_sizes'] = []
trial_config_dict['features']['use_time_features'] = trial.suggest_categorical('use_time_features', [True, False])
trial_config_dict['features']['sinus_curve'] = trial.suggest_categorical('sinus_curve', [True, False])
trial_config_dict['features']['cosin_curve'] = trial.suggest_categorical('cosin_curve', [True, False])
trial_config_dict['features']['fill_nan'] = trial.suggest_categorical('fill_nan', ['ffill', 'bfill', 0])
# ----- End of Suggestions -----
# --- 2. Re-validate Trial Config ---
try:
trial_config_dict['features']['forecast_horizon'] = base_config.features.forecast_horizon
# Disable plotting during HPO runs to save time/resources
trial_config_dict['evaluation']['save_plots'] = False
trial_config = MainConfig(**trial_config_dict)
logger.info(f"Trial {trial.number} Parameters: {trial.params}")
except Exception as e:
logger.error(f"Trial {trial.number}: Invalid config generated: {e}", exc_info=True)
return worst_value
# --- Early check for data length ---
# ... (Keep the check as in optuna_run.py) ...
try:
if not isinstance(trial_config.features.forecast_horizon, list) or not trial_config.features.forecast_horizon:
raise ValueError("Trial config has invalid forecast_horizon list.")
min_data_for_sequence = trial_config.features.sequence_length + max(trial_config.features.forecast_horizon)
if min_data_for_sequence > len(df):
logger.warning(f"Trial {trial.number}: Skipped. sequence_length + max_horizon ({min_data_for_sequence}) exceeds data length ({len(df)}).")
# Report worst value so Optuna knows this trial failed badly
# Using study direction to determine the appropriate "worst" value
return worst_value
# raise optuna.TrialPruned() # Alternative: Prune instead of returning worst value
except Exception as e:
logger.error(f"Trial {trial.number}: Error during pre-check: {e}", exc_info=True)
return worst_value
# --- 3. Run Cross-Validation Training (Saving Artifacts) ---
all_fold_best_val_scores = {} # Store best val scores for pruning
actual_folds_trained = 0
# Store paths to saved models and scalers for this trial
fold_model_paths = []
fold_scaler_paths = []
try:
cv_splitter = TimeSeriesCrossValidationSplitter(trial_config.cross_validation, len(df))
for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split()):
fold_id = fold_num + 1
logger.info(f"Trial {trial.number}, Fold {fold_id}/{cv_splitter.n_splits}: Training model...")
current_fold_best_metric = None # Reset for each fold
try:
# Use run_single_fold - it handles training and saving artifacts
# Pass trial_output_dir so fold artifacts are saved per trial
fold_metrics, best_val_score, saved_model_path, saved_scaler_path, saved_input_size_path, saved_config_path = run_single_fold(
fold_num=fold_num,
train_idx=train_idx, val_idx=val_idx, test_idx=test_idx,
config=trial_config, # Use the config with trial's hyperparameters
full_df=df,
output_base_dir=trial_artifacts_dir # Save folds under trial dir
)
actual_folds_trained += 1
all_fold_best_val_scores[fold_id] = best_val_score
# Store all artifact paths for this fold
fold_artifact_details.append({
"fold_id": fold_id,
"model_path": str(saved_model_path) if saved_model_path else None,
"target_scaler_path": str(saved_scaler_path) if saved_scaler_path else None,
"input_size_path": str(saved_input_size_path) if saved_input_size_path else None,
"config_path": str(saved_config_path) if saved_config_path else None,
})
# Check if the monitored validation metric was returned
if best_val_score is not None and np.isfinite(best_val_score):
current_fold_best_metric = best_val_score
logger.info(f"Trial {trial.number}, Fold {fold_id}: Best val score ({validation_metric_monitor}) = {current_fold_best_metric:.4f}")
else:
# Use worst value if metric is missing/invalid for pruning
logger.warning(f"Trial {trial.number}, Fold {fold_id}: Invalid or missing validation score ({validation_metric_monitor}). Using {worst_value} for pruning.")
current_fold_best_metric = worst_value # Assign worst for pruning report
# Report intermediate value (individual fold validation score) for pruning
trial.report(current_fold_best_metric, fold_num)
if trial.should_prune():
logger.info(f"Trial {trial.number}: Pruned after fold {fold_id}.")
raise optuna.TrialPruned()
except optuna.TrialPruned:
raise # Propagate prune signal
except Exception as e:
logger.error(f"Trial {trial.number}, Fold {fold_id}: Failed CV fold training: {e}", exc_info=True)
all_fold_best_val_scores[fold_id] = None # Mark fold as failed
# Continue to next fold if possible, but report worst value for this fold
trial.report(worst_value, fold_num)
# Optionally raise prune here if too many folds fail? Or let the ensemble eval handle it.
except optuna.TrialPruned:
logger.info(f"Trial {trial.number}: Pruned during CV training phase.")
return worst_value # Return worst value when pruned
except Exception as e:
logger.critical(f"Trial {trial.number}: Failed critically during CV training setup/loop: {e}", exc_info=True)
return worst_value
# --- 4. Run Ensemble Evaluation ---
if actual_folds_trained < 2:
logger.error(f"Trial {trial.number}: Only {actual_folds_trained} folds trained successfully. Cannot run ensemble evaluation.")
return worst_value # Not enough models for ensemble
logger.info(f"Trial {trial.number}: Starting Ensemble Evaluation using {actual_folds_trained} trained models...")
ensemble_metric_final = worst_value # Initialize to worst
best_ensemble_method_for_trial = None # Track the best method for this trial
try:
# Run evaluation using the artifacts saved in the trial's output directory
ensemble_results = run_ensemble_evaluation(
config=trial_config, # Pass trial config
output_base_dir=trial_artifacts_dir # Directory containing trial's fold subdirs
)
if ensemble_results:
# Aggregate the results to get the final objective value
ensemble_metrics_for_method = []
for fold_num, fold_res in ensemble_results.items():
if fold_res and ensemble_method_optimize in fold_res:
method_metrics = fold_res[ensemble_method_optimize]
if method_metrics and ensemble_metric_optimize in method_metrics:
metric_val = method_metrics[ensemble_metric_optimize]
if metric_val is not None and np.isfinite(metric_val):
ensemble_metrics_for_method.append(metric_val)
else:
logger.warning(f"Trial {trial.number}: Invalid ensemble metric value found for fold {fold_num}, method '{ensemble_method_optimize}', metric '{ensemble_metric_optimize}'.")
else:
logger.warning(f"Trial {trial.number}: Metric '{ensemble_metric_optimize}' not found for method '{ensemble_method_optimize}' in fold {fold_num}.")
else:
logger.warning(f"Trial {trial.number}: Ensemble method '{ensemble_method_optimize}' results not found for fold {fold_num}.")
if not ensemble_metrics_for_method:
logger.error(f"Trial {trial.number}: No valid ensemble metrics found for method '{ensemble_method_optimize}', metric '{ensemble_metric_optimize}'.")
ensemble_metric_final = worst_value
else:
# Calculate the mean of the chosen ensemble metric across test folds
ensemble_metric_final = np.mean(ensemble_metrics_for_method)
logger.info(f"Trial {trial.number}: Final Ensemble Metric (Avg {ensemble_method_optimize} {ensemble_metric_optimize}): {ensemble_metric_final:.6f}")
# Determine the best ensemble method based on average performance across folds
# This requires re-calculating averages for *all* methods evaluated by run_ensemble_evaluation
all_ensemble_methods = set()
for fold_res in ensemble_results.values():
if fold_res: all_ensemble_methods.update(fold_res.keys())
avg_metrics_per_method = {}
for method in all_ensemble_methods:
method_metrics_across_folds = []
for fold_res in ensemble_results.values():
if fold_res and method in fold_res and ensemble_metric_optimize in fold_res[method]:
metric_val = fold_res[method][ensemble_metric_optimize]
if metric_val is not None and np.isfinite(metric_val):
method_metrics_across_folds.append(metric_val)
if method_metrics_across_folds:
avg_metrics_per_method[method] = np.mean(method_metrics_across_folds)
if avg_metrics_per_method:
if optimization_direction == 'minimize':
best_ensemble_method_for_trial = min(avg_metrics_per_method, key=avg_metrics_per_method.get)
else: # maximize
best_ensemble_method_for_trial = max(avg_metrics_per_method, key=avg_metrics_per_method.get)
logger.info(f"Trial {trial.number}: Best performing ensemble method for this trial (based on avg {ensemble_metric_optimize}): {best_ensemble_method_for_trial}")
else:
logger.warning(f"Trial {trial.number}: Could not determine best ensemble method based on average {ensemble_metric_optimize}.")
else:
logger.error(f"Trial {trial.number}: Ensemble evaluation function returned no results.")
ensemble_metric_final = worst_value
except Exception as e:
logger.error(f"Trial {trial.number}: Failed during ensemble evaluation phase: {e}", exc_info=True)
ensemble_metric_final = worst_value
# --- 5. Return Final Objective Value ---
trial_duration = time.perf_counter() - trial_start_time
logger.info(f"--- Trial {trial.number}: Finished ---")
logger.info(f" Final Objective Value (Avg Ensemble {ensemble_method_optimize} {ensemble_metric_optimize}): {ensemble_metric_final:.6f}")
logger.info(f" Total time: {trial_duration:.2f}s")
# Store ensemble evaluation results and best method in trial user attributes
# This makes it easier to retrieve the best ensemble details after the study
trial.set_user_attr("ensemble_evaluation_results", ensemble_results)
trial.set_user_attr("best_ensemble_method", best_ensemble_method_for_trial)
# trial.set_user_attr("fold_model_paths", fold_model_paths) # Removed
# trial.set_user_attr("fold_scaler_paths", fold_scaler_paths) # Removed
trial.set_user_attr("fold_artifact_details", fold_artifact_details) # Added comprehensive artifact details
return ensemble_metric_final
# --- Main HPO Execution ---
def run_hpo():
"""Main execution function for HPO optimizing ensemble performance."""
args = parse_arguments()
config_path = Path(args.config)
try:
base_config = load_config(config_path)
logger.info(f"Successfully loaded configuration from {config_path}")
except Exception as e:
logger.critical(f"Failed to load configuration from {config_path}: {e}", exc_info=True)
sys.exit(1)
# --- Setup Output Dir ---
if args.output_dir:
hpo_base_output_dir = Path(args.output_dir)
elif base_config.optuna.storage and base_config.optuna.storage.startswith("sqlite:///"):
hpo_base_output_dir = Path(base_config.optuna.storage.replace("sqlite:///", "")).parent
else:
# Fallback to default if output_dir is not in config either
main_output_dir_str = getattr(base_config, 'output_dir', 'output')
if not main_output_dir_str: # Handle empty string case
main_output_dir_str = 'output'
main_output_dir = Path(main_output_dir_str)
hpo_base_output_dir = main_output_dir / f'{base_config.optuna.study_name}_ensemble_hpo' # Specific subdir using study name
hpo_base_output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Using HPO output directory: {hpo_base_output_dir}")
# --- Setup Logging ---
try:
level_name = base_config.log_level.upper()
effective_log_level = logging.getLevelName(level_name)
# Ensure study name is filesystem-safe if used directly
safe_study_name = "".join(c if c.isalnum() or c in ('_', '-') else '_' for c in base_config.optuna.study_name)
log_file = hpo_base_output_dir / f"{safe_study_name}_ensemble_hpo.log"
file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8') # Specify encoding
formatter = logging.Formatter('%(asctime)s - %(name)-25s - %(levelname)-7s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
file_handler.setFormatter(formatter)
# Prevent adding duplicate handlers if script/function is called multiple times
if not any(isinstance(h, logging.FileHandler) and h.baseFilename == str(log_file.resolve()) for h in logger.handlers):
logger.addHandler(file_handler)
logger.setLevel(effective_log_level)
logger.info(f"Set log level to {level_name}. Logging HPO run to console and {log_file}")
if effective_log_level <= logging.DEBUG: logger.debug("Debug logging enabled.")
except (AttributeError, ValueError, TypeError) as e: # Added TypeError
logger.warning(f"Could not set log level from config: {e}. Defaulting to INFO.")
logger.setLevel(logging.INFO)
# Still try to log to a default file if possible
try:
log_file = hpo_base_output_dir / "default_ensemble_hpo.log"
file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8')
formatter = logging.Formatter('%(asctime)s - %(name)-25s - %(levelname)-7s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
file_handler.setFormatter(formatter)
if not any(isinstance(h, logging.FileHandler) and h.baseFilename == str(log_file.resolve()) for h in logger.handlers):
logger.addHandler(file_handler)
logger.info(f"Logging to default file: {log_file}")
except Exception as log_e:
logger.error(f"Failed to set up default file logging: {log_e}")
# --- Setup Seeding ---
set_seeds(getattr(base_config, 'random_seed', 42))
# --- Load Data ---
try:
logger.info("Loading base dataset for HPO...")
df = load_raw_data(base_config.data)
logger.info(f"Base dataset loaded. Shape: {df.shape}")
except FileNotFoundError as e:
logger.critical(f"Data file not found: {e}", exc_info=True)
sys.exit(1)
except Exception as e:
logger.critical(f"Failed to load raw data for HPO: {e}", exc_info=True)
sys.exit(1)
# --- Optuna Study Setup ---
try:
hpo_config = base_config.optuna
if not hpo_config.enabled:
logger.info("Optuna optimization is disabled in the configuration.")
sys.exit(0)
except AttributeError:
logger.critical("Optuna configuration section ('optuna') missing.")
sys.exit(1)
storage_path = hpo_config.storage
if storage_path and storage_path.startswith("sqlite:///"):
db_path_str = storage_path.replace("sqlite:///", "")
if not db_path_str:
# Default filename if only 'sqlite:///' is provided
db_path = hpo_base_output_dir / f"{base_config.optuna.study_name}.db"
logger.warning(f"SQLite path was empty, defaulting to: {db_path}")
else:
db_path = Path(db_path_str)
if not db_path.is_absolute():
db_path = hpo_base_output_dir / db_path
db_path.parent.mkdir(parents=True, exist_ok=True) # Ensure parent dir exists
storage_path = f"sqlite:///{db_path.resolve()}"
logger.info(f"Using SQLite storage: {storage_path}")
elif storage_path:
logger.info(f"Using Optuna storage: {storage_path} (Assuming non-SQLite or pre-configured)")
else:
storage_path = None # Explicitly set to None for in-memory
logger.warning("No Optuna storage DB specified, using in-memory storage.")
try:
# Single objective study based on ensemble performance
study = optuna.create_study(
study_name=hpo_config.study_name,
storage=storage_path,
direction=hpo_config.direction, # 'minimize' or 'maximize'
load_if_exists=True,
pruner=optuna.pruners.MedianPruner() if hpo_config.pruning else optuna.pruners.NopPruner()
)
# --- Run Optimization ---
logger.info(f"Starting Optuna optimization for ensemble performance: study='{hpo_config.study_name}', n_trials={hpo_config.n_trials}, direction='{hpo_config.direction}'")
study.optimize(
lambda trial: objective(trial, base_config, df, hpo_base_output_dir), # Pass base_config and output dir
n_trials=hpo_config.n_trials,
timeout=None,
gc_after_trial=True # Garbage collect after trial
)
# --- Report and Save Best Trial ---
logger.info("--- Optuna HPO Finished ---")
logger.info(f"Number of finished trials: {len(study.trials)}")
# Filter trials to find the actual best one (excluding pruned/failed)
try:
best_trial = study.best_trial
except ValueError: # Optuna raises ValueError if no trials completed successfully
best_trial = None
logger.warning("No successful trials completed. Cannot determine best trial.")
if best_trial:
logger.info("--- Best Trial ---")
logger.info(f" Trial Number: {best_trial.number}")
# Ensure value is not None before formatting
best_value_str = f"{best_trial.value:.6f}" if best_trial.value is not None else "N/A"
logger.info(f" Objective Value (Ensemble Metric): {best_value_str}")
logger.info(f" Hyperparameters:")
best_params = best_trial.params
for key, value in best_params.items():
logger.info(f" {key}: {value}")
# Save best hyperparameters
best_params_file = hpo_base_output_dir / f"{safe_study_name}_best_params.json"
try:
with open(best_params_file, 'w', encoding='utf-8') as f:
import json
json.dump(best_params, f, indent=4)
logger.info(f"Best hyperparameters saved to {best_params_file}")
except Exception as e:
logger.error(f"Failed to save best parameters: {e}", exc_info=True)
# Save the corresponding config
best_config_file = hpo_base_output_dir / f"{safe_study_name}_best_config.yaml"
try:
# Use a fresh deepcopy to avoid modifying the original base_config
best_config_dict = copy.deepcopy(base_config.model_dump(mode='python'))
# Update with best trial's hyperparameters
# Pitfall: This assumes keys match exactly and exist in these sections.
# A more robust approach might involve checking key existence or
# iterating through the config structure if params are nested differently.
for key, value in best_params.items():
if key in best_config_dict.get('training', {}): best_config_dict['training'][key] = value
elif key in best_config_dict.get('model', {}): best_config_dict['model'][key] = value
elif key in best_config_dict.get('features', {}): best_config_dict['features'][key] = value
else:
logger.warning(f"Best parameter '{key}' not found in expected config sections (training, model, features).")
# Ensure forecast horizon is preserved from the original config
best_config_dict['features']['forecast_horizon'] = base_config.features.forecast_horizon
# Maybe remove optuna section from the best config?
# best_config_dict.pop('optuna', None)
with open(best_config_file, 'w', encoding='utf-8') as f:
yaml.dump(best_config_dict, f, default_flow_style=False, sort_keys=False, allow_unicode=True)
logger.info(f"Configuration for best trial saved to {best_config_file}")
except Exception as e:
logger.error(f"Failed to save best configuration: {e}", exc_info=True)
# Retrieve saved artifact paths and best ensemble method from user attributes
best_trial_artifacts_dir = hpo_base_output_dir / "ensemble_runs_artifacts" / f"trial_{best_trial.number}"
best_ensemble_method = best_trial.user_attrs.get("best_ensemble_method")
# fold_model_paths = best_trial.user_attrs.get("fold_model_paths", []) # Removed
# fold_scaler_paths = best_trial.user_attrs.get("fold_scaler_paths", []) # Removed
fold_artifact_details = best_trial.user_attrs.get("fold_artifact_details", []) # Retrieve comprehensive details
if not best_trial_artifacts_dir.exists():
logger.error(f"Artifacts directory for best trial {best_trial.number} not found: {best_trial_artifacts_dir}. Cannot save best ensemble definition.")
elif not best_ensemble_method:
logger.error(f"Best ensemble method not recorded for best trial {best_trial.number}. Cannot save best ensemble definition.")
elif not fold_artifact_details: # Check if any artifact details were recorded
logger.error(f"No artifact details recorded for best trial {best_trial.number}. Cannot save best ensemble definition.")
else:
# --- Save Best Ensemble Definition ---
logger.info(f"Saving best ensemble definition for trial {best_trial.number}...")
ensemble_definition_file = hpo_base_output_dir / f"{safe_study_name}_best_ensemble.json"
best_ensemble_definition = {
"trial_number": best_trial.number,
"objective_value": best_trial.value,
"hyperparameters": best_trial.params,
"ensemble_method": best_ensemble_method,
"fold_models": [], # List of dictionaries for each fold's model and scaler, input_size, config
"ensemble_artifacts_base_dir": str(best_trial_artifacts_dir.relative_to(hpo_base_output_dir)) # Save path relative to hpo_base_output_dir
}
# Populate fold_models with paths to saved artifacts
for artifact_detail in fold_artifact_details:
fold_def = {
"fold_id": artifact_detail.get("fold_id"), # Include fold ID
"model_path": None,
"target_scaler_path": None,
"input_size_path": None,
"config_path": None,
}
# Process each path, making it relative if possible
for key in ["model_path", "target_scaler_path", "input_size_path", "config_path"]:
abs_path_str = artifact_detail.get(key)
if abs_path_str:
abs_path = Path(abs_path_str)
try:
# Make path relative to the trial artifacts dir
relative_path = str(abs_path.relative_to(best_trial_artifacts_dir))
fold_def[key] = relative_path
except ValueError:
logger.warning(f"Failed to make path {abs_path} relative to {best_trial_artifacts_dir}. Saving absolute path for {key}.")
fold_def[key] = str(abs_path) # Fallback to absolute path
best_ensemble_definition["fold_models"].append(fold_def)
try:
with open(ensemble_definition_file, 'w', encoding='utf-8') as f:
json.dump(best_ensemble_definition, f, indent=4)
logger.info(f"Best ensemble definition saved to {ensemble_definition_file}")
except Exception as e:
logger.error(f"Failed to save best ensemble definition: {e}", exc_info=True)
# --- Optional: Clean up artifact directories for non-best trials ---
if not args.keep_artifacts:
logger.info("Cleaning up artifact directories for non-best trials...")
ensemble_artifacts_base_dir = hpo_base_output_dir / "ensemble_runs_artifacts"
if ensemble_artifacts_base_dir.exists():
for item in ensemble_artifacts_base_dir.iterdir():
if item.is_dir():
# Check if this directory belongs to the best trial
if best_trial and item.name == f"trial_{best_trial.number}":
logger.debug(f"Keeping artifact directory for best trial: {item}")
continue
else:
logger.debug(f"Removing artifact directory for non-best trial: {item}")
try:
import shutil
shutil.rmtree(item)
except Exception as e:
logger.error(f"Failed to remove directory {item}: {e}", exc_info=True)
else:
logger.debug(f"Artifacts base directory not found for cleanup: {ensemble_artifacts_base_dir}")
except optuna.exceptions.StorageInternalError as e:
logger.critical(f"Optuna storage error: {e}. Check storage path/permissions: {storage_path}", exc_info=True)
sys.exit(1)
except Exception as e:
logger.critical(f"An critical error occurred during the Optuna study: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
run_hpo() # Changed main() to run_hpo()

View File

@ -1,35 +1,44 @@
import argparse
import logging
import sys
import warnings # Import the warnings module
import copy # For deep copying config
from pathlib import Path
import time
import numpy as np
import pandas as pd
import torch
import yaml # Added for saving best config
import optuna
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
# Import the Optuna callback for pruning
from optuna.integration.pytorch_lightning import PyTorchLightningPruningCallback
from pytorch_lightning.callbacks import EarlyStopping
# Import necessary components from the forecasting_model package
from forecasting_model.utils.config_model import MainConfig
from forecasting_model.utils.forecast_config_model import MainConfig
from forecasting_model.data_processing import (
load_raw_data,
TimeSeriesCrossValidationSplitter,
prepare_fold_data_and_loaders
prepare_fold_data_and_loaders,
split_data_classic
)
from forecasting_model.model import LSTMForecastLightningModule
# We don't need evaluation functions here, Optuna optimizes based on validation metrics
# from forecasting_model.evaluation import ...
from typing import Dict, List, Any, Optional
from forecasting_model.train.model import LSTMForecastLightningModule
from forecasting_model.train.classic import run_classic_training
# Import helper functions from forecasting_model.py (or move them to a shared utils file)
# For now, let's redefine simplified versions or assume they exist in utils
from forecasting_model_run import load_config, set_seeds # Assuming these are accessible
# Import helper functions from forecasting_model_run.py
from forecasting_model.utils.helper import load_config, set_seeds
# Import the data processing functions
from forecasting_model.data_processing import load_raw_data
# --- Suppress specific PL warnings about logger=True with no logger ---
# This is expected behavior in optuna_run.py where logger=False is intentional
warnings.filterwarnings(
"ignore",
message=".*You called `self.log.*logger=True.*no logger configured.*",
category=UserWarning,
module="pytorch_lightning.core.module"
)
# Silence overly verbose libraries if needed
mpl_logger = logging.getLogger('matplotlib')
@ -37,358 +46,396 @@ mpl_logger.setLevel(logging.WARNING)
pil_logger = logging.getLogger('PIL')
pil_logger.setLevel(logging.WARNING)
pl_logger = logging.getLogger('pytorch_lightning')
pl_logger.setLevel(logging.INFO) # Keep PL logs, but maybe set higher later
pl_logger.setLevel(logging.WARNING)
# --- Basic Logging Setup ---
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)-25s - %(levelname)-7s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
root_logger = logging.getLogger()
logger = logging.getLogger(__name__) # Logger for this script
optuna_lg = logging.getLogger('optuna') # Optuna's logger
format='%(asctime)s - %(levelname)-7s - %(message)s',
datefmt='%H:%M:%S')
# Get the root logger
logger = logging.getLogger()
# --- Argument Parsing ---
# --- Argument Parsing (Simplified) ---
def parse_arguments():
"""Parses command-line arguments for Optuna HPO."""
parser = argparse.ArgumentParser(
description="Run Hyperparameter Optimization using Optuna for Time Series Forecasting.",
description="Run Hyperparameter Optimization using Optuna.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
'-c', '--config',
type=str,
default='config.yaml',
help="Path to the BASE YAML configuration file."
'-c', '--config', type=str, default='forecasting_config.yaml',
help="Path to the YAML configuration file containing HPO settings."
)
parser.add_argument(
'--output-dir',
type=str,
default='output/hpo_results',
help="Directory for saving Optuna study database and potentially best trial info."
'--output-dir', type=str, default=None,
help="Override output directory specified in the configuration file."
)
parser.add_argument(
'--study-name',
type=str,
default='lstm_forecasting_hpo',
help="Name for the Optuna study."
)
parser.add_argument(
'--n-trials',
type=int,
default=20,
help="Number of Optuna trials to run."
)
parser.add_argument(
'--storage-db',
type=str,
default=None, # Default to in-memory if not specified
help="Optuna storage database URL (e.g., 'sqlite:///output/hpo_results/study.db'). If None, uses in-memory storage."
)
parser.add_argument(
'--metric-to-optimize',
type=str,
default='val_mae_orig_scale',
help="Metric logged during validation to optimize (must match metric name in LightningModule)."
)
parser.add_argument(
'--direction',
type=str,
default='minimize',
choices=['minimize', 'maximize'],
help="Direction for Optuna optimization."
)
parser.add_argument(
'--pruning',
action='store_true',
help="Enable Optuna's trial pruning based on intermediate validation results."
)
parser.add_argument(
'--seed',
type=int,
default=42, # Fixed seed for the HPO process itself
help="Random seed for the main HPO script."
)
parser.add_argument(
'--debug',
action='store_true',
help="Override log level to DEBUG."
)
args = parser.parse_args()
return args
# --- Optuna Objective Function ---
def objective(
trial: optuna.Trial,
base_config: MainConfig, # Pass the loaded base config
df: pd.DataFrame, # Pass the loaded data
output_base_dir: Path, # Base dir for any potential trial artifacts (usually avoid saving checkpoints here)
metric_to_optimize: str,
enable_pruning: bool
) -> float:
base_config: MainConfig,
df: pd.DataFrame,
) -> float: # Ensure it returns a float
"""
Optuna objective function. Trains and evaluates one set of hyperparameters
using cross-validation and returns the average validation metric.
Optuna single-objective function using a classic train/val/test split.
Returns:
- Validation score from the classic split (minimize).
"""
logger.info(f"\n--- Starting Optuna Trial {trial.number} ---")
trial_start_time = time.perf_counter()
# Get HPO settings
hpo_config = base_config.optuna
# Use the specific metric name from config for validation monitoring
validation_metric_monitor = hpo_config.metric_to_optimize
# The objective will be based on the validation metric
# Hardcode minimization for the single objective
monitor_mode = "min"
worst_value = float('inf') # Represents a poor validation score
# --- 1. Suggest Hyperparameters ---
# Make a deep copy of the base config to modify for this trial
# Using dict conversion and back might be easier than Pydantic's copy for deep nested updates
try:
trial_config_dict = copy.deepcopy(base_config.dict()) # Convert to dict for easier modification
# Deep copy the base config dictionary to modify for the trial
trial_config_dict = copy.deepcopy(base_config.model_dump(mode='python')) # Use mode='python' for easier modification
except Exception as e:
logger.error(f"Failed to deep copy base configuration: {e}")
raise # Cannot proceed without config
logger.error(f"Trial {trial.number}: Failed to deep copy base configuration: {e}", exc_info=True)
# Return worst value
return worst_value
# Suggest values for hyperparameters we want to tune
# Example suggestions (adjust ranges and types as needed):
# ----- Suggest Hyperparameters -----
# (Suggest parameters as before, modifying trial_config_dict)
trial_config_dict['training']['learning_rate'] = trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True)
trial_config_dict['training']['batch_size'] = trial.suggest_categorical('batch_size', [32, 64, 128])
trial_config_dict['model']['hidden_size'] = trial.suggest_int('hidden_size', 32, 256, step=32)
trial_config_dict['model']['num_layers'] = trial.suggest_int('num_layers', 1, 4)
trial_config_dict['model']['dropout'] = trial.suggest_float('dropout', 0.0, 0.5, step=0.1)
# Example: Suggest sequence length? (Requires careful handling as it affects data prep)
# trial_config_dict['features']['sequence_length'] = trial.suggest_int('sequence_length', 24, 168, step=24)
trial_config_dict['training']['batch_size'] = trial.suggest_categorical('batch_size', [32, 64, 128, 256])
trial_config_dict['training']['loss_function'] = trial.suggest_categorical('loss_function', ['MSE', 'MAE'])
trial_config_dict['model']['hidden_size'] = trial.suggest_int('hidden_size', 18, 498, step=32)
trial_config_dict['model']['num_layers'] = trial.suggest_int('num_layers', 1, 8)
trial_config_dict['model']['dropout'] = trial.suggest_float('dropout', 0.0, 0.25, step=0.05)
trial_config_dict['features']['sequence_length'] = trial.suggest_int('sequence_length', 24, 168, step=12)
# Note: forecast_horizon is NOT tuned here, taken from base_config
trial_config_dict['features']['scaling_method'] = trial.suggest_categorical('scaling_method', ['standard', 'minmax', None])
use_configured_lags = trial.suggest_categorical('use_lags', [True, False])
if not use_configured_lags: trial_config_dict['features']['lags'] = []
use_configured_rolling = trial.suggest_categorical('use_rolling_windows', [True, False])
if not use_configured_rolling: trial_config_dict['features']['rolling_window_sizes'] = []
trial_config_dict['features']['use_time_features'] = trial.suggest_categorical('use_time_features', [True, False])
trial_config_dict['features']['sinus_curve'] = trial.suggest_categorical('sinus_curve', [True, False])
trial_config_dict['features']['cosin_curve'] = trial.suggest_categorical('cosin_curve', [True, False])
trial_config_dict['features']['fill_nan'] = trial.suggest_categorical('fill_nan', ['ffill', 'bfill', 0])
# ----- End of Hyperparameter Suggestions -----
# --- 2. Re-validate Trial Config (Optional but Recommended) ---
# --- 2. Re-validate Trial Config ---
try:
# Transfer forecast horizon from base config (not tuned)
trial_config_dict['features']['forecast_horizon'] = base_config.features.forecast_horizon
trial_config = MainConfig(**trial_config_dict)
logger.debug(f"Trial {trial.number} Config: {trial_config.training} {trial_config.model} {trial_config.features}")
logger.info(f"Trial {trial.number} Parameters: {trial.params}")
except Exception as e:
logger.error(f"Trial {trial.number}: Invalid configuration generated from suggested parameters: {e}")
# Return a high value (for minimization) to penalize invalid configs
return float('inf')
# --- 3. Run Cross-Validation for this Trial ---
cv_splitter = TimeSeriesCrossValidationSplitter(trial_config.cross_validation, len(df))
fold_best_val_metrics: List[Optional[float]] = []
for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split()):
fold_id = fold_num + 1
logger.info(f"Trial {trial.number}, Fold {fold_id}: Starting fold evaluation.")
fold_start_time = time.perf_counter()
# Create a temporary directory for this specific trial+fold if needed (usually avoid for HPO)
# fold_trial_dir = output_base_dir / f"trial_{trial.number}" / f"fold_{fold_id:02d}"
# fold_trial_dir.mkdir(parents=True, exist_ok=True)
logger.error(f"Trial {trial.number}: Invalid configuration generated: {e}")
return worst_value
# --- Early check for invalid sequence length / forecast horizon combination ---
try:
# --- Per-Fold Data Prep ---
# Use trial_config for batch sizes etc.
train_loader, val_loader, _, target_scaler, input_size = prepare_fold_data_and_loaders(
full_df=df, train_idx=train_idx, val_idx=val_idx, test_idx=test_idx, # Test loader not needed here
target_col=trial_config.data.target_col,
feature_config=trial_config.features,
train_config=trial_config.training,
eval_config=trial_config.evaluation # Pass eval for batch size if needed by prep?
# Make sure forecast_horizon is a list before checking max()
if not isinstance(trial_config.features.forecast_horizon, list) or not trial_config.features.forecast_horizon:
raise ValueError("Trial config has invalid forecast_horizon list.")
min_data_for_sequence = trial_config.features.sequence_length + max(trial_config.features.forecast_horizon)
if min_data_for_sequence > len(df):
logger.warning(f"Trial {trial.number}: Skipped. sequence_length ({trial_config.features.sequence_length}) + "
f"max_horizon ({max(trial_config.features.forecast_horizon)}) "
f"exceeds data length ({len(df)}).")
# Optuna doesn't directly support skipping, so return worst values
return worst_value
except Exception as e:
logger.error(f"Trial {trial.number}: Error during pre-check: {e}", exc_info=True)
return worst_value
# --- 3. Run Classic Train/Test ---
logger.info(f"Trial {trial.number}: Starting Classic Run...")
validation_metric_value = worst_value # Initialize to worst
try:
n_samples = len(df)
val_frac = trial_config.cross_validation.val_size_fraction
test_frac = trial_config.cross_validation.test_size_fraction
train_idx_cl, val_idx_cl, test_idx_cl = split_data_classic(n_samples, val_frac, test_frac)
# Prepare data for classic split
train_loader_cl, val_loader_cl, test_loader_cl, target_scaler_cl, input_size_cl = prepare_fold_data_and_loaders(
full_df=df, train_idx=train_idx_cl, val_idx=val_idx_cl, test_idx=test_idx_cl,
target_col=trial_config.data.target_col, feature_config=trial_config.features,
train_config=trial_config.training, eval_config=trial_config.evaluation
)
# --- Model Instantiation ---
current_model_config = trial_config.model.copy(update={'input_size': input_size,
'forecast_horizon': trial_config.features.forecast_horizon})
model = LSTMForecastLightningModule(
model_config=current_model_config,
train_config=trial_config.training,
target_scaler=target_scaler
# Initialize Model
model_cl = LSTMForecastLightningModule(
model_config=trial_config.model, train_config=trial_config.training,
input_size=input_size_cl, target_scaler=target_scaler_cl
)
# --- Callbacks for this Trial/Fold ---
# Monitor the metric Optuna cares about
monitor_mode = "min" if args.direction == "minimize" else "max"
callbacks = []
if trial_config.training.early_stopping_patience is not None and trial_config.training.early_stopping_patience > 0:
early_stopping = EarlyStopping(
monitor=metric_to_optimize,
# Callbacks (EarlyStopping and Pruning)
callbacks_cl = []
if trial_config.training.early_stopping_patience and trial_config.training.early_stopping_patience > 0:
callbacks_cl.append(EarlyStopping(
monitor=validation_metric_monitor, # Monitor validation metric
patience=trial_config.training.early_stopping_patience,
mode=monitor_mode,
verbose=False # Less verbose during HPO
)
callbacks.append(early_stopping)
mode=monitor_mode, verbose=False
))
# Add Optuna Pruning Callback
if enable_pruning:
pruning_callback = PyTorchLightningPruningCallback(trial, monitor=metric_to_optimize)
callbacks.append(pruning_callback)
# Optional: LR Monitor
# callbacks.append(LearningRateMonitor(logging_interval='epoch'))
# --- Trainer for this Trial/Fold ---
trainer = pl.Trainer(
accelerator='gpu' if torch.cuda.is_available() else 'cpu',
devices=1 if torch.cuda.is_available() else None,
max_epochs=trial_config.training.epochs,
callbacks=callbacks,
logger=False, # Disable default PL logging during HPO
enable_checkpointing=False, # Disable checkpoint saving during HPO
enable_progress_bar=False, # Disable progress bar for cleaner logs
enable_model_summary=False, # Disable model summary
# Trainer for classic run
trainer_cl = pl.Trainer(
accelerator='gpu' if torch.cuda.is_available() else 'cpu', devices=1 if torch.cuda.is_available() else None,
max_epochs=trial_config.training.epochs, callbacks=callbacks_cl, logger=False, # logger=False as per original
enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False,
gradient_clip_val=getattr(trial_config.training, 'gradient_clip_val', None),
precision=getattr(trial_config.training, 'precision', 32),
# Log GPU usage if available?
# log_gpu_memory='min_max',
)
# --- Fit the Model ---
logger.info(f"Trial {trial.number}, Fold {fold_id}: Fitting model...")
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
# --- Train Model ---
logger.info(f"Trial {trial.number}: Fitting model on classic train/val split...")
trainer_cl.fit(model_cl, train_dataloaders=train_loader_cl, val_dataloaders=val_loader_cl)
# --- Get Best Validation Score for Pruning/Reporting ---
# Access the monitored metric value from the trainer's logged metrics or callback state
# Ensure the key matches exactly what's logged in validation_step
best_val_score = trainer.callback_metrics.get(metric_to_optimize)
# --- Get Best Validation Score ---
# Check early stopping callback first if it exists
best_score_tensor = None
if callbacks_cl and isinstance(callbacks_cl[0], EarlyStopping):
if hasattr(callbacks_cl[0], 'best_score') and callbacks_cl[0].best_score is not None:
best_score_tensor = callbacks_cl[0].best_score
elif callbacks_cl[0].stopped_epoch > 0 : # Early stopping triggered
logger.debug(f"Trial {trial.number}: Early stopping triggered, attempting to use last callback metric.")
if best_val_score is None:
logger.warning(f"Trial {trial.number}, Fold {fold_id}: Metric '{metric_to_optimize}' not found in trainer metrics. Using inf/nan.")
# Handle cases where training might have failed or metric wasn't logged
best_val_score = float('inf') if monitor_mode == 'min' else float('-inf') # Return worst possible value
# If early stopping didn't capture best score, use last metrics from trainer
if best_score_tensor is None:
metric_val = trainer_cl.callback_metrics.get(validation_metric_monitor)
if metric_val is not None:
best_score_tensor = metric_val # Use the last logged value
if best_score_tensor is None:
logger.warning(f"Trial {trial.number}: Metric '{validation_metric_monitor}' not found in callbacks or metrics. Using {worst_value}.")
validation_metric_value = worst_value
else:
best_val_score = best_val_score.item() # Convert tensor to float
logger.info(f"Trial {trial.number}, Fold {fold_id}: Best validation score ({metric_to_optimize}) = {best_val_score:.4f}")
validation_metric_value = best_score_tensor.item()
logger.info(f"Trial {trial.number}: Best val score ({validation_metric_monitor}) = {validation_metric_value:.4f}")
fold_best_val_metrics.append(best_val_score)
# Report intermediate value for pruning (if enabled)
trial.report(validation_metric_value, trainer_cl.current_epoch)
if trial.should_prune():
logger.info(f"Trial {trial.number}: Pruned.")
raise optuna.TrialPruned()
# --- Intermediate Pruning Report (Optional but Recommended) ---
# Report the intermediate value (best score for this fold) to Optuna
# trial.report(best_val_score, fold_id) # Report score at step `fold_id`
# Check if the trial should be pruned based on reported values
# if trial.should_prune():
# logger.info(f"Trial {trial.number}: Pruned after fold {fold_id}.")
# raise optuna.TrialPruned()
# Note: We don't run prediction/evaluation on the test set here,
# as the objective is based on validation performance.
# The test set evaluation will be done later for the best trial.
logger.info(f"Trial {trial.number}, Fold {fold_id}: Finished in {time.perf_counter() - fold_start_time:.2f}s")
logger.info(f"Trial {trial.number}: Finished Classic Run in {time.perf_counter() - trial_start_time:.2f}s")
except optuna.TrialPruned:
# Re-raise prune exception to let Optuna handle it
# Propagate prune signal, objective will be set to worst later by Optuna
raise
except Exception as e:
logger.error(f"Trial {trial.number}, Fold {fold_id}: Failed with error: {e}", exc_info=True)
# Record a failure for this fold (e.g., append NaN or worst value)
fold_best_val_metrics.append(float('inf') if monitor_mode == 'min' else float('-inf'))
# Optionally: Break the CV loop for this trial if one fold fails catastrophically?
# break
logger.error(f"Trial {trial.number}: Failed during classic run phase: {e}", exc_info=True)
validation_metric_value = worst_value # Assign worst value if classic run fails
finally:
# Clean up GPU memory after the run
del model_cl, trainer_cl, train_loader_cl, val_loader_cl, test_loader_cl
if torch.cuda.is_available(): torch.cuda.empty_cache()
# --- 4. Calculate Average Metric Across Folds ---
if not fold_best_val_metrics:
logger.error(f"Trial {trial.number}: No validation results obtained across folds.")
return float('inf') # Return worst value
# Handle potential infinities or NaNs from failed folds
valid_scores = [s for s in fold_best_val_metrics if np.isfinite(s)]
if not valid_scores:
logger.error(f"Trial {trial.number}: All folds failed or produced non-finite scores.")
return float('inf')
average_val_metric = np.mean(valid_scores)
# --- 4. Return Objective ---
logger.info(f"--- Trial {trial.number}: Finished ---")
logger.info(f" Average validation {metric_to_optimize}: {average_val_metric:.5f}")
logger.info(f" Total trial time: {time.perf_counter() - trial_start_time:.2f}s")
logger.info(f" Objective (Validation {validation_metric_monitor}): {validation_metric_value:.5f}")
logger.info(f" Total time: {time.perf_counter() - trial_start_time:.2f}s")
# --- 5. Return Metric for Optuna ---
return average_val_metric
# Return the single objective (validation metric)
return float(validation_metric_value)
# --- Main HPO Execution ---
def run_hpo():
"""Main execution function for HPO."""
global args # Make args accessible in objective (simplifies passing) - or use functools.partial
args = parse_arguments()
config_path = Path(args.config)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) # Ensure output dir exists
# Adjust log level if debug flag is set
if args.debug:
root_logger.setLevel(logging.DEBUG)
optuna_lg.setLevel(logging.DEBUG)
pl_logger.setLevel(logging.DEBUG)
logger.debug("Debug mode enabled.")
else:
# Reduce verbosity during HPO runs
optuna_lg.setLevel(logging.WARNING)
pl_logger.setLevel(logging.INFO) # Keep INFO for PL start/end messages
# --- Configuration Loading ---
try:
base_config = load_config(config_path)
except Exception:
base_config = load_config(config_path) # Load base config once
logger.info(f"Successfully loaded configuration from {config_path}")
except Exception as e:
logger.critical(f"Failed to load configuration from {config_path}: {e}", exc_info=True)
sys.exit(1)
# --- Seed Setting (for HPO script itself) ---
set_seeds(args.seed)
# Setup output dir...
if args.output_dir:
hpo_base_output_dir = Path(args.output_dir) # Use specific name for HPO dir
logger.info(f"Using HPO output directory from command line: {hpo_base_output_dir}")
elif base_config.optuna.storage and base_config.optuna.storage.startswith("sqlite:///"):
hpo_base_output_dir = Path(base_config.optuna.storage.replace("sqlite:///", "")).parent
logger.info(f"Using HPO output directory from Optuna storage path: {hpo_base_output_dir}")
else:
# Use output_dir from main config if available, otherwise default
main_output_dir = Path(getattr(base_config, 'output_dir', 'output'))
hpo_base_output_dir = main_output_dir / 'hpo_results'
logger.info(f"Using HPO output directory: {hpo_base_output_dir}")
hpo_base_output_dir.mkdir(parents=True, exist_ok=True)
# --- Load Data Once ---
# Assume data doesn't change based on HPs (unless sequence_length is tuned heavily)
# Setup logging... (ensure file handler uses hpo_base_output_dir)
try:
logger.info("Loading base dataset...")
level_name = base_config.log_level.upper()
effective_log_level = logging.getLevelName(level_name)
log_file = hpo_base_output_dir / f"{base_config.optuna.study_name}_hpo.log"
file_handler = logging.FileHandler(log_file, mode='a')
formatter = logging.Formatter('%(asctime)s - %(name)-25s - %(levelname)-7s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
file_handler.setFormatter(formatter)
# Add handler only if it's not already added (e.g., if run_hpo is called multiple times)
if not any(isinstance(h, logging.FileHandler) and h.baseFilename == str(log_file.resolve()) for h in logger.handlers):
logger.addHandler(file_handler)
logger.setLevel(effective_log_level)
logger.info(f"Set log level to {level_name}. Logging HPO run to console and {log_file}")
if effective_log_level <= logging.DEBUG: logger.debug("Debug logging enabled.")
except (AttributeError, ValueError) as e:
logger.warning(f"Could not set log level from config. Defaulting to INFO. Error: {e}")
logger.setLevel(logging.INFO)
# Setup seeding...
try:
seed = base_config.random_seed
set_seeds(seed)
except AttributeError:
logger.warning("Config missing 'random_seed'. Using default seed 42.")
set_seeds(42)
# Load data...
try:
logger.info("Loading base dataset for HPO...")
df = load_raw_data(base_config.data)
logger.info("Base dataset loaded.")
logger.info(f"Base dataset loaded. Shape: {df.shape}")
except Exception as e:
logger.critical(f"Failed to load raw data for HPO: {e}", exc_info=True)
sys.exit(1)
# --- Optuna Study Setup ---
storage_path = args.storage_db
if storage_path:
# Ensure directory exists if using SQLite file storage
db_path = Path(storage_path.replace("sqlite:///", ""))
db_path.parent.mkdir(parents=True, exist_ok=True)
storage_path = f"sqlite:///{db_path.resolve()}" # Use absolute path
logger.info(f"Using Optuna storage: {storage_path}")
else:
logger.warning("No Optuna storage DB specified, using in-memory storage (results lost on exit).")
try:
hpo_config = base_config.optuna
if not hpo_config.enabled:
logger.info("Optuna optimization is disabled in the configuration.")
sys.exit(0)
except AttributeError:
logger.critical("Optuna configuration section ('optuna') missing.")
sys.exit(1)
storage_path = hpo_config.storage
if storage_path and storage_path.startswith("sqlite:///"):
db_path = Path(storage_path.replace("sqlite:///", ""))
if not db_path.is_absolute():
db_path = hpo_base_output_dir / db_path # Relative to HPO output dir
db_path.parent.mkdir(parents=True, exist_ok=True)
storage_path = f"sqlite:///{db_path.resolve()}"
logger.info(f"Using Optuna storage: {storage_path}")
elif storage_path:
logger.info(f"Using Optuna storage: {storage_path} (non-SQLite)")
else:
logger.warning("No Optuna storage DB specified in config, using in-memory storage.")
try:
# Create or load the study
# Change to single objective 'minimize'
study = optuna.create_study(
study_name=args.study_name,
study_name=hpo_config.study_name,
storage=storage_path,
direction=args.direction,
load_if_exists=True, # Load previous results if study exists
pruner=optuna.pruners.MedianPruner() if args.pruning else optuna.pruners.NopPruner() # Example pruner
direction="minimize", # Changed to single direction
load_if_exists=True,
pruner=optuna.pruners.MedianPruner() if hpo_config.pruning else optuna.pruners.NopPruner()
)
# Remove multi-objective check/attribute setting
# if not study._is_multi_objective:
# logger.warning(f"Study '{hpo_config.study_name}' exists but is not multi-objective.")
# elif 'objective_names' not in study.user_attrs:
# study.set_user_attr('objective_names', objective_names)
# --- Run Optimization ---
logger.info(f"Starting Optuna optimization: study='{args.study_name}', n_trials={args.n_trials}, metric='{args.metric_to_optimize}', direction='{args.direction}'")
logger.info(f"Starting Optuna single-objective optimization: study='{hpo_config.study_name}', n_trials={hpo_config.n_trials}") # Updated log message
study.optimize(
lambda trial: objective(trial, base_config, df, output_dir, args.metric_to_optimize, args.pruning),
n_trials=args.n_trials,
timeout=None # Optional: Set timeout in seconds
# Optional: Add callbacks (e.g., logging callback)
lambda trial: objective(trial, base_config, df), # Pass base_config
n_trials=hpo_config.n_trials,
timeout=None,
gc_after_trial=True
)
# --- Report Best Trial ---
# --- Report and Process Best Trial ---
logger.info("--- Optuna HPO Finished ---")
logger.info(f"Number of finished trials: {len(study.trials)}")
# Get the single best trial
best_trial = study.best_trial
logger.info(f"Best trial number: {best_trial.number}")
logger.info(f" Best validation {args.metric_to_optimize}: {best_trial.value:.5f}")
logger.info(" Best hyperparameters:")
if best_trial is None:
logger.warning("Optuna study finished, but no successful trial was completed.")
else:
logger.info(f"Best trial found (Trial {best_trial.number}):")
# Log details for the best trial
validation_metric_monitor = base_config.optuna.metric_to_optimize
logger.info(f" Objective ({validation_metric_monitor}): {best_trial.value:.5f}") # Use .value for single objective
logger.info(f" Hyperparameters:")
for key, value in best_trial.params.items():
logger.info(f" {key}: {value}")
# --- Save Best Hyperparameters (Optional) ---
best_params_file = output_dir / f"{args.study_name}_best_params.json"
# --- Re-run and Save Artifacts for the Best Trial ---
logger.info(f"-> Re-running Best Trial {best_trial.number} to save artifacts...")
trial_output_dir = hpo_base_output_dir / f"best_trial_num{best_trial.number}" # Simplified directory name
trial_output_dir.mkdir(parents=True, exist_ok=True)
try:
# 1. Create config for this trial
best_config_dict = copy.deepcopy(base_config.model_dump(mode='python'))
# Update with trial's hyperparameters
for key, value in best_trial.params.items(): # Use best_trial.params
# Need to place the param in the correct nested config dict
if key in best_config_dict['training']: best_config_dict['training'][key] = value
elif key in best_config_dict['model']: best_config_dict['model'][key] = value
elif key in best_config_dict['features']: best_config_dict['features'][key] = value
# Add more sections if HPO tunes them
# Add non-tuned features forecast_horizon back
best_config_dict['features']['forecast_horizon'] = base_config.features.forecast_horizon
# Ensure evaluation plots and model saving are enabled for this final run
best_config_dict['evaluation']['save_plots'] = True
best_config_dict['training']['save_model'] = True # Assuming you want to save the best model
# Validate the final config for this trial
best_trial_config = MainConfig(**best_config_dict)
# Save the specific config used for this run
with open(trial_output_dir / "best_config.yaml", 'w') as f:
yaml.dump(best_config_dict, f, default_flow_style=False, sort_keys=False)
# 2. Run classic training (which saves model & plots)
logger.info(f"-> Running classic training for Best Trial {best_trial.number}...")
# Pass the specific config and output directory
run_classic_training(
config=best_trial_config,
full_df=df,
output_base_dir=trial_output_dir # outputs -> hpo_results/best_trial_num<n>/classic_run/
)
logger.info(f"-> Finished re-running and saving artifacts for Best Trial {best_trial.number} to {trial_output_dir}")
except Exception as e:
logger.error(f"-> Failed to re-run or save artifacts for Best Trial {best_trial.number}: {e}", exc_info=True)
# --- Save Best Hyperparameters ---
best_params_file = hpo_base_output_dir / f"{hpo_config.study_name}_best_params.json" # Simplified filename
try:
with open(best_params_file, 'w') as f:
import json
json.dump(best_trial.params, f, indent=4)
logger.info(f"Best hyperparameters saved to {best_params_file}")
json.dump(best_trial.params, f, indent=4) # Use best_trial.params
logger.info(f"Hyperparameters of the best trial saved to {best_params_file}")
except Exception as e:
logger.error(f"Failed to save best parameters: {e}")
logger.error(f"Failed to save parameters for best trial: {e}")
except Exception as e:
logger.critical(f"An critical error occurred during the Optuna study: {e}", exc_info=True)
logger.critical(f"A critical error occurred during the Optuna study: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":