intermediate backup
This commit is contained in:
@ -4,7 +4,7 @@ import pandas as pd
|
|||||||
import json
|
import json
|
||||||
from typing import Optional, Dict, List, Any
|
from typing import Optional, Dict, List, Any
|
||||||
# Use utils for config if that's the structure
|
# Use utils for config if that's the structure
|
||||||
from data_analysis.utils.config_model import settings
|
from data_analysis.utils.data_config_model import settings
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -3,7 +3,7 @@ from pathlib import Path
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
from typing import Tuple, Optional, Dict, Any
|
from typing import Tuple, Optional, Dict, Any
|
||||||
|
|
||||||
from data_analysis.utils.config_model import settings
|
from data_analysis.utils.data_config_model import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ import shutil
|
|||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
from data_analysis.utils.config_model import settings # Assuming settings are configured
|
from data_analysis.utils.data_config_model import settings # Assuming settings are configured
|
||||||
from data_analysis.utils.report_model import ReportData
|
from data_analysis.utils.report_model import ReportData
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -5,7 +5,7 @@ from pathlib import Path
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
# Import necessary components from your project structure
|
# Import necessary components from your project structure
|
||||||
from data_analysis.utils.config_model import load_settings, Settings # Import loading function and model
|
from data_analysis.utils.data_config_model import load_settings, Settings # Import loading function and model
|
||||||
from data_analysis.analysis.pipeline import run_eda_pipeline # Import the pipeline entry point
|
from data_analysis.analysis.pipeline import run_eda_pipeline # Import the pipeline entry point
|
||||||
|
|
||||||
# Silence overly verbose libraries if needed (e.g., matplotlib)
|
# Silence overly verbose libraries if needed (e.g., matplotlib)
|
||||||
|
@ -2,6 +2,13 @@
|
|||||||
|
|
||||||
project_name: "TimeSeriesForecasting" # Name for the project/run
|
project_name: "TimeSeriesForecasting" # Name for the project/run
|
||||||
random_seed: 42 # Optional: Global random seed for reproducibility
|
random_seed: 42 # Optional: Global random seed for reproducibility
|
||||||
|
log_level: INFO # Or DEBUG
|
||||||
|
|
||||||
|
# --- Execution Control ---
|
||||||
|
run_cross_validation: true # Run the main cross-validation loop?
|
||||||
|
run_classic_training: true # Run a single classic train/val/test split?
|
||||||
|
run_ensemble_evaluation: true # Run ensemble evaluation (requires run_cross_validation=true)?
|
||||||
|
# --- End Execution Control ---
|
||||||
|
|
||||||
# --- Data Loading Configuration ---
|
# --- Data Loading Configuration ---
|
||||||
data:
|
data:
|
||||||
@ -20,7 +27,7 @@ data:
|
|||||||
# --- Feature Engineering & Preprocessing Configuration ---
|
# --- Feature Engineering & Preprocessing Configuration ---
|
||||||
features:
|
features:
|
||||||
sequence_length: 72 # REQUIRED: Lookback window size (e.g., 72 hours = 3 days)
|
sequence_length: 72 # REQUIRED: Lookback window size (e.g., 72 hours = 3 days)
|
||||||
forecast_horizon: 24 # REQUIRED: Number of steps ahead to predict (e.g., 24 hours)
|
forecast_horizon: [ 1, 6, 12, 24] # REQUIRED: List of steps ahead to predict (e.g., 1 hour, 6 hours, 12 hours, 24 hours, 48 hours, 72 hours, 168 hours)
|
||||||
lags: [24, 48, 72, 168] # List of lag features to create (e.g., 1 day, 2 days, 3 days, 1 week)
|
lags: [24, 48, 72, 168] # List of lag features to create (e.g., 1 day, 2 days, 3 days, 1 week)
|
||||||
rolling_window_sizes: [24, 72, 168] # List of window sizes for rolling stats (mean, std)
|
rolling_window_sizes: [24, 72, 168] # List of window sizes for rolling stats (mean, std)
|
||||||
use_time_features: true # Create calendar features (hour, dayofweek, month, etc.)?
|
use_time_features: true # Create calendar features (hour, dayofweek, month, etc.)?
|
||||||
@ -62,7 +69,7 @@ training:
|
|||||||
scheduler_step_size: null # Optional: Step size for StepLR scheduler (epochs). Set null/None to disable. Must be > 0 if set.
|
scheduler_step_size: null # Optional: Step size for StepLR scheduler (epochs). Set null/None to disable. Must be > 0 if set.
|
||||||
scheduler_gamma: null # Optional: Gamma factor for StepLR scheduler. Set null/None to disable. Must be 0 < gamma < 1 if set.
|
scheduler_gamma: null # Optional: Gamma factor for StepLR scheduler. Set null/None to disable. Must be 0 < gamma < 1 if set.
|
||||||
gradient_clip_val: 1.0 # Optional: Value for gradient clipping. Set null/None to disable. Must be >= 0.0 if set.
|
gradient_clip_val: 1.0 # Optional: Value for gradient clipping. Set null/None to disable. Must be >= 0.0 if set.
|
||||||
num_workers: 0 # Number of workers for DataLoader (>= 0). 0 means data loading happens in the main process.
|
num_workers: 4 # Number of workers for DataLoader (>= 0). 0 means data loading happens in the main process.
|
||||||
precision: 32 # Training precision (16, 32, 64, 'bf16')
|
precision: 32 # Training precision (16, 32, 64, 'bf16')
|
||||||
|
|
||||||
# --- Cross-Validation Configuration (Rolling Window) ---
|
# --- Cross-Validation Configuration (Rolling Window) ---
|
||||||
@ -80,9 +87,11 @@ evaluation:
|
|||||||
|
|
||||||
# --- Optuna Hyperparameter Optimization Configuration ---
|
# --- Optuna Hyperparameter Optimization Configuration ---
|
||||||
optuna:
|
optuna:
|
||||||
enabled: false # Enable Optuna HPO? If true, requires optuna.py script.
|
enabled: true # Set to true to actually run HPO via optuna_run.py
|
||||||
n_trials: 20 # Number of trials to run (must be > 0)
|
study_name: "lstm_price_forecast_hpo_v1" # Specific name for this study
|
||||||
storage: null # Optional: Optuna storage URL (e.g., "sqlite:///output/hpo_results/study.db"). If null, uses in-memory.
|
n_trials: 200 # Number of trials to run
|
||||||
direction: "minimize" # Optimization direction ('minimize' or 'maximize')
|
storage: "sqlite:///output/hpo_results/study_v1.db" # Path to database file
|
||||||
metric_to_optimize: "val_mae_orig_scale" # Metric logged by LightningModule to optimize
|
direction: "minimize" # 'minimize' or 'maximize'
|
||||||
pruning: true # Enable Optuna trial pruning?
|
metric_to_optimize: "val_MeanAbsoluteError" # Metric logged in validation_step
|
||||||
|
pruning: true # Enable pruning
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ from .data_processing import (
|
|||||||
prepare_fold_data_and_loaders,
|
prepare_fold_data_and_loaders,
|
||||||
TimeSeriesDataset
|
TimeSeriesDataset
|
||||||
)
|
)
|
||||||
from .model import LSTMForecastLightningModule
|
from forecasting_model.train.model import LSTMForecastLightningModule
|
||||||
from .evaluation import (
|
from .evaluation import (
|
||||||
evaluate_fold_predictions,
|
evaluate_fold_predictions,
|
||||||
# Optionally expose the standalone evaluation utility if needed externally
|
# Optionally expose the standalone evaluation utility if needed externally
|
||||||
|
@ -5,9 +5,10 @@ import torch
|
|||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
||||||
from typing import Tuple, Generator, List, Optional, Union, Dict, Literal, Type
|
from typing import Tuple, Generator, List, Optional, Union, Dict, Literal, Type
|
||||||
|
import math # Add math import
|
||||||
|
|
||||||
# Use relative import for utils within the package
|
# Use relative import for utils within the package
|
||||||
from .utils.config_model import DataConfig, FeatureConfig, TrainingConfig, EvaluationConfig, CrossValidationConfig
|
from .utils.forecast_config_model import DataConfig, FeatureConfig, TrainingConfig, EvaluationConfig, CrossValidationConfig
|
||||||
# Optional: Import wavelet library if needed later
|
# Optional: Import wavelet library if needed later
|
||||||
# import pywt
|
# import pywt
|
||||||
|
|
||||||
@ -264,31 +265,39 @@ def engineer_features(df: pd.DataFrame, target_col: str, feature_config: Feature
|
|||||||
if isinstance(nan_handler, str):
|
if isinstance(nan_handler, str):
|
||||||
if nan_handler in ['ffill', 'bfill']:
|
if nan_handler in ['ffill', 'bfill']:
|
||||||
fill_method = nan_handler
|
fill_method = nan_handler
|
||||||
logger.debug(f"Filling NaNs in generated features using method: '{fill_method}'")
|
logger.debug(f"Selected NaN fill method for generated features: '{fill_method}'")
|
||||||
elif nan_handler == 'mean':
|
elif nan_handler == 'mean':
|
||||||
logger.warning("NaN filling with 'mean' in generated features is applied globally here;"
|
logger.warning("NaN filling with 'mean' in generated features is applied globally here;"
|
||||||
" consider per-fold mean filling if lookahead is a concern.")
|
" consider per-fold mean filling if lookahead is a concern.")
|
||||||
# Calculate mean only on the slice provided, potentially leaking info if slice includes val/test
|
fill_value = features_df[feature_cols_generated].mean()
|
||||||
# Better to use ffill/bfill here or handle after split
|
logger.debug("Selected NaN fill method: column means.")
|
||||||
fill_value = features_df[feature_cols_generated].mean() # Calculate mean per feature column
|
|
||||||
logger.debug("Filling NaNs in generated features using column means.")
|
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Unsupported string fill_nan method '{nan_handler}' for generated features. Using 'ffill'.")
|
logger.warning(f"Unsupported string fill_nan method '{nan_handler}' for generated features. Using 'ffill'.")
|
||||||
fill_method = 'ffill'
|
fill_method = 'ffill' # Default to ffill if unsupported string
|
||||||
elif isinstance(nan_handler, (int, float)):
|
elif isinstance(nan_handler, (int, float)):
|
||||||
fill_value = float(nan_handler)
|
fill_value = float(nan_handler)
|
||||||
logger.debug(f"Filling NaNs in generated features with value: {fill_value}")
|
logger.debug(f"Selected NaN fill value for generated features: {fill_value}")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Invalid fill_nan type: {type(nan_handler)}. NaNs in features may remain.")
|
logger.warning(f"Invalid fill_nan type: {type(nan_handler)}. NaNs in features may remain.")
|
||||||
|
|
||||||
# Apply filling only to generated feature columns
|
# Apply filling only to generated feature columns using recommended methods
|
||||||
if fill_method:
|
if fill_method == 'ffill':
|
||||||
features_df[feature_cols_generated] = features_df[feature_cols_generated].fillna(method=fill_method)
|
logger.debug("Applying .ffill() to generated features...")
|
||||||
if fill_method == 'ffill':
|
features_df[feature_cols_generated] = features_df[feature_cols_generated].ffill()
|
||||||
features_df[feature_cols_generated] = features_df[feature_cols_generated].fillna(method='bfill')
|
# Apply bfill afterwards to handle any NaNs remaining at the very beginning
|
||||||
|
logger.debug("Applying .bfill() to handle any remaining NaNs at the start...")
|
||||||
|
features_df[feature_cols_generated] = features_df[feature_cols_generated].bfill()
|
||||||
|
elif fill_method == 'bfill':
|
||||||
|
logger.debug("Applying .bfill() to generated features...")
|
||||||
|
features_df[feature_cols_generated] = features_df[feature_cols_generated].bfill()
|
||||||
|
# Optionally apply ffill after bfill if you need to fill trailing NaNs (less common)
|
||||||
|
# features_df[feature_cols_generated] = features_df[feature_cols_generated].ffill()
|
||||||
elif fill_value is not None:
|
elif fill_value is not None:
|
||||||
# fillna with Series/dict for column-wise mean, or scalar for constant value
|
# fillna with Series/dict for column-wise mean, or scalar for constant value
|
||||||
|
logger.debug(f"Applying .fillna(value={fill_value}) to generated features...")
|
||||||
features_df[feature_cols_generated] = features_df[feature_cols_generated].fillna(value=fill_value)
|
features_df[feature_cols_generated] = features_df[feature_cols_generated].fillna(value=fill_value)
|
||||||
|
# No else needed, if fill_method and fill_value are None, no filling happens
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.warning("`fill_nan` is None. NaNs generated by feature engineering may remain.")
|
logger.warning("`fill_nan` is None. NaNs generated by feature engineering may remain.")
|
||||||
|
|
||||||
@ -366,36 +375,31 @@ class TimeSeriesCrossValidationSplitter:
|
|||||||
|
|
||||||
# Estimate if None
|
# Estimate if None
|
||||||
elif self.initial_train_size is None:
|
elif self.initial_train_size is None:
|
||||||
min_samples_per_split_step = 2 # Heuristic minimum samples for val+test in one step
|
logger.info("Estimating fixed train size based on n_splits, val_frac, test_frac.")
|
||||||
# Estimate val/test based on *potential* train size (crude)
|
# Estimate based on the total space needed for all splits:
|
||||||
# Assume train is roughly (1 - val - test) fraction for estimation
|
# n_samples >= fixed_train_n + val_size + test_size + (n_splits - 1) * step_size
|
||||||
estimated_train_frac = max(0.1, 1.0 - self.val_frac - self.test_frac) # Ensure non-zero
|
# n_samples >= fixed_train_n + int(fixed_train_n*val_frac) + n_splits * int(fixed_train_n*test_frac)
|
||||||
estimated_train_n = int(self.n_samples * estimated_train_frac)
|
# n_samples >= fixed_train_n * (1 + val_frac + n_splits * test_frac)
|
||||||
val_test_size_per_step = max(min_samples_per_split_step, int(estimated_train_n * (self.val_frac + self.test_frac)))
|
# fixed_train_n <= n_samples / (1 + val_frac + n_splits * test_frac)
|
||||||
|
|
||||||
# Tentative initial train size is total minus one val/test block
|
denominator = 1.0 + self.val_frac + self.n_splits * self.test_frac
|
||||||
fixed_train_n_est = self.n_samples - val_test_size_per_step
|
if denominator <= 1.0: # Avoid division by zero or non-positive, and ensure train frac < 1
|
||||||
|
raise ValueError(f"Cannot estimate initial_train_size. Combination of val_frac ({self.val_frac}), "
|
||||||
|
f"test_frac ({self.test_frac}), and n_splits ({self.n_splits}) is invalid (denominator {denominator:.2f} <= 1.0).")
|
||||||
|
|
||||||
# Basic sanity checks
|
estimated_size = int(self.n_samples / denominator)
|
||||||
if fixed_train_n_est <= 0:
|
|
||||||
raise ValueError("Could not estimate a valid initial_train_size (<= 0). Please specify it or check CV fractions.")
|
|
||||||
# Need at least 1 sample for train, val, test each theoretically
|
|
||||||
est_val_size = max(1, int(fixed_train_n_est * self.val_frac))
|
|
||||||
est_test_size = max(1, int(fixed_train_n_est * self.test_frac))
|
|
||||||
if fixed_train_n_est + est_val_size + est_test_size > self.n_samples:
|
|
||||||
# If the simple estimate is too large, reduce it more drastically
|
|
||||||
# Try setting train size = 50% and see if val/test fit?
|
|
||||||
fixed_train_n_est = int(self.n_samples * 0.5)
|
|
||||||
est_val_size = max(1, int(fixed_train_n_est * self.val_frac))
|
|
||||||
est_test_size = max(1, int(fixed_train_n_est * self.test_frac))
|
|
||||||
if fixed_train_n_est <=0 or (fixed_train_n_est + est_val_size + est_test_size > self.n_samples):
|
|
||||||
raise ValueError("Could not estimate a valid initial_train_size. Data too small relative to val/test fractions? Please specify initial_train_size.")
|
|
||||||
|
|
||||||
logger.warning(f"initial_train_size not set, estimated fixed train size for rolling window: {fixed_train_n_est}. "
|
# Add a sanity check: ensure estimated size is reasonably large
|
||||||
"This is a heuristic; viability depends on n_splits and step size. Validation happens in split().")
|
min_required_for_features = 1 # Placeholder - ideally get from FeatureConfig if possible, but complex here
|
||||||
return fixed_train_n_est
|
if estimated_size < min_required_for_features:
|
||||||
|
raise ValueError(f"Estimated fixed train size ({estimated_size}) is too small. "
|
||||||
|
f"Check CV config (n_splits={self.n_splits}, val_frac={self.val_frac}, test_frac={self.test_frac}) "
|
||||||
|
f"relative to total samples ({self.n_samples}). Consider specifying initial_train_size manually.")
|
||||||
|
|
||||||
|
logger.info(f"Estimated fixed training window size: {estimated_size}")
|
||||||
|
return estimated_size
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid initial_train_size: {self.initial_train_size}")
|
raise ValueError(f"Invalid initial_train_size type or value: {self.initial_train_size}")
|
||||||
|
|
||||||
|
|
||||||
def split(self) -> Generator[Tuple[np.ndarray, np.ndarray, np.ndarray], None, None]:
|
def split(self) -> Generator[Tuple[np.ndarray, np.ndarray, np.ndarray], None, None]:
|
||||||
@ -483,28 +487,31 @@ class TimeSeriesDataset(Dataset):
|
|||||||
"""
|
"""
|
||||||
PyTorch Dataset for time series forecasting.
|
PyTorch Dataset for time series forecasting.
|
||||||
|
|
||||||
Takes a NumPy array (features + target), sequence length, and forecast horizon,
|
Takes a NumPy array (features + target), sequence length, and a list of
|
||||||
and returns (input_sequence, target_sequence) tuples. Compatible with PyTorch
|
specific forecast horizons. Returns (input_sequence, target_vector) tuples,
|
||||||
DataLoaders used by PyTorch Lightning.
|
where target_vector contains the target values at the specified future steps.
|
||||||
"""
|
"""
|
||||||
def __init__(self, data_array: np.ndarray, sequence_length: int, forecast_horizon: int, target_col_index: int = 0):
|
def __init__(self, data_array: np.ndarray, sequence_length: int, forecast_horizon: List[int], target_col_index: int = 0):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
data_array: Numpy array of shape (n_samples, n_features).
|
data_array: Numpy array of shape (n_samples, n_features).
|
||||||
Assumes the target variable is one of the columns.
|
Assumes the target variable is one of the columns.
|
||||||
sequence_length: Length of the input sequence (lookback window).
|
sequence_length: Length of the input sequence (lookback window).
|
||||||
forecast_horizon: Number of steps ahead to predict.
|
forecast_horizon: List of specific steps ahead to predict (e.g., [1, 6, 12]).
|
||||||
target_col_index: Index of the target column in data_array. Defaults to 0.
|
target_col_index: Index of the target column in data_array. Defaults to 0.
|
||||||
"""
|
"""
|
||||||
if sequence_length <= 0:
|
if sequence_length <= 0:
|
||||||
raise ValueError("sequence_length must be positive.")
|
raise ValueError("sequence_length must be positive.")
|
||||||
if forecast_horizon <= 0:
|
if not forecast_horizon or not isinstance(forecast_horizon, list) or any(h <= 0 for h in forecast_horizon):
|
||||||
raise ValueError("forecast_horizon must be positive.")
|
raise ValueError("forecast_horizon must be a non-empty list of positive integers.")
|
||||||
if data_array.ndim != 2:
|
if data_array.ndim != 2:
|
||||||
raise ValueError(f"data_array must be 2D, but got shape {data_array.shape}")
|
raise ValueError(f"data_array must be 2D, but got shape {data_array.shape}")
|
||||||
min_len_required = sequence_length + forecast_horizon
|
|
||||||
|
self.max_horizon = max(forecast_horizon) # Find the furthest point needed
|
||||||
|
|
||||||
|
min_len_required = sequence_length + self.max_horizon
|
||||||
if min_len_required > data_array.shape[0]:
|
if min_len_required > data_array.shape[0]:
|
||||||
raise ValueError(f"sequence_length ({sequence_length}) + forecast_horizon ({forecast_horizon}) = {min_len_required} "
|
raise ValueError(f"sequence_length ({sequence_length}) + max_horizon ({self.max_horizon}) = {min_len_required} "
|
||||||
f"exceeds total samples provided ({data_array.shape[0]})")
|
f"exceeds total samples provided ({data_array.shape[0]})")
|
||||||
if not (0 <= target_col_index < data_array.shape[1]):
|
if not (0 <= target_col_index < data_array.shape[1]):
|
||||||
raise ValueError(f"target_col_index ({target_col_index}) out of bounds for data with {data_array.shape[1]} columns.")
|
raise ValueError(f"target_col_index ({target_col_index}) out of bounds for data with {data_array.shape[1]} columns.")
|
||||||
@ -512,32 +519,37 @@ class TimeSeriesDataset(Dataset):
|
|||||||
|
|
||||||
self.data = torch.tensor(data_array, dtype=torch.float32)
|
self.data = torch.tensor(data_array, dtype=torch.float32)
|
||||||
self.sequence_length = sequence_length
|
self.sequence_length = sequence_length
|
||||||
self.forecast_horizon = forecast_horizon
|
self.forecast_horizon_list = sorted(forecast_horizon)
|
||||||
self.target_col_index = target_col_index
|
self.target_col_index = target_col_index
|
||||||
self.n_samples = data_array.shape[0]
|
self.n_samples = data_array.shape[0]
|
||||||
self.n_features = data_array.shape[1]
|
self.n_features = data_array.shape[1]
|
||||||
|
|
||||||
logger.debug(f"TimeSeriesDataset created: data shape={self.data.shape}, "
|
logger.debug(f"TimeSeriesDataset created: data shape={self.data.shape}, "
|
||||||
f"seq_len={self.sequence_length}, forecast_horizon={self.forecast_horizon}, "
|
f"seq_len={self.sequence_length}, forecast_horizons={self.forecast_horizon_list}, "
|
||||||
f"target_idx={self.target_col_index}")
|
f"max_horizon={self.max_horizon}, target_idx={self.target_col_index}")
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""Returns the total number of sequences that can be generated."""
|
"""Returns the total number of sequences that can be generated."""
|
||||||
return self.n_samples - self.sequence_length - self.forecast_horizon + 1
|
return self.n_samples - self.sequence_length - self.max_horizon + 1
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Returns a single (input_sequence, target_sequence) pair.
|
Returns a single (input_sequence, target_vector) pair.
|
||||||
|
Target vector contains values for the specified forecast horizons.
|
||||||
"""
|
"""
|
||||||
if not (0 <= idx < len(self)):
|
if not (0 <= idx < len(self)):
|
||||||
raise IndexError(f"Index {idx} out of bounds for dataset with length {len(self)}")
|
raise IndexError(f"Index {idx} out of bounds for dataset with length {len(self)}")
|
||||||
|
|
||||||
input_start = idx
|
input_start = idx
|
||||||
input_end = idx + self.sequence_length
|
input_end = idx + self.sequence_length
|
||||||
input_sequence = self.data[input_start:input_end, :]
|
input_sequence = self.data[input_start:input_end, :] # Shape: (seq_len, n_features)
|
||||||
target_start = input_end
|
|
||||||
target_end = target_start + self.forecast_horizon
|
# Calculate indices for each horizon relative to the end of the input sequence
|
||||||
target_sequence = self.data[target_start:target_end, self.target_col_index]
|
# Horizon h corresponds to index: input_end + h - 1
|
||||||
return input_sequence, target_sequence
|
target_indices = [input_end + h - 1 for h in self.forecast_horizon_list]
|
||||||
|
target_vector = self.data[target_indices, self.target_col_index] # Shape: (len(forecast_horizon_list),)
|
||||||
|
|
||||||
|
return input_sequence, target_vector
|
||||||
|
|
||||||
# --- Data Preparation ---
|
# --- Data Preparation ---
|
||||||
def prepare_fold_data_and_loaders(
|
def prepare_fold_data_and_loaders(
|
||||||
@ -576,6 +588,7 @@ def prepare_fold_data_and_loaders(
|
|||||||
feature_config: Configuration for feature engineering.
|
feature_config: Configuration for feature engineering.
|
||||||
train_config: Configuration for training (used for batch size, device hints).
|
train_config: Configuration for training (used for batch size, device hints).
|
||||||
eval_config: Configuration for evaluation (used for batch size).
|
eval_config: Configuration for evaluation (used for batch size).
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple containing:
|
Tuple containing:
|
||||||
@ -598,13 +611,25 @@ def prepare_fold_data_and_loaders(
|
|||||||
if feature_config.lags:
|
if feature_config.lags:
|
||||||
max_lookback = max(max_lookback, max(feature_config.lags))
|
max_lookback = max(max_lookback, max(feature_config.lags))
|
||||||
if feature_config.rolling_window_sizes:
|
if feature_config.rolling_window_sizes:
|
||||||
max_lookback = max(max_lookback, max(feature_config.rolling_window_sizes) -1 )
|
max_lookback = max(max_lookback, max(feature_config.rolling_window_sizes) -1)
|
||||||
max_history_needed = max(max_lookback, feature_config.sequence_length)
|
|
||||||
|
# Also need history for the input sequence length and max target horizon
|
||||||
|
max_horizon_needed = max(feature_config.forecast_horizon) if feature_config.forecast_horizon else 0
|
||||||
|
# Max history needed is max of lookback for features OR (sequence_length + max_horizon - 1) for targets/inputs
|
||||||
|
# Correct logic: Need `sequence_length` history for input, and `max_horizon` steps *after* the train data for targets/evaluation.
|
||||||
|
# The slicing needs to ensure enough data *before* train_idx[0] for feature lookback *and* sequence_length.
|
||||||
|
# Max history *before* the start of the training set
|
||||||
|
max_history_needed_before_train = max(max_lookback, feature_config.sequence_length)
|
||||||
|
|
||||||
|
slice_start_idx = max(0, train_idx[0] - max_history_needed_before_train)
|
||||||
|
# The end index needs to cover the test set PLUS the maximum horizon needed for the last test target
|
||||||
|
slice_end_idx = test_idx[-1] + max_horizon_needed # Go up to the last needed target
|
||||||
|
|
||||||
|
# Ensure end index is within bounds
|
||||||
|
slice_end_idx = min(slice_end_idx + 1, len(full_df)) # +1 because iloc is exclusive
|
||||||
|
|
||||||
slice_start_idx = max(0, train_idx[0] - max_history_needed)
|
|
||||||
slice_end_idx = test_idx[-1] + 1
|
|
||||||
if slice_start_idx >= slice_end_idx:
|
if slice_start_idx >= slice_end_idx:
|
||||||
raise ValueError(f"Calculated slice start ({slice_start_idx}) >= slice end ({slice_end_idx}). Check indices.")
|
raise ValueError(f"Calculated slice start ({slice_start_idx}) >= slice end ({slice_end_idx}). Check indices and horizon.")
|
||||||
|
|
||||||
fold_data_slice = full_df.iloc[slice_start_idx:slice_end_idx]
|
fold_data_slice = full_df.iloc[slice_start_idx:slice_end_idx]
|
||||||
logger.debug(f"Required data slice for fold: indices {slice_start_idx} to {slice_end_idx-1} "
|
logger.debug(f"Required data slice for fold: indices {slice_start_idx} to {slice_end_idx-1} "
|
||||||
@ -709,22 +734,38 @@ def prepare_fold_data_and_loaders(
|
|||||||
|
|
||||||
input_size = train_data_scaled.shape[1]
|
input_size = train_data_scaled.shape[1]
|
||||||
|
|
||||||
|
# --- Ensure final data arrays are float32 for PyTorch ---
|
||||||
|
try:
|
||||||
|
# Explicitly convert to float32 AFTER scaling (or non-scaling)
|
||||||
|
train_data_final = train_data_scaled.astype(np.float32)
|
||||||
|
val_data_final = val_data_scaled.astype(np.float32)
|
||||||
|
test_data_final = test_data_scaled.astype(np.float32)
|
||||||
|
logger.debug("Ensured final data arrays are float32.")
|
||||||
|
except ValueError as e:
|
||||||
|
# This might happen if data cannot be safely cast (e.g., strings remain unexpectedly)
|
||||||
|
logger.error(f"Failed to convert data arrays to float32 before creating Tensors: {e}", exc_info=True)
|
||||||
|
# Consider adding more debug info here if it fails, e.g.:
|
||||||
|
# logger.debug(f"Data types in train_df before conversion: \n{train_df.dtypes}")
|
||||||
|
raise ValueError("Data could not be converted to numeric type (float32) for PyTorch.") from e
|
||||||
|
|
||||||
|
|
||||||
# 6. Dataset Instantiation
|
# 6. Dataset Instantiation
|
||||||
logger.debug("Creating TimeSeriesDataset instances for the fold.")
|
logger.debug("Creating TimeSeriesDataset instances for the fold.")
|
||||||
try:
|
try:
|
||||||
|
# Use the explicitly converted arrays
|
||||||
train_dataset = TimeSeriesDataset(
|
train_dataset = TimeSeriesDataset(
|
||||||
train_data_scaled, feature_config.sequence_length, feature_config.forecast_horizon, target_col_index=target_col_index_in_features
|
train_data_final, feature_config.sequence_length, feature_config.forecast_horizon, target_col_index=target_col_index_in_features
|
||||||
)
|
)
|
||||||
val_dataset = TimeSeriesDataset(
|
val_dataset = TimeSeriesDataset(
|
||||||
val_data_scaled, feature_config.sequence_length, feature_config.forecast_horizon, target_col_index=target_col_index_in_features
|
val_data_final, feature_config.sequence_length, feature_config.forecast_horizon, target_col_index=target_col_index_in_features
|
||||||
)
|
)
|
||||||
test_dataset = TimeSeriesDataset(
|
test_dataset = TimeSeriesDataset(
|
||||||
test_data_scaled, feature_config.sequence_length, feature_config.forecast_horizon, target_col_index=target_col_index_in_features
|
test_data_final, feature_config.sequence_length, feature_config.forecast_horizon, target_col_index=target_col_index_in_features
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Error creating TimeSeriesDataset: {e}")
|
logger.error(f"Error creating TimeSeriesDataset: {e}")
|
||||||
logger.error(f"Shapes fed to Dataset: Train={train_data_scaled.shape}, Val={val_data_scaled.shape}, Test={test_data_scaled.shape}")
|
logger.error(f"Shapes fed to Dataset: Train={train_data_final.shape}, Val={val_data_final.shape}, Test={test_data_final.shape}")
|
||||||
logger.error(f"SeqLen={feature_config.sequence_length}, Horizon={feature_config.forecast_horizon}")
|
logger.error(f"SeqLen={feature_config.sequence_length}, Horizons={feature_config.forecast_horizon}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@ -748,4 +789,69 @@ def prepare_fold_data_and_loaders(
|
|||||||
|
|
||||||
logger.info("Data loaders prepared successfully for the fold.")
|
logger.info("Data loaders prepared successfully for the fold.")
|
||||||
|
|
||||||
return train_loader, val_loader, test_loader, target_scaler, input_size
|
return train_loader, val_loader, test_loader, target_scaler, input_size
|
||||||
|
|
||||||
|
# --- Classic Train/Val/Test Split ---
|
||||||
|
|
||||||
|
def split_data_classic(
|
||||||
|
n_samples: int,
|
||||||
|
val_frac: float,
|
||||||
|
test_frac: float,
|
||||||
|
start_from_end: bool = True
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Splits data indices into one train, one validation, and one test set based on fractions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_samples: Total number of samples in the dataset.
|
||||||
|
val_frac: Fraction of the *total* data to use for validation.
|
||||||
|
test_frac: Fraction of the *total* data to use for testing.
|
||||||
|
start_from_end: If True (default), test and validation sets are taken from the end
|
||||||
|
of the series. If False, they are taken after the initial training block.
|
||||||
|
Default is True for typical time series evaluation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (train_indices, val_indices, test_indices).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If fractions are invalid or sum to >= 1.
|
||||||
|
"""
|
||||||
|
if not (0 < val_frac < 1):
|
||||||
|
raise ValueError(f"val_frac must be between 0 and 1, got {val_frac}")
|
||||||
|
if not (0 < test_frac < 1):
|
||||||
|
raise ValueError(f"test_frac must be between 0 and 1, got {test_frac}")
|
||||||
|
if val_frac + test_frac >= 1:
|
||||||
|
raise ValueError(f"Sum of val_frac ({val_frac}) and test_frac ({test_frac}) must be less than 1.")
|
||||||
|
|
||||||
|
test_size = math.ceil(n_samples * test_frac) # Use ceil to ensure at least one sample if frac is tiny
|
||||||
|
val_size = math.ceil(n_samples * val_frac)
|
||||||
|
train_size = n_samples - val_size - test_size
|
||||||
|
|
||||||
|
if train_size <= 0:
|
||||||
|
raise ValueError(f"Calculated train_size ({train_size}) is not positive. Adjust fractions or increase data.")
|
||||||
|
if val_size <= 0:
|
||||||
|
raise ValueError("Calculated val_size is not positive.")
|
||||||
|
if test_size <= 0:
|
||||||
|
raise ValueError("Calculated test_size is not positive.")
|
||||||
|
|
||||||
|
|
||||||
|
indices = np.arange(n_samples)
|
||||||
|
|
||||||
|
if start_from_end:
|
||||||
|
train_indices = indices[:train_size]
|
||||||
|
val_indices = indices[train_size:train_size + val_size]
|
||||||
|
test_indices = indices[train_size + val_size:]
|
||||||
|
# Adjust if ceil caused slight overallocation in test
|
||||||
|
test_indices = test_indices[:test_size]
|
||||||
|
else:
|
||||||
|
# Less common: place val/test directly after train
|
||||||
|
train_indices = indices[:train_size]
|
||||||
|
val_indices = indices[train_size:train_size + val_size]
|
||||||
|
test_indices = indices[train_size + val_size:train_size + val_size + test_size]
|
||||||
|
# Remaining data is unused in this scenario
|
||||||
|
|
||||||
|
logger.info(f"Classic split: Train indices {train_indices[0]}-{train_indices[-1]} (size {len(train_indices)}), "
|
||||||
|
f"Val indices {val_indices[0]}-{val_indices[-1]} (size {len(val_indices)}), "
|
||||||
|
f"Test indices {test_indices[0]}-{test_indices[-1]} (size {len(test_indices)})")
|
||||||
|
|
||||||
|
return train_indices, val_indices, test_indices
|
@ -1,24 +1,22 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from pathlib import Path # Added
|
from pathlib import Path # Added
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from sklearn.preprocessing import StandardScaler, MinMaxScaler # For type hinting target_scaler
|
from sklearn.preprocessing import StandardScaler, MinMaxScaler # For type hinting target_scaler
|
||||||
from typing import Dict, Any, Optional, Union, List, Tuple
|
from typing import Dict, Optional, Union, List
|
||||||
# import matplotlib.pyplot as plt # No longer needed directly
|
import pandas as pd # For time index type hint
|
||||||
# import seaborn as sns # No longer needed directly
|
|
||||||
|
|
||||||
# Assuming config_model and io.plotting are accessible
|
from forecasting_model.utils.forecast_config_model import EvaluationConfig
|
||||||
from forecasting_model.utils.config_model import EvaluationConfig
|
from forecasting_model.train.model import LSTMForecastLightningModule
|
||||||
from forecasting_model.io.plotting import ( # Import the plotting utilities
|
from forecasting_model.io.plotting import (
|
||||||
setup_plot_style,
|
setup_plot_style,
|
||||||
save_plot,
|
save_plot,
|
||||||
create_time_series_plot,
|
create_time_series_plot,
|
||||||
create_scatter_plot,
|
create_scatter_plot,
|
||||||
create_residuals_plot,
|
create_residuals_plot,
|
||||||
create_residuals_distribution_plot
|
create_residuals_distribution_plot,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -82,90 +80,101 @@ def calculate_rmse_np(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|||||||
return float(rmse)
|
return float(rmse)
|
||||||
|
|
||||||
|
|
||||||
# --- Plotting Functions (Utilities) ---
|
|
||||||
# REMOVED - These are now imported from io.plotting
|
|
||||||
|
|
||||||
|
|
||||||
# --- Fold Evaluation Function ---
|
# --- Fold Evaluation Function ---
|
||||||
|
|
||||||
def evaluate_fold_predictions(
|
def evaluate_fold_predictions(
|
||||||
y_true_scaled: np.ndarray,
|
y_true_scaled: np.ndarray, # Shape: (n_samples, len(horizons))
|
||||||
y_pred_scaled: np.ndarray,
|
y_pred_scaled: np.ndarray, # Shape: (n_samples, len(horizons))
|
||||||
target_scaler: Union[StandardScaler, MinMaxScaler, None],
|
target_scaler: Union[StandardScaler, MinMaxScaler, None],
|
||||||
eval_config: EvaluationConfig,
|
eval_config: EvaluationConfig,
|
||||||
fold_num: int,
|
fold_num: int, # Zero-based fold index
|
||||||
output_dir: str, # Base output directory (e.g., output/cv_results)
|
output_dir: str, # Base output directory
|
||||||
time_index: Optional[np.ndarray] = None # Optional: Pass time index for x-axis
|
plot_subdir: Optional[str] = "plots",
|
||||||
|
# time_index: Optional[Union[np.ndarray, pd.Index]] = None, # OLD: Index for samples
|
||||||
|
prediction_time_index: Optional[pd.Index] = None, # Index corresponding to the prediction times (n_samples,)
|
||||||
|
forecast_horizons: Optional[List[int]] = None, # The list of horizons predicted (e.g., [1, 6, 12])
|
||||||
|
plot_title_prefix: Optional[str] = None
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, float]:
|
||||||
"""
|
"""
|
||||||
Processes prediction results for a fold's test set using torchmetrics.
|
Processes prediction results (multiple horizons) for a fold or ensemble.
|
||||||
|
|
||||||
Takes scaled predictions and targets, inverse transforms them,
|
Takes scaled predictions and targets (shape: samples, num_horizons),
|
||||||
calculates final metrics (MAE, RMSE) using torchmetrics.functional,
|
inverse transforms them, calculates overall metrics (MAE, RMSE) across all horizons,
|
||||||
and generates evaluation plots using utilities from io.plotting. Assumes
|
and generates evaluation plots *for the first specified horizon only*.
|
||||||
model inference is already done.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
y_true_scaled: Numpy array of scaled ground truth targets (n_samples, horizon).
|
y_true_scaled: Numpy array of scaled ground truth targets (n_samples, len(horizons)).
|
||||||
y_pred_scaled: Numpy array of scaled model predictions (n_samples, horizon).
|
y_pred_scaled: Numpy array of scaled model predictions (n_samples, len(horizons)).
|
||||||
target_scaler: The scaler fitted on the target variable during training. Needed
|
target_scaler: The scaler fitted on the target variable.
|
||||||
for inverse transforming to original scale. Can be None.
|
eval_config: Configuration object for evaluation parameters.
|
||||||
eval_config: Configuration object for evaluation parameters (e.g., plotting).
|
fold_num: The current fold number (zero-based or -1 for classic).
|
||||||
fold_num: The current fold number (e.g., 0, 1, ...).
|
output_dir: The base directory to save outputs.
|
||||||
output_dir: The base directory to save fold-specific outputs (plots, metrics).
|
plot_subdir: Specific subdirectory under output_dir for plots.
|
||||||
time_index: Optional array representing the time index for the test set,
|
prediction_time_index: Pandas Index representing the time for each prediction point (n_samples,).
|
||||||
used for x-axis in time-based plots. If None, uses integer indices.
|
Required for meaningful time plots.
|
||||||
|
forecast_horizons: List of horizons predicted (e.g., [1, 6, 12]). Required for plotting.
|
||||||
|
plot_title_prefix: Optional string to prepend to plot titles.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary containing evaluation metrics {'MAE': value, 'RMSE': value} on the
|
Dictionary containing evaluation metrics {'MAE': value, 'RMSE': value} on the
|
||||||
original scale. Metrics will be NaN if inverse transform or calculation fails.
|
original scale, calculated *across all predicted horizons*.
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If input shapes are inconsistent or required scaler is missing.
|
|
||||||
"""
|
"""
|
||||||
logger.info(f"Processing evaluation results for Fold {fold_num + 1}...")
|
fold_id_str = f"Fold {fold_num + 1}" if fold_num >= 0 else "Classic Run"
|
||||||
fold_id = fold_num + 1 # Use 1-based indexing for reporting/filenames
|
eval_context_str = f"{plot_title_prefix} {fold_id_str}" if plot_title_prefix else fold_id_str
|
||||||
|
logger.info(f"Processing evaluation results for: {eval_context_str}")
|
||||||
|
|
||||||
if y_true_scaled.shape != y_pred_scaled.shape:
|
if y_true_scaled.shape != y_pred_scaled.shape:
|
||||||
raise ValueError(f"Shape mismatch between targets and predictions: "
|
raise ValueError(f"Shape mismatch between targets and predictions for {eval_context_str}: "
|
||||||
f"{y_true_scaled.shape} vs {y_pred_scaled.shape}")
|
f"{y_true_scaled.shape} vs {y_pred_scaled.shape}")
|
||||||
if y_true_scaled.ndim != 2:
|
if y_true_scaled.ndim != 2:
|
||||||
raise ValueError(f"Expected 2D arrays for targets and predictions, got {y_true_scaled.ndim}D")
|
raise ValueError(f"Expected 2D arrays (samples, num_horizons) for {eval_context_str}, got {y_true_scaled.ndim}D")
|
||||||
|
|
||||||
n_samples, horizon = y_true_scaled.shape
|
n_samples, n_horizons = y_true_scaled.shape
|
||||||
logger.debug(f"Processing {n_samples} samples with horizon {horizon}.")
|
logger.debug(f"Processing {n_samples} samples across {n_horizons} horizons for {eval_context_str}.")
|
||||||
|
|
||||||
# --- Inverse Transform (Outputs NumPy) ---
|
# --- Inverse Transform (Outputs NumPy) ---
|
||||||
y_true_flat_scaled = y_true_scaled.reshape(-1, 1)
|
# Flatten the multi-horizon arrays for the scaler (which expects (N, 1))
|
||||||
y_pred_flat_scaled = y_pred_scaled.reshape(-1, 1)
|
y_true_flat_scaled = y_true_scaled.reshape(-1, 1) # Shape: (n_samples * n_horizons, 1)
|
||||||
|
y_pred_flat_scaled = y_pred_scaled.reshape(-1, 1) # Shape: (n_samples * n_horizons, 1)
|
||||||
|
|
||||||
y_true_inv_np: np.ndarray
|
y_true_inv_np: np.ndarray
|
||||||
y_pred_inv_np: np.ndarray
|
y_pred_inv_np: np.ndarray
|
||||||
|
|
||||||
if target_scaler is not None:
|
if target_scaler is not None:
|
||||||
try:
|
try:
|
||||||
logger.debug("Inverse transforming predictions and targets.")
|
logger.debug(f"Inverse transforming predictions and targets for {eval_context_str}.")
|
||||||
y_true_inv_np = target_scaler.inverse_transform(y_true_flat_scaled)
|
y_true_inv_flat = target_scaler.inverse_transform(y_true_flat_scaled)
|
||||||
y_pred_inv_np = target_scaler.inverse_transform(y_pred_flat_scaled)
|
y_pred_inv_flat = target_scaler.inverse_transform(y_pred_flat_scaled)
|
||||||
# Flatten NumPy arrays for metric calculation and plotting
|
# Reshape back to (n_samples, n_horizons) for potential per-horizon analysis later
|
||||||
y_true_np = y_true_inv_np.flatten()
|
y_true_inv_np = y_true_inv_flat.reshape(n_samples, n_horizons)
|
||||||
y_pred_np = y_pred_inv_np.flatten()
|
y_pred_inv_np = y_pred_inv_flat.reshape(n_samples, n_horizons)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during inverse scaling for Fold {fold_id}: {e}", exc_info=True)
|
logger.error(f"Error during inverse scaling for {eval_context_str}: {e}", exc_info=True)
|
||||||
logger.error("Metrics calculation will be skipped due to inverse transform failure.")
|
logger.error("Metrics calculation will be skipped due to inverse transform failure.")
|
||||||
return {'MAE': np.nan, 'RMSE': np.nan}
|
return {'MAE': np.nan, 'RMSE': np.nan}
|
||||||
else:
|
|
||||||
logger.info("No target scaler provided, assuming inputs are already on original scale.")
|
|
||||||
# Flatten NumPy arrays for metric calculation and plotting
|
|
||||||
y_true_np = y_true_flat_scaled.flatten()
|
|
||||||
y_pred_np = y_pred_flat_scaled.flatten()
|
|
||||||
|
|
||||||
# --- Calculate Metrics using torchmetrics.functional ---
|
else:
|
||||||
metrics: Dict[str, float] = {'MAE': np.nan, 'RMSE': np.nan} # Initialize with NaN
|
logger.info(f"No target scaler provided for {eval_context_str}, assuming inputs are on original scale.")
|
||||||
|
y_true_inv_np = y_true_scaled # Keep original shape (n_samples, n_horizons)
|
||||||
|
y_pred_inv_np = y_pred_scaled # Keep original shape
|
||||||
|
|
||||||
|
# --- Calculate Metrics using torchmetrics.functional (Overall across all horizons) ---
|
||||||
|
metrics: Dict[str, float] = {'MAE': np.nan, 'RMSE': np.nan}
|
||||||
try:
|
try:
|
||||||
if len(y_true_np) > 0: # Check if data exists after potential failures
|
# Flatten arrays for overall metrics calculation
|
||||||
y_true_tensor = torch.from_numpy(y_true_np).float().cpu()
|
y_true_flat_for_metrics = y_true_inv_np.flatten()
|
||||||
y_pred_tensor = torch.from_numpy(y_pred_np).float().cpu()
|
y_pred_flat_for_metrics = y_pred_inv_np.flatten()
|
||||||
|
|
||||||
|
valid_mask = ~np.isnan(y_true_flat_for_metrics) & ~np.isnan(y_pred_flat_for_metrics)
|
||||||
|
if np.sum(valid_mask) < len(y_true_flat_for_metrics):
|
||||||
|
nan_count = len(y_true_flat_for_metrics) - np.sum(valid_mask)
|
||||||
|
logger.warning(f"{nan_count} NaN values found in predictions/targets (across all horizons) for {eval_context_str}. These will be excluded from metrics.")
|
||||||
|
|
||||||
|
|
||||||
|
if np.sum(valid_mask) > 0:
|
||||||
|
y_true_tensor = torch.from_numpy(y_true_flat_for_metrics[valid_mask]).float().cpu()
|
||||||
|
y_pred_tensor = torch.from_numpy(y_pred_flat_for_metrics[valid_mask]).float().cpu()
|
||||||
|
|
||||||
mae_tensor = torchmetrics.functional.mean_absolute_error(y_pred_tensor, y_true_tensor)
|
mae_tensor = torchmetrics.functional.mean_absolute_error(y_pred_tensor, y_true_tensor)
|
||||||
mse_tensor = torchmetrics.functional.mean_squared_error(y_pred_tensor, y_true_tensor)
|
mse_tensor = torchmetrics.functional.mean_squared_error(y_pred_tensor, y_true_tensor)
|
||||||
@ -174,82 +183,95 @@ def evaluate_fold_predictions(
|
|||||||
metrics['MAE'] = mae_tensor.item()
|
metrics['MAE'] = mae_tensor.item()
|
||||||
metrics['RMSE'] = rmse_tensor.item()
|
metrics['RMSE'] = rmse_tensor.item()
|
||||||
|
|
||||||
logger.info(f"Fold {fold_id} Test Set Metrics (torchmetrics): MAE={metrics['MAE']:.4f}, RMSE={metrics['RMSE']:.4f}")
|
logger.info(f"{eval_context_str} Test Set Overall Metrics (torchmetrics): MAE={metrics['MAE']:.4f}, RMSE={metrics['RMSE']:.4f} (across all horizons)")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Skipping metric calculation for Fold {fold_id} due to empty data after inverse transform.")
|
logger.warning(f"Skipping metric calculation for {eval_context_str} due to no valid (non-NaN) data points.")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to calculate metrics using torchmetrics for Fold {fold_id}: {e}", exc_info=True)
|
logger.error(f"Failed to calculate overall metrics using torchmetrics for {eval_context_str}: {e}", exc_info=True)
|
||||||
# metrics already initialized to NaN
|
|
||||||
|
|
||||||
|
|
||||||
# --- Generate Plots (Optional - uses plotting utilities) ---
|
# --- Generate Plots (Optional - Focus on FIRST horizon) ---
|
||||||
if eval_config.save_plots and len(y_true_np) > 0:
|
if eval_config.save_plots and np.sum(valid_mask) > 0:
|
||||||
logger.info(f"Generating evaluation plots for Fold {fold_id}...")
|
if forecast_horizons is None or not forecast_horizons:
|
||||||
# Define plot directory and setup style
|
logger.warning(f"Skipping plot generation for {eval_context_str}: `forecast_horizons` list not provided.")
|
||||||
fold_plot_dir = Path(output_dir) / f"fold_{fold_id:02d}" / "plots"
|
elif prediction_time_index is None or len(prediction_time_index) != n_samples:
|
||||||
setup_plot_style() # Apply consistent styling
|
logger.warning(f"Skipping plot generation for {eval_context_str}: `prediction_time_index` is missing or has incorrect length ({len(prediction_time_index) if prediction_time_index is not None else 'None'} != {n_samples}).")
|
||||||
|
else:
|
||||||
|
logger.info(f"Generating evaluation plots for {eval_context_str} (using first horizon H+{forecast_horizons[0]} only)...")
|
||||||
|
base_plot_dir = Path(output_dir)
|
||||||
|
fold_plot_dir = base_plot_dir / plot_subdir if plot_subdir else base_plot_dir
|
||||||
|
setup_plot_style()
|
||||||
|
|
||||||
title_suffix = f"Fold {fold_id} Test Set"
|
# --- Plotting for the FIRST horizon ---
|
||||||
residuals_np = y_true_np - y_pred_np
|
first_horizon = forecast_horizons[0]
|
||||||
|
y_true_h1 = y_true_inv_np[:, 0] # Data for the first horizon
|
||||||
|
y_pred_h1 = y_pred_inv_np[:, 0] # Data for the first horizon
|
||||||
|
residuals_h1 = y_true_h1 - y_pred_h1
|
||||||
|
|
||||||
# Determine x-axis: use provided time_index if available, else integer indices
|
# Calculate the actual time index for the first horizon's targets
|
||||||
# Note: Flattened y_true/y_pred have length n_samples * horizon
|
# Requires the original dataset's frequency if available, otherwise assumes simple offset
|
||||||
# Need an appropriate index for this flattened view if time_index is provided.
|
target_time_index_h1 = prediction_time_index
|
||||||
# Simple approach: use integer indices for flattened data.
|
try:
|
||||||
plot_indices = np.arange(len(y_true_np))
|
# Assuming prediction_time_index corresponds to the *time* of prediction
|
||||||
xlabel = "Time Index (Flattened Horizon x Samples)"
|
# The target for H+h occurs `h` steps later.
|
||||||
# If time_index corresponding to the start of each forecast is passed,
|
# This requires a DatetimeIndex with a frequency.
|
||||||
# more sophisticated x-axis handling could be done, but integer indices are simpler.
|
if isinstance(prediction_time_index, pd.DatetimeIndex) and prediction_time_index.freq:
|
||||||
|
time_offset = pd.Timedelta(first_horizon, unit=prediction_time_index.freq.name)
|
||||||
|
target_time_index_h1 = prediction_time_index + time_offset
|
||||||
|
xlabel_h1 = f"Time (Target H+{first_horizon})"
|
||||||
|
else:
|
||||||
|
logger.warning(f"Prediction time index lacks frequency info. Using original prediction time for H+{first_horizon} plot x-axis.")
|
||||||
|
xlabel_h1 = f"Prediction Time (Plotting H+{first_horizon})"
|
||||||
|
except Exception as time_err:
|
||||||
|
logger.warning(f"Could not calculate target time index for H+{first_horizon}: {time_err}. Using prediction time index for x-axis.")
|
||||||
|
xlabel_h1 = f"Prediction Time (Plotting H+{first_horizon})"
|
||||||
|
|
||||||
|
|
||||||
try:
|
title_suffix = f"- {eval_context_str} (H+{first_horizon})"
|
||||||
# Create and save each plot using utility functions
|
|
||||||
fig_ts = create_time_series_plot(
|
|
||||||
plot_indices, y_true_np, y_pred_np,
|
|
||||||
f"Predictions vs Actual - {title_suffix}",
|
|
||||||
xlabel=xlabel,
|
|
||||||
ylabel="Value (Original Scale)",
|
|
||||||
max_points=eval_config.plot_sample_size
|
|
||||||
)
|
|
||||||
save_plot(fig_ts, fold_plot_dir / "predictions_vs_actual.png")
|
|
||||||
|
|
||||||
fig_scatter = create_scatter_plot(
|
try:
|
||||||
y_true_np, y_pred_np,
|
fig_ts = create_time_series_plot(
|
||||||
f"Scatter Plot - {title_suffix}",
|
target_time_index_h1, y_true_h1, y_pred_h1, # Use H1 data and time
|
||||||
xlabel="Actual Values (Original Scale)",
|
f"Predictions vs Actual {title_suffix}",
|
||||||
ylabel="Predicted Values (Original Scale)"
|
xlabel=xlabel_h1, ylabel="Value (Original Scale)",
|
||||||
)
|
max_points=eval_config.plot_sample_size
|
||||||
save_plot(fig_scatter, fold_plot_dir / "scatter_predictions.png")
|
)
|
||||||
|
save_plot(fig_ts, fold_plot_dir / f"predictions_vs_actual_h{first_horizon}.png")
|
||||||
|
|
||||||
fig_res_time = create_residuals_plot(
|
fig_scatter = create_scatter_plot(
|
||||||
plot_indices, residuals_np,
|
y_true_h1, y_pred_h1, # Use H1 data
|
||||||
f"Residuals Over Time - {title_suffix}",
|
f"Scatter Plot {title_suffix}",
|
||||||
xlabel=xlabel,
|
xlabel="Actual Values (Original Scale)", ylabel="Predicted Values (Original Scale)"
|
||||||
ylabel="Residual (Original Scale)",
|
)
|
||||||
max_points=eval_config.plot_sample_size
|
save_plot(fig_scatter, fold_plot_dir / f"scatter_predictions_h{first_horizon}.png")
|
||||||
)
|
|
||||||
save_plot(fig_res_time, fold_plot_dir / "residuals_time.png")
|
|
||||||
|
|
||||||
fig_res_dist = create_residuals_distribution_plot(
|
fig_res_time = create_residuals_plot(
|
||||||
residuals_np,
|
target_time_index_h1, residuals_h1, # Use H1 residuals and time
|
||||||
f"Residuals Distribution - {title_suffix}",
|
f"Residuals Over Time {title_suffix}",
|
||||||
xlabel="Residual Value (Original Scale)",
|
xlabel=xlabel_h1, ylabel="Residual (Original Scale)",
|
||||||
ylabel="Density"
|
max_points=eval_config.plot_sample_size
|
||||||
)
|
)
|
||||||
save_plot(fig_res_dist, fold_plot_dir / "residuals_distribution.png")
|
save_plot(fig_res_time, fold_plot_dir / f"residuals_time_h{first_horizon}.png")
|
||||||
|
|
||||||
logger.info(f"Evaluation plots saved to: {fold_plot_dir}")
|
# Residual distribution can use residuals from ALL horizons
|
||||||
|
residuals_all = y_true_inv_np.flatten() - y_pred_inv_np.flatten()
|
||||||
|
fig_res_dist = create_residuals_distribution_plot(
|
||||||
|
residuals_all, # Use all residuals
|
||||||
|
f"Residuals Distribution {eval_context_str} (All Horizons)", # Adjusted title
|
||||||
|
xlabel="Residual Value (Original Scale)", ylabel="Density"
|
||||||
|
)
|
||||||
|
save_plot(fig_res_dist, fold_plot_dir / "residuals_distribution_all_horizons.png")
|
||||||
|
|
||||||
except Exception as e:
|
logger.info(f"Evaluation plots saved to: {fold_plot_dir}")
|
||||||
logger.error(f"Failed to generate or save one or more plots for Fold {fold_id}: {e}", exc_info=True)
|
|
||||||
# Continue without plots, metrics are already calculated.
|
|
||||||
|
|
||||||
elif eval_config.save_plots and len(y_true_np) == 0:
|
except Exception as e:
|
||||||
logger.warning(f"Skipping plot generation for Fold {fold_id} due to empty data.")
|
logger.error(f"Failed to generate or save one or more plots for {eval_context_str}: {e}", exc_info=True)
|
||||||
|
|
||||||
|
elif eval_config.save_plots and np.sum(valid_mask) == 0:
|
||||||
|
logger.warning(f"Skipping plot generation for {eval_context_str} due to no valid data points.")
|
||||||
|
|
||||||
logger.info(f"Evaluation processing finished for Fold {fold_id}.")
|
logger.info(f"Evaluation processing finished for {eval_context_str}.")
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
@ -257,63 +279,90 @@ def evaluate_fold_predictions(
|
|||||||
# This function still calls evaluate_fold_predictions internally, so it benefits
|
# This function still calls evaluate_fold_predictions internally, so it benefits
|
||||||
# from the updated plotting logic without needing direct changes here.
|
# from the updated plotting logic without needing direct changes here.
|
||||||
def evaluate_model_on_fold_test_set(
|
def evaluate_model_on_fold_test_set(
|
||||||
model: torch.nn.Module,
|
model: LSTMForecastLightningModule, # Use the specific type
|
||||||
test_loader: DataLoader,
|
test_loader: DataLoader,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
target_scaler: Union[StandardScaler, MinMaxScaler, None],
|
target_scaler: Union[StandardScaler, MinMaxScaler, None],
|
||||||
eval_config: EvaluationConfig,
|
eval_config: EvaluationConfig,
|
||||||
fold_num: int,
|
fold_num: int,
|
||||||
output_dir: str
|
output_dir: str,
|
||||||
|
# time_index: Optional[Union[np.ndarray, pd.Index]] = None, # OLD
|
||||||
|
prediction_time_index: Optional[pd.Index] = None, # Pass prediction time index
|
||||||
|
forecast_horizons: Optional[List[int]] = None # Pass horizons
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, float]:
|
||||||
"""
|
"""
|
||||||
[Optional Function] Evaluates a given model on a fold's test set.
|
[Optional Function] Evaluates a given model on a fold's test set.
|
||||||
|
Handles multiple forecast horizons.
|
||||||
Runs the inference loop, collects scaled results, then processes them using
|
|
||||||
`evaluate_fold_predictions` (which now uses plotting utilities).
|
|
||||||
Useful for standalone testing or if not using pl.Trainer.test().
|
|
||||||
"""
|
"""
|
||||||
# ... (Implementation of inference loop remains the same) ...
|
|
||||||
logger.info(f"Starting full evaluation (inference + processing) for Fold {fold_num + 1}...")
|
logger.info(f"Starting full evaluation (inference + processing) for Fold {fold_num + 1}...")
|
||||||
model.eval()
|
model.eval()
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
all_preds_scaled_list: List[torch.Tensor] = []
|
all_preds_scaled_list: List[torch.Tensor] = []
|
||||||
all_targets_scaled_list: List[torch.Tensor] = []
|
all_targets_scaled_list: List[torch.Tensor] = []
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for i, (X_batch, y_batch) in enumerate(test_loader):
|
for i, batch in enumerate(test_loader):
|
||||||
try:
|
try:
|
||||||
X_batch = X_batch.to(device)
|
if isinstance(batch, (list, tuple)) and len(batch) == 2:
|
||||||
outputs = model(X_batch) # Scaled outputs
|
X_batch, y_batch = batch # y_batch shape: (batch, len(horizons))
|
||||||
|
targets_present = True
|
||||||
|
else:
|
||||||
|
X_batch = batch
|
||||||
|
y_batch = None
|
||||||
|
targets_present = False
|
||||||
|
|
||||||
# Ensure outputs match target shape (e.g., handle trailing dimension)
|
X_batch = X_batch.to(device)
|
||||||
if outputs.shape != y_batch.shape:
|
outputs = model(X_batch) # Scaled outputs: (batch, len(horizons))
|
||||||
if outputs.ndim == y_batch.ndim + 1 and outputs.shape[-1] == 1:
|
|
||||||
outputs = outputs.squeeze(-1)
|
|
||||||
if outputs.shape != y_batch.shape:
|
|
||||||
raise ValueError(f"Shape mismatch: Output {outputs.shape}, Target {y_batch.shape}")
|
|
||||||
|
|
||||||
all_preds_scaled_list.append(outputs.cpu())
|
all_preds_scaled_list.append(outputs.cpu())
|
||||||
all_targets_scaled_list.append(y_batch.cpu()) # Keep targets on CPU
|
|
||||||
|
if targets_present and y_batch is not None:
|
||||||
|
if outputs.shape != y_batch.shape:
|
||||||
|
raise ValueError(f"Shape mismatch: Output {outputs.shape}, Target {y_batch.shape}")
|
||||||
|
all_targets_scaled_list.append(y_batch.cpu())
|
||||||
|
# ... error/warning if targets expected but not found ...
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during inference batch {i} for Fold {fold_num+1}: {e}", exc_info=True)
|
logger.error(f"Error during inference batch {i} for Fold {fold_num+1}: {e}", exc_info=True)
|
||||||
raise ValueError(f"Inference failed on batch {i} for Fold {fold_num+1}")
|
raise ValueError(f"Inference failed on batch {i} for Fold {fold_num+1}")
|
||||||
|
|
||||||
|
# --- Concatenate results ---
|
||||||
# Concatenate results from all batches
|
|
||||||
try:
|
try:
|
||||||
if not all_preds_scaled_list or not all_targets_scaled_list:
|
if not all_preds_scaled_list:
|
||||||
logger.error(f"No prediction results collected for Fold {fold_num + 1}. Check test_loader.")
|
# ... handle no predictions ...
|
||||||
|
return {'MAE': np.nan, 'RMSE': np.nan}
|
||||||
|
# Resulting shapes: (n_samples, len(horizons))
|
||||||
|
y_pred_scaled = torch.cat(all_preds_scaled_list, dim=0).numpy()
|
||||||
|
|
||||||
|
y_true_scaled = None
|
||||||
|
if all_targets_scaled_list:
|
||||||
|
y_true_scaled = torch.cat(all_targets_scaled_list, dim=0).numpy()
|
||||||
|
elif targets_present:
|
||||||
|
# ... handle missing targets ...
|
||||||
|
return {'MAE': np.nan, 'RMSE': np.nan}
|
||||||
|
else:
|
||||||
|
# ... handle no targets available ...
|
||||||
return {'MAE': np.nan, 'RMSE': np.nan}
|
return {'MAE': np.nan, 'RMSE': np.nan}
|
||||||
|
|
||||||
y_pred_scaled = torch.cat(all_preds_scaled_list, dim=0).numpy()
|
|
||||||
y_true_scaled = torch.cat(all_targets_scaled_list, dim=0).numpy()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error concatenating prediction results for Fold {fold_num + 1}: {e}", exc_info=True)
|
# ... error handling ...
|
||||||
raise ValueError("Failed to combine batch results during evaluation inference.")
|
raise ValueError("Failed to combine batch results during evaluation inference.")
|
||||||
|
|
||||||
# Process the collected predictions using the refactored function
|
if y_true_scaled is None:
|
||||||
# No time_index passed here by default, plotting will use integer indices
|
# ... handle missing targets ...
|
||||||
|
return {'MAE': np.nan, 'RMSE': np.nan}
|
||||||
|
|
||||||
|
# Ensure forecast_horizons are passed if available from the model
|
||||||
|
# Retrieve from model's hparams if not passed explicitly
|
||||||
|
if forecast_horizons is None:
|
||||||
|
try:
|
||||||
|
# Assuming forecast_horizon list is stored in model_config hparam
|
||||||
|
forecast_horizons = model.hparams.model_config.forecast_horizon
|
||||||
|
except AttributeError:
|
||||||
|
logger.warning("Could not retrieve forecast_horizons from model hparams for evaluation.")
|
||||||
|
|
||||||
|
|
||||||
|
# Process the collected predictions
|
||||||
return evaluate_fold_predictions(
|
return evaluate_fold_predictions(
|
||||||
y_true_scaled=y_true_scaled,
|
y_true_scaled=y_true_scaled,
|
||||||
y_pred_scaled=y_pred_scaled,
|
y_pred_scaled=y_pred_scaled,
|
||||||
@ -321,5 +370,8 @@ def evaluate_model_on_fold_test_set(
|
|||||||
eval_config=eval_config,
|
eval_config=eval_config,
|
||||||
fold_num=fold_num,
|
fold_num=fold_num,
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
time_index=None # Explicitly pass None
|
# time_index=time_index # OLD
|
||||||
|
prediction_time_index=prediction_time_index, # Pass through
|
||||||
|
forecast_horizons=forecast_horizons, # Pass through
|
||||||
|
plot_title_prefix=f"Test Fold {fold_num + 1}" # Example prefix
|
||||||
)
|
)
|
@ -1,11 +1,15 @@
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union, List
|
||||||
import logging
|
import logging
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Assuming sklearn scalers are available
|
||||||
|
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def setup_plot_style(use_seaborn: bool = True) -> None:
|
def setup_plot_style(use_seaborn: bool = True) -> None:
|
||||||
@ -17,14 +21,16 @@ def setup_plot_style(use_seaborn: bool = True) -> None:
|
|||||||
"""
|
"""
|
||||||
if use_seaborn:
|
if use_seaborn:
|
||||||
try:
|
try:
|
||||||
sns.set_theme(style="whitegrid", palette="muted")
|
# Use a different style that might be better for multiple lines
|
||||||
plt.rcParams['figure.figsize'] = (12, 6) # Default figure size
|
sns.set_theme(style="whitegrid", palette="viridis") # Changed palette
|
||||||
|
plt.rcParams['figure.figsize'] = (15, 7) # Slightly larger default figure size
|
||||||
logger.debug("Seaborn plot style set.")
|
logger.debug("Seaborn plot style set.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to set seaborn theme: {e}. Using default matplotlib style.")
|
logger.warning(f"Failed to set seaborn theme: {e}. Using default matplotlib style.")
|
||||||
else:
|
else:
|
||||||
# Optional: Define a default matplotlib style if seaborn is not used
|
# Optional: Define a default matplotlib style if seaborn is not used
|
||||||
plt.style.use('default')
|
plt.style.use('default')
|
||||||
|
plt.rcParams['figure.figsize'] = (15, 7)
|
||||||
logger.debug("Using default matplotlib plot style.")
|
logger.debug("Using default matplotlib plot style.")
|
||||||
|
|
||||||
def save_plot(fig: plt.Figure, filename: Union[str, Path]) -> None:
|
def save_plot(fig: plt.Figure, filename: Union[str, Path]) -> None:
|
||||||
@ -49,16 +55,21 @@ def save_plot(fig: plt.Figure, filename: Union[str, Path]) -> None:
|
|||||||
logger.info(f"Plot saved successfully to: {filepath}")
|
logger.info(f"Plot saved successfully to: {filepath}")
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
logger.error(f"Failed to create directory for plot {filepath}: {e}", exc_info=True)
|
logger.error(f"Failed to create directory for plot {filepath}: {e}", exc_info=True)
|
||||||
raise # Re-raise OSError for directory creation issues
|
# Don't re-raise immediately, try closing figure first
|
||||||
|
# raise # Re-raise OSError for directory creation issues - Removed to ensure finally runs
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to save plot to {filepath}: {e}", exc_info=True)
|
logger.error(f"Failed to save plot to {filepath}: {e}", exc_info=True)
|
||||||
raise # Re-raise other saving errors
|
# Don't re-raise immediately, try closing figure first
|
||||||
finally:
|
finally:
|
||||||
# Close the figure to free up memory, regardless of saving success
|
# Close the figure to free up memory, regardless of saving success or failure
|
||||||
plt.close(fig)
|
try:
|
||||||
|
plt.close(fig)
|
||||||
|
logger.debug(f"Closed figure for plot {filepath}.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to close figure for plot {filepath}: {e}")
|
||||||
|
|
||||||
def create_time_series_plot(
|
def create_time_series_plot(
|
||||||
x: np.ndarray,
|
x: Union[np.ndarray, pd.Index], # Allow pd.Index for time axis
|
||||||
y_true: np.ndarray,
|
y_true: np.ndarray,
|
||||||
y_pred: np.ndarray,
|
y_pred: np.ndarray,
|
||||||
title: str,
|
title: str,
|
||||||
@ -68,9 +79,9 @@ def create_time_series_plot(
|
|||||||
) -> plt.Figure:
|
) -> plt.Figure:
|
||||||
"""
|
"""
|
||||||
Create a time series plot comparing actual vs predicted values.
|
Create a time series plot comparing actual vs predicted values.
|
||||||
|
NOTE: When using multi-horizon forecasts, this typically plots only ONE selected horizon.
|
||||||
Args:
|
Args:
|
||||||
x: The array for the x-axis (e.g., time steps, indices).
|
x: The array or index for the x-axis (e.g., time steps, datetime index). Should align with y_true/y_pred.
|
||||||
y_true: Ground truth values (1D array).
|
y_true: Ground truth values (1D array).
|
||||||
y_pred: Predicted values (1D array).
|
y_pred: Predicted values (1D array).
|
||||||
title: Title for the plot.
|
title: Title for the plot.
|
||||||
@ -84,8 +95,9 @@ def create_time_series_plot(
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If input array shapes are incompatible.
|
ValueError: If input array shapes are incompatible.
|
||||||
"""
|
"""
|
||||||
if not (x.shape == y_true.shape == y_pred.shape and x.ndim == 1):
|
# Add check for pd.Index for x
|
||||||
raise ValueError("Input arrays (x, y_true, y_pred) must be 1D and have the same shape.")
|
if not isinstance(x, (np.ndarray, pd.Index)) or x.shape[0] != y_true.shape[0] or x.shape[0] != y_pred.shape[0] or y_true.ndim != 1 or y_pred.ndim != 1:
|
||||||
|
raise ValueError(f"Input shapes mismatch or invalid types: x({type(x)}, {x.shape if hasattr(x, 'shape') else 'N/A'}), y_true({y_true.shape}), y_pred({y_pred.shape}). Expecting 1D y arrays and matching length x.")
|
||||||
if len(x) == 0:
|
if len(x) == 0:
|
||||||
logger.warning("Attempting to create time series plot with empty data.")
|
logger.warning("Attempting to create time series plot with empty data.")
|
||||||
# Return an empty figure or raise error? Let's return empty.
|
# Return an empty figure or raise error? Let's return empty.
|
||||||
@ -304,4 +316,243 @@ def create_residuals_distribution_plot(
|
|||||||
ax.grid(True, axis='y', linestyle='--', alpha=0.6)
|
ax.grid(True, axis='y', linestyle='--', alpha=0.6)
|
||||||
fig.tight_layout()
|
fig.tight_layout()
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
def create_multi_horizon_time_series_plot(
|
||||||
|
y_true_scaled_all_horizons: np.ndarray, # (N, H)
|
||||||
|
y_pred_scaled_all_horizons: np.ndarray, # (N, H)
|
||||||
|
target_scaler: Optional[Union[StandardScaler, MinMaxScaler]],
|
||||||
|
prediction_time_index_h1: pd.DatetimeIndex, # Time index for the first horizon predictions
|
||||||
|
forecast_horizons: List[int],
|
||||||
|
title: str,
|
||||||
|
xlabel: str = "Time",
|
||||||
|
ylabel: str = "Value (Original Scale)",
|
||||||
|
max_points: Optional[int] = 1000 # Limit points for clarity
|
||||||
|
) -> plt.Figure:
|
||||||
|
"""
|
||||||
|
Create a time series plot comparing actual values to predictions for multiple horizons.
|
||||||
|
Predictions for each horizon are plotted on their corresponding target time step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y_true_scaled_all_horizons: Ground truth values (N, H array) on scaled scale.
|
||||||
|
y_pred_scaled_all_horizons: Predicted values (N, H array) on scaled scale.
|
||||||
|
target_scaler: The scaler used for the target variable, needed for inverse transform.
|
||||||
|
prediction_time_index_h1: DatetimeIndex for the first horizon (h=h1) predictions.
|
||||||
|
Length should be N.
|
||||||
|
forecast_horizons: List of forecast horizons (e.g., [1, 6, 12, 24]).
|
||||||
|
title: Title for the plot.
|
||||||
|
xlabel: Label for the x-axis.
|
||||||
|
ylabel: Label for the y-axis.
|
||||||
|
max_points: Maximum number of points to display (subsamples if needed).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The generated matplotlib Figure object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If input shapes are incompatible or horizons list is invalid.
|
||||||
|
"""
|
||||||
|
if y_true_scaled_all_horizons.shape != y_pred_scaled_all_horizons.shape:
|
||||||
|
raise ValueError(f"Shapes of y_true_scaled_all_horizons {y_true_scaled_all_horizons.shape} and y_pred_scaled_all_horizons {y_pred_scaled_all_horizons.shape} must match.")
|
||||||
|
if y_true_scaled_all_horizons.ndim != 2 or y_true_scaled_all_horizons.shape[1] != len(forecast_horizons):
|
||||||
|
raise ValueError(f"y arrays must be 2D (N, H) where H is the number of horizons ({len(forecast_horizons)}). Shape is {y_true_scaled_all_horizons.shape}.")
|
||||||
|
if len(prediction_time_index_h1) != y_true_scaled_all_horizons.shape[0]:
|
||||||
|
raise ValueError(f"Length of prediction_time_index_h1 ({len(prediction_time_index_h1)}) must match the number of predictions ({y_true_scaled_all_horizons.shape[0]}).")
|
||||||
|
if not isinstance(prediction_time_index_h1, pd.DatetimeIndex):
|
||||||
|
logger.warning("prediction_time_index_h1 is not a DatetimeIndex. Time shifts may not work as expected.")
|
||||||
|
if not forecast_horizons or len(forecast_horizons) == 0:
|
||||||
|
raise ValueError("forecast_horizons list cannot be empty.")
|
||||||
|
|
||||||
|
logger.debug(f"Creating multi-horizon time series plot: {title}")
|
||||||
|
setup_plot_style() # Apply standard style
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(18, 8)) # Larger figure for multi-horizon
|
||||||
|
|
||||||
|
n_points = y_true_scaled_all_horizons.shape[0]
|
||||||
|
plot_indices = np.arange(n_points)
|
||||||
|
|
||||||
|
if max_points and n_points > max_points:
|
||||||
|
step = max(1, n_points // max_points)
|
||||||
|
plot_indices = plot_indices[::step]
|
||||||
|
# Subsample the data and index
|
||||||
|
y_true_scaled_plot = y_true_scaled_all_horizons[plot_indices]
|
||||||
|
y_pred_scaled_plot = y_pred_scaled_all_horizons[plot_indices]
|
||||||
|
time_index_h1_plot = prediction_time_index_h1[plot_indices]
|
||||||
|
effective_title = f'{title} (Sampled {len(plot_indices)} points)'
|
||||||
|
else:
|
||||||
|
y_true_scaled_plot = y_true_scaled_all_horizons
|
||||||
|
y_pred_scaled_plot = y_pred_scaled_all_horizons
|
||||||
|
time_index_h1_plot = prediction_time_index_h1
|
||||||
|
effective_title = title
|
||||||
|
|
||||||
|
# Inverse transform the subsampled data
|
||||||
|
y_true_inv_plot = None
|
||||||
|
y_pred_inv_plot = None
|
||||||
|
if target_scaler is not None:
|
||||||
|
try:
|
||||||
|
# Scaler expects (N * H, 1), reshape (N, H) to (N*H, 1)
|
||||||
|
y_true_inv_plot_flat = target_scaler.inverse_transform(y_true_scaled_plot.reshape(-1, 1))
|
||||||
|
y_pred_inv_plot_flat = target_scaler.inverse_transform(y_pred_scaled_plot.reshape(-1, 1))
|
||||||
|
# Reshape back to (N, H)
|
||||||
|
y_true_inv_plot = y_true_inv_plot_flat.reshape(y_true_scaled_plot.shape)
|
||||||
|
y_pred_inv_plot = y_pred_inv_plot_flat.reshape(y_pred_scaled_plot.shape)
|
||||||
|
logger.debug("Successfully inverse-transformed data for multi-horizon plot.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to inverse transform data for multi-horizon plot: {e}", exc_info=True)
|
||||||
|
# Fallback to plotting scaled data if inverse transform fails
|
||||||
|
y_true_inv_plot = y_true_scaled_plot
|
||||||
|
y_pred_inv_plot = y_pred_scaled_plot
|
||||||
|
ylabel = f"{ylabel} (Scaled Data - Inverse Transform Failed)"
|
||||||
|
|
||||||
|
|
||||||
|
if y_true_inv_plot is None or y_pred_inv_plot is None:
|
||||||
|
# This should not happen with the fallback, but as a safeguard
|
||||||
|
logger.error("Inverse transformed data is None, cannot plot.")
|
||||||
|
return fig # Return empty figure
|
||||||
|
|
||||||
|
# Plot Actuals (using h1's time index, as it's the reference point)
|
||||||
|
ax.plot(time_index_h1_plot, y_true_inv_plot[:, 0], label='Actuals', marker='.', linestyle='-', markersize=4, linewidth=1.5, color='black') # Actuals for H1
|
||||||
|
|
||||||
|
# Plot predictions for each horizon
|
||||||
|
colors = sns.color_palette("viridis", len(forecast_horizons)) # Use palette for distinct colors
|
||||||
|
linestyles = ['-', '--', '-.', ':'] * (len(forecast_horizons) // 4 + 1) # Cycle through linestyles
|
||||||
|
|
||||||
|
for i, horizon in enumerate(forecast_horizons):
|
||||||
|
preds_h = y_pred_inv_plot[:, i]
|
||||||
|
# Calculate time index for this specific horizon by shifting the h1 index
|
||||||
|
# Assumes the time index frequency is appropriate for the horizon steps
|
||||||
|
try:
|
||||||
|
time_index_h = time_index_h1_plot + pd.to_timedelta(horizon - forecast_horizons[0], unit='h') # Assuming 'h' for hours
|
||||||
|
ax.plot(time_index_h, preds_h, label=f'Predicted (h={horizon})', marker='x', linestyle=linestyles[i], markersize=4, alpha=0.8, linewidth=1, color=colors[i])
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not calculate time index for horizon {horizon}: {e}. Skipping plot for this horizon.", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
# Configure plot appearance
|
||||||
|
ax.set_title(effective_title, fontsize=16) # Slightly larger title
|
||||||
|
ax.set_xlabel(xlabel, fontsize=12)
|
||||||
|
ax.set_ylabel(ylabel, fontsize=12)
|
||||||
|
ax.legend(fontsize=10) # Smaller legend font
|
||||||
|
ax.grid(True, linestyle='--', alpha=0.6)
|
||||||
|
|
||||||
|
# Improve x-axis readability for datetimes
|
||||||
|
fig.autofmt_xdate() # Auto-rotate date labels
|
||||||
|
|
||||||
|
fig.tight_layout()
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
def plot_loss_curve_from_csv(
|
||||||
|
metrics_csv_path: Union[str, Path],
|
||||||
|
output_path: Union[str, Path],
|
||||||
|
title: str = "Training Loss Curve",
|
||||||
|
train_loss_col: str = "train_loss", # Changed to match logging in model.py
|
||||||
|
val_loss_col: str = "val_loss", # Common validation loss metric logged by PL
|
||||||
|
epoch_col: str = "epoch"
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Reads training metrics from a PyTorch Lightning CSVLogger file and plots
|
||||||
|
training and validation loss curves over epochs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metrics_csv_path: Path to the metrics.csv file generated by CSVLogger.
|
||||||
|
output_path: Path where the plot image will be saved.
|
||||||
|
title: Title for the plot.
|
||||||
|
train_loss_col: Name of the column containing epoch-level training loss.
|
||||||
|
val_loss_col: Name of the column containing epoch-level validation loss.
|
||||||
|
epoch_col: Name of the column containing the epoch number.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If the metrics_csv_path does not exist.
|
||||||
|
KeyError: If required columns are not found in the CSV.
|
||||||
|
Exception: For other plotting or file reading errors.
|
||||||
|
"""
|
||||||
|
logger.info(f"Generating loss curve plot from: {metrics_csv_path}")
|
||||||
|
metrics_path = Path(metrics_csv_path)
|
||||||
|
if not metrics_path.is_file():
|
||||||
|
raise FileNotFoundError(f"Metrics CSV file not found at: {metrics_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
metrics_df = pd.read_csv(metrics_path)
|
||||||
|
|
||||||
|
# Check if required columns exist
|
||||||
|
required_cols = [epoch_col, train_loss_col]
|
||||||
|
# Val loss column might be the scaled loss or the original scale MAE
|
||||||
|
possible_val_cols = [val_loss_col, 'val_MeanAbsoluteError_Original_Scale', 'val_mae_orig_scale'] # Include potential names
|
||||||
|
|
||||||
|
found_val_col = None
|
||||||
|
for col in possible_val_cols:
|
||||||
|
if col in metrics_df.columns:
|
||||||
|
found_val_col = col
|
||||||
|
break
|
||||||
|
|
||||||
|
if not found_val_col:
|
||||||
|
missing_cols = [col for col in required_cols if col not in metrics_df.columns]
|
||||||
|
raise KeyError(f"Missing required columns in {metrics_path}: {missing_cols} or a suitable validation loss/metric column from {possible_val_cols}.")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Plotting ---
|
||||||
|
setup_plot_style() # Apply standard style
|
||||||
|
fig, ax1 = plt.subplots(figsize=(12, 6))
|
||||||
|
|
||||||
|
color1 = 'tab:red'
|
||||||
|
ax1.set_xlabel(epoch_col.capitalize())
|
||||||
|
# Adjust ylabel based on actual column name used for train loss
|
||||||
|
ax1.set_ylabel(train_loss_col.replace('_epoch','').replace('_',' ').capitalize(), color=color1)
|
||||||
|
# Drop NaNs specific to this column for plotting integrity
|
||||||
|
train_plot_data = metrics_df[[epoch_col, train_loss_col]].dropna(subset=[train_loss_col])
|
||||||
|
# Filter for epoch column only if needed (usually not for loss plots)
|
||||||
|
# train_plot_data = train_plot_data[train_plot_data[epoch_col].notna()]
|
||||||
|
|
||||||
|
# Ensure epoch starts from 0 or 1 consistently
|
||||||
|
if train_plot_data[epoch_col].min() > 0 and 0 in metrics_df[epoch_col].unique():
|
||||||
|
# If epoch starts from 1 in plot data but 0 exists, adjust x-axis for alignment
|
||||||
|
ax1.plot(train_plot_data[epoch_col] + 1, train_plot_data[train_loss_col], color=color1, label='Train Loss', marker='.', linestyle='-')
|
||||||
|
logger.debug("Adjusting train loss x-axis by +1 for epoch alignment.")
|
||||||
|
else:
|
||||||
|
ax1.plot(train_plot_data[epoch_col], train_plot_data[train_loss_col], color=color1, label='Train Loss', marker='.', linestyle='-')
|
||||||
|
|
||||||
|
|
||||||
|
ax1.tick_params(axis='y', labelcolor=color1)
|
||||||
|
ax1.grid(True, axis='y', linestyle='--', alpha=0.6, which='major')
|
||||||
|
|
||||||
|
|
||||||
|
# Validation loss/metric plotting on twin axis
|
||||||
|
ax2 = ax1.twinx()
|
||||||
|
color2 = 'tab:blue'
|
||||||
|
# Adjust ylabel based on actual column name used for val metric
|
||||||
|
ax2.set_ylabel(found_val_col.replace('_epoch','').replace('_',' ').capitalize(), color=color2)
|
||||||
|
# Drop NaNs specific to the found validation column
|
||||||
|
val_plot_data = metrics_df[[epoch_col, found_val_col]].dropna(subset=[found_val_col])
|
||||||
|
# val_plot_data = val_plot_data[val_plot_data[epoch_col].notna()] # Ensure epoch is not NaN
|
||||||
|
|
||||||
|
# Ensure epoch starts from 0 or 1 consistently
|
||||||
|
if val_plot_data[epoch_col].min() > 0 and 0 in metrics_df[epoch_col].unique():
|
||||||
|
# If epoch starts from 1 in plot data but 0 exists, adjust x-axis for alignment
|
||||||
|
ax2.plot(val_plot_data[epoch_col] + 1, val_plot_data[found_val_col], color=color2, label='Validation Metric', marker='x', linestyle='--')
|
||||||
|
logger.debug("Adjusting val metric x-axis by +1 for epoch alignment.")
|
||||||
|
else:
|
||||||
|
ax2.plot(val_plot_data[epoch_col], val_plot_data[found_val_col], color=color2, label='Validation Metric', marker='x', linestyle='--')
|
||||||
|
|
||||||
|
|
||||||
|
ax2.tick_params(axis='y', labelcolor=color2)
|
||||||
|
|
||||||
|
|
||||||
|
# Add legend manually combining lines from both axes
|
||||||
|
lines, labels = ax1.get_legend_handles_labels()
|
||||||
|
lines2, labels2 = ax2.get_legend_handles_labels()
|
||||||
|
ax2.legend(lines + lines2, labels + labels2, loc='upper right')
|
||||||
|
|
||||||
|
plt.title(title, fontsize=14)
|
||||||
|
fig.tight_layout() # Otherwise the right y-label is slightly clipped
|
||||||
|
|
||||||
|
# Save the plot
|
||||||
|
save_plot(fig, output_path)
|
||||||
|
|
||||||
|
except pd.errors.EmptyDataError:
|
||||||
|
logger.error(f"Metrics CSV file is empty: {metrics_csv_path}")
|
||||||
|
except KeyError as e:
|
||||||
|
logger.error(f"Could not find expected column in {metrics_csv_path}: {e}")
|
||||||
|
raise # Re-raise specific error after logging
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create or save loss curve plot from {metrics_csv_path}: {e}", exc_info=True)
|
||||||
|
raise # Re-raise general errors
|
20
forecasting_model/train/__init__.py
Normal file
20
forecasting_model/train/__init__.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
"""
|
||||||
|
TODO
|
||||||
|
"""
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
|
|
||||||
|
# Expose core components for easier import
|
||||||
|
from .ensemble_evaluation import (
|
||||||
|
run_ensemble_evaluation
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Expose main configuration class from utils
|
||||||
|
from ..utils import MainConfig
|
||||||
|
|
||||||
|
# Define __all__ for explicit public API (optional but good practice)
|
||||||
|
__all__ = [
|
||||||
|
"run_ensemble_evaluation",
|
||||||
|
"MainConfig",
|
||||||
|
]
|
276
forecasting_model/train/classic.py
Normal file
276
forecasting_model/train/classic.py
Normal file
@ -0,0 +1,276 @@
|
|||||||
|
"""
|
||||||
|
Classic training routine: Train on initial data segment, validate and test on final segments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
import yaml
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
|
||||||
|
from pytorch_lightning.loggers import CSVLogger
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from forecasting_model.utils.forecast_config_model import MainConfig
|
||||||
|
from forecasting_model.data_processing import prepare_fold_data_and_loaders, split_data_classic
|
||||||
|
from forecasting_model.train.model import LSTMForecastLightningModule
|
||||||
|
from forecasting_model.evaluation import evaluate_fold_predictions
|
||||||
|
|
||||||
|
from forecasting_model.utils.helper import save_results
|
||||||
|
from forecasting_model.io.plotting import plot_loss_curve_from_csv
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def run_classic_training(
|
||||||
|
config: MainConfig,
|
||||||
|
full_df: pd.DataFrame,
|
||||||
|
output_base_dir: Path
|
||||||
|
) -> Optional[Dict[str, float]]:
|
||||||
|
"""
|
||||||
|
Runs a single training pipeline using a classic train/val/test split.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The main configuration object.
|
||||||
|
full_df: The complete raw DataFrame.
|
||||||
|
output_base_dir: The base directory where general outputs are saved.
|
||||||
|
Classic results will be saved in a subdirectory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary containing test metrics (e.g., {'MAE': ..., 'RMSE': ...})
|
||||||
|
for the classic run, or None if it fails.
|
||||||
|
"""
|
||||||
|
run_start_time = time.perf_counter()
|
||||||
|
logger.info("--- Starting Classic Training Run ---")
|
||||||
|
|
||||||
|
# Define a specific output directory for this run
|
||||||
|
classic_output_dir = output_base_dir / "classic_run"
|
||||||
|
classic_output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
logger.info(f"Classic run outputs will be saved to: {classic_output_dir}")
|
||||||
|
|
||||||
|
test_metrics: Optional[Dict[str, float]] = None
|
||||||
|
best_val_score: Optional[float] = None
|
||||||
|
best_model_path: Optional[str] = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# --- Data Splitting ---
|
||||||
|
logger.info("Splitting data into classic train/val/test sets...")
|
||||||
|
n_samples = len(full_df)
|
||||||
|
val_frac = config.cross_validation.val_size_fraction
|
||||||
|
test_frac = config.cross_validation.test_size_fraction
|
||||||
|
train_idx, val_idx, test_idx = split_data_classic(n_samples, val_frac, test_frac)
|
||||||
|
|
||||||
|
# Store test datetime index for evaluation plotting
|
||||||
|
test_datetime_index = full_df.iloc[test_idx].index
|
||||||
|
|
||||||
|
# --- Data Preparation ---
|
||||||
|
logger.info("Preparing data loaders for the classic split...")
|
||||||
|
train_loader, val_loader, test_loader, target_scaler, input_size = prepare_fold_data_and_loaders(
|
||||||
|
full_df=full_df,
|
||||||
|
train_idx=train_idx,
|
||||||
|
val_idx=val_idx,
|
||||||
|
test_idx=test_idx,
|
||||||
|
target_col=config.data.target_col,
|
||||||
|
feature_config=config.features,
|
||||||
|
train_config=config.training,
|
||||||
|
eval_config=config.evaluation
|
||||||
|
)
|
||||||
|
logger.info(f"Data loaders prepared. Input size determined: {input_size}")
|
||||||
|
|
||||||
|
# Save artifacts specific to this run if needed (e.g., for later inference)
|
||||||
|
torch.save(test_loader, classic_output_dir / "classic_test_loader.pt")
|
||||||
|
torch.save(target_scaler, classic_output_dir / "classic_target_scaler.pt")
|
||||||
|
torch.save(input_size, classic_output_dir / "classic_input_size.pt")
|
||||||
|
# Save config for this run
|
||||||
|
try: config_dump = config.model_dump()
|
||||||
|
except AttributeError: config_dump = config.model_dump()
|
||||||
|
with open(classic_output_dir / "config.yaml", 'w') as f:
|
||||||
|
yaml.dump(config_dump, f, default_flow_style=False)
|
||||||
|
|
||||||
|
# --- Model Initialization ---
|
||||||
|
model = LSTMForecastLightningModule(
|
||||||
|
model_config=config.model,
|
||||||
|
train_config=config.training,
|
||||||
|
input_size=input_size,
|
||||||
|
target_scaler=target_scaler
|
||||||
|
)
|
||||||
|
logger.info("Classic LSTMForecastLightningModule initialized.")
|
||||||
|
|
||||||
|
# --- PyTorch Lightning Callbacks ---
|
||||||
|
monitor_metric = "val_MeanAbsoluteError" # Monitor same metric as CV folds
|
||||||
|
monitor_mode = "min"
|
||||||
|
|
||||||
|
early_stop_callback = None
|
||||||
|
if config.training.early_stopping_patience is not None and config.training.early_stopping_patience > 0:
|
||||||
|
early_stop_callback = EarlyStopping(
|
||||||
|
monitor=monitor_metric, min_delta=0.0001,
|
||||||
|
patience=config.training.early_stopping_patience, verbose=True, mode=monitor_mode
|
||||||
|
)
|
||||||
|
logger.info(f"Enabled EarlyStopping: monitor='{monitor_metric}', patience={config.training.early_stopping_patience}")
|
||||||
|
|
||||||
|
checkpoint_callback = ModelCheckpoint(
|
||||||
|
dirpath=classic_output_dir / "checkpoints",
|
||||||
|
filename="best_classic_model", # Simple filename
|
||||||
|
save_top_k=1, monitor=monitor_metric, mode=monitor_mode, verbose=True
|
||||||
|
)
|
||||||
|
logger.info(f"Enabled ModelCheckpoint: monitor='{monitor_metric}', mode='{monitor_mode}'")
|
||||||
|
|
||||||
|
lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
||||||
|
callbacks = [checkpoint_callback, lr_monitor]
|
||||||
|
if early_stop_callback: callbacks.append(early_stop_callback)
|
||||||
|
|
||||||
|
# --- PyTorch Lightning Logger ---
|
||||||
|
pl_logger = CSVLogger(save_dir=str(classic_output_dir), name="training_logs")
|
||||||
|
logger.info(f"Using CSVLogger, logs will be saved in: {pl_logger.log_dir}")
|
||||||
|
|
||||||
|
# --- PyTorch Lightning Trainer ---
|
||||||
|
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
|
||||||
|
devices = 1 if accelerator == 'gpu' else None
|
||||||
|
precision = getattr(config.training, 'precision', 32)
|
||||||
|
|
||||||
|
trainer = pl.Trainer(
|
||||||
|
accelerator=accelerator, devices=devices,
|
||||||
|
max_epochs=config.training.epochs,
|
||||||
|
callbacks=callbacks, logger=pl_logger,
|
||||||
|
log_every_n_steps=max(1, len(train_loader)//10),
|
||||||
|
enable_progress_bar=True,
|
||||||
|
gradient_clip_val=getattr(config.training, 'gradient_clip_val', None),
|
||||||
|
precision=precision,
|
||||||
|
)
|
||||||
|
logger.info(f"Initialized PyTorch Lightning Trainer: accelerator='{accelerator}', devices={devices}, precision={precision}")
|
||||||
|
|
||||||
|
# --- Training ---
|
||||||
|
logger.info("Starting classic model training...")
|
||||||
|
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
|
||||||
|
logger.info("Classic model training finished.")
|
||||||
|
|
||||||
|
# Store best validation score and path
|
||||||
|
best_val_score_tensor = trainer.checkpoint_callback.best_model_score
|
||||||
|
best_model_path = trainer.checkpoint_callback.best_model_path
|
||||||
|
best_val_score = best_val_score_tensor.item() if best_val_score_tensor is not None else None
|
||||||
|
|
||||||
|
if best_val_score is not None:
|
||||||
|
logger.info(f"Best validation score ({monitor_metric}): {best_val_score:.4f}")
|
||||||
|
logger.info(f"Best model checkpoint path: {best_model_path}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Could not retrieve best validation score/path (metric: {monitor_metric}). Evaluation might use last model.")
|
||||||
|
best_model_path = None
|
||||||
|
|
||||||
|
# --- Prediction on Test Set ---
|
||||||
|
logger.info("Starting prediction on classic test set using best checkpoint...")
|
||||||
|
prediction_results_list = trainer.predict(
|
||||||
|
ckpt_path=best_model_path if best_model_path else 'last',
|
||||||
|
dataloaders=test_loader
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Evaluation ---
|
||||||
|
if not prediction_results_list:
|
||||||
|
logger.error("Predict phase did not return any results for classic run.")
|
||||||
|
test_metrics = None
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
# Shapes: (n_samples, len(horizons))
|
||||||
|
all_preds_scaled = torch.cat([b['preds_scaled'] for b in prediction_results_list], dim=0).numpy()
|
||||||
|
n_predictions = len(all_preds_scaled) # Number of samples actually predicted
|
||||||
|
|
||||||
|
if 'targets_scaled' in prediction_results_list[0]:
|
||||||
|
all_targets_scaled = torch.cat([b['targets_scaled'] for b in prediction_results_list], dim=0).numpy()
|
||||||
|
if len(all_targets_scaled) != n_predictions:
|
||||||
|
logger.error(f"Classic Run: Mismatch between number of predictions ({n_predictions}) and targets ({len(all_targets_scaled)}).")
|
||||||
|
raise ValueError("Prediction and target count mismatch during classic evaluation.")
|
||||||
|
else:
|
||||||
|
raise ValueError("Targets missing from prediction results.")
|
||||||
|
|
||||||
|
logger.info(f"Processing {n_predictions} prediction results for classic test set...")
|
||||||
|
|
||||||
|
# --- Calculate Correct Time Index for Plotting (First Horizon) ---
|
||||||
|
target_time_index_for_plotting = None
|
||||||
|
if test_idx is not None and config.features.forecast_horizon:
|
||||||
|
try:
|
||||||
|
test_block_index = full_df.index[test_idx] # Use the test_idx from classic split
|
||||||
|
seq_len = config.features.sequence_length
|
||||||
|
first_horizon = config.features.forecast_horizon[0]
|
||||||
|
start_offset = seq_len + first_horizon - 1
|
||||||
|
if start_offset < len(test_block_index):
|
||||||
|
end_index = min(start_offset + n_predictions, len(test_block_index))
|
||||||
|
target_time_index_for_plotting = test_block_index[start_offset:end_index]
|
||||||
|
if len(target_time_index_for_plotting) != n_predictions:
|
||||||
|
logger.warning(f"Classic Run: Calculated target time index length ({len(target_time_index_for_plotting)}) "
|
||||||
|
f"does not match prediction count ({n_predictions}). Plotting x-axis might be misaligned.")
|
||||||
|
target_time_index_for_plotting = None
|
||||||
|
else:
|
||||||
|
logger.warning(f"Classic Run: Cannot calculate target time index, start offset ({start_offset}) "
|
||||||
|
f"exceeds test block length ({len(test_block_index)}).")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Classic Run: Error calculating target time index for plotting: {e}", exc_info=True)
|
||||||
|
target_time_index_for_plotting = None # Ensure it's None if error occurs
|
||||||
|
else:
|
||||||
|
logger.warning(f"Classic Run: Skipping target time index calculation (missing test_idx or forecast_horizon).")
|
||||||
|
# --- End Index Calculation ---
|
||||||
|
|
||||||
|
# Use the classic run specific objects and config
|
||||||
|
test_metrics = evaluate_fold_predictions(
|
||||||
|
y_true_scaled=all_targets_scaled,
|
||||||
|
y_pred_scaled=all_preds_scaled,
|
||||||
|
target_scaler=target_scaler,
|
||||||
|
eval_config=config.evaluation,
|
||||||
|
fold_num=-1, # Indicate classic run
|
||||||
|
output_dir=str(classic_output_dir),
|
||||||
|
plot_subdir="plots",
|
||||||
|
prediction_time_index=target_time_index_for_plotting, # Pass the correctly calculated index
|
||||||
|
forecast_horizons=config.features.forecast_horizon,
|
||||||
|
plot_title_prefix="Classic Run"
|
||||||
|
)
|
||||||
|
# Save metrics
|
||||||
|
save_results({"overall_metrics": test_metrics}, classic_output_dir / "test_metrics.json")
|
||||||
|
logger.info(f"Classic run test metrics (overall): {test_metrics}")
|
||||||
|
|
||||||
|
# --- Plot Loss Curve for Classic Run ---
|
||||||
|
try:
|
||||||
|
# Adjusted logic to find metrics.csv inside potential version_*/ directories
|
||||||
|
classic_log_dir = classic_output_dir / "training_logs"
|
||||||
|
metrics_file = None
|
||||||
|
version_dirs = list(classic_log_dir.glob("version_*"))
|
||||||
|
if version_dirs:
|
||||||
|
# Assuming the latest version directory contains the relevant logs
|
||||||
|
latest_version_dir = max(version_dirs, key=lambda p: p.stat().st_mtime)
|
||||||
|
potential_metrics_file = latest_version_dir / "metrics.csv"
|
||||||
|
if potential_metrics_file.is_file():
|
||||||
|
metrics_file = potential_metrics_file
|
||||||
|
else:
|
||||||
|
logger.warning(f"Classic Run: metrics.csv not found in latest version directory: {latest_version_dir}")
|
||||||
|
else:
|
||||||
|
# Fallback if no version_* directories exist (less common with CSVLogger)
|
||||||
|
potential_metrics_file = classic_log_dir / "metrics.csv"
|
||||||
|
if potential_metrics_file.is_file():
|
||||||
|
metrics_file = potential_metrics_file
|
||||||
|
|
||||||
|
if metrics_file and metrics_file.is_file():
|
||||||
|
plot_loss_curve_from_csv(
|
||||||
|
metrics_csv_path=metrics_file,
|
||||||
|
output_path=classic_output_dir / "loss_curve.png",
|
||||||
|
title="Classic Run Training Progression",
|
||||||
|
train_loss_col='train_loss', # Changed from 'train_loss_epoch'
|
||||||
|
val_loss_col='val_loss' # Keep as 'val_loss'
|
||||||
|
)
|
||||||
|
logger.info(f"Generating loss curve for classic run from: {metrics_file}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Classic Run: Could not find metrics.csv in {classic_log_dir} or its version subdirectories for loss curve plot.")
|
||||||
|
except Exception as plot_e:
|
||||||
|
logger.error(f"Classic Run: Failed to generate loss curve plot: {plot_e}", exc_info=True)
|
||||||
|
# --- End Classic Loss Plotting ---
|
||||||
|
|
||||||
|
except (KeyError, ValueError, Exception) as e:
|
||||||
|
logger.error(f"Error processing classic prediction results: {e}", exc_info=True)
|
||||||
|
test_metrics = None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"An error occurred during the classic training pipeline: {e}", exc_info=True)
|
||||||
|
test_metrics = None # Indicate failure
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
||||||
|
run_end_time = time.perf_counter()
|
||||||
|
logger.info(f"--- Finished Classic Training Run in {run_end_time - run_start_time:.2f} seconds ---")
|
||||||
|
return test_metrics
|
425
forecasting_model/train/ensemble_evaluation.py
Normal file
425
forecasting_model/train/ensemble_evaluation.py
Normal file
@ -0,0 +1,425 @@
|
|||||||
|
"""
|
||||||
|
Ensemble evaluation for time series forecasting models.
|
||||||
|
|
||||||
|
This module provides functionality to evaluate ensemble predictions
|
||||||
|
by combining predictions from n-1 folds and testing on the remaining fold.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import yaml # For loading fold config
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
||||||
|
import pandas as pd # For time index handling
|
||||||
|
import pickle # Need pickle for the specific error check
|
||||||
|
|
||||||
|
from forecasting_model.evaluation import evaluate_fold_predictions
|
||||||
|
from forecasting_model.train.model import LSTMForecastLightningModule
|
||||||
|
from forecasting_model.utils.forecast_config_model import MainConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def load_fold_model_and_objects(
|
||||||
|
fold_dir: Path,
|
||||||
|
) -> Optional[Tuple[LSTMForecastLightningModule, MainConfig, torch.utils.data.DataLoader, Union[StandardScaler, MinMaxScaler, None], int, Optional[pd.Index], List[int]]]:
|
||||||
|
"""
|
||||||
|
Load a trained model, its config, dataloader, scaler, input_size, prediction time index, and forecast horizons.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fold_dir: Directory containing the fold's artifacts (checkpoint, config, loader, etc.).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing (model, config, test_loader, target_scaler, input_size, prediction_target_time_index, forecast_horizons)
|
||||||
|
or None if any essential artifact is missing or loading fails.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Loading artifacts from: {fold_dir}")
|
||||||
|
|
||||||
|
# 1. Load Fold Configuration
|
||||||
|
config_path = fold_dir / "config.yaml"
|
||||||
|
if not config_path.is_file():
|
||||||
|
logger.error(f"Fold config file not found in {fold_dir}")
|
||||||
|
return None
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
fold_config_dict = yaml.safe_load(f)
|
||||||
|
fold_config = MainConfig(**fold_config_dict) # Validate fold's config
|
||||||
|
|
||||||
|
# 2. Load Saved Objects using torch.load
|
||||||
|
test_loader_path = fold_dir / "test_loader.pt"
|
||||||
|
scaler_path = fold_dir / "target_scaler.pt"
|
||||||
|
input_size_path = fold_dir / "input_size.pt"
|
||||||
|
prediction_index_path = fold_dir / "prediction_target_time_index.pt"
|
||||||
|
|
||||||
|
if not all([p.is_file() for p in [test_loader_path, scaler_path, input_size_path]]):
|
||||||
|
logger.error(f"Missing one or more required artifacts (test_loader, target_scaler, input_size) in {fold_dir}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# --- Explicitly set weights_only=False for non-model objects ---
|
||||||
|
test_loader = torch.load(test_loader_path, weights_only=False)
|
||||||
|
target_scaler = torch.load(scaler_path, weights_only=False)
|
||||||
|
input_size = torch.load(input_size_path, weights_only=False)
|
||||||
|
# --- End Modification ---
|
||||||
|
except pickle.UnpicklingError as e:
|
||||||
|
# Catch potential unpickling errors even with weights_only=False
|
||||||
|
logger.error(f"Failed to unpickle saved object in {fold_dir}: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
except AttributeError as e:
|
||||||
|
# Catch potential issues if class definitions changed between saving and loading
|
||||||
|
logger.error(f"AttributeError loading saved object in {fold_dir} (class definition changed?): {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
# Catch other potential loading errors
|
||||||
|
logger.error(f"Unexpected error loading saved objects (loader/scaler/size) from {fold_dir}: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Retrieve forecast horizon list from the fold's config
|
||||||
|
forecast_horizons = fold_config.features.forecast_horizon
|
||||||
|
|
||||||
|
# --- Extract prediction target time index (if available) ---
|
||||||
|
prediction_target_time_index: Optional[pd.Index] = None
|
||||||
|
if prediction_index_path.is_file():
|
||||||
|
try:
|
||||||
|
prediction_target_time_index = torch.load(prediction_index_path, weights_only=False)
|
||||||
|
# Basic validation
|
||||||
|
if not isinstance(prediction_target_time_index, pd.Index):
|
||||||
|
logger.warning(f"Loaded prediction index from {prediction_index_path} is not a pandas Index.")
|
||||||
|
prediction_target_time_index = None
|
||||||
|
else:
|
||||||
|
logger.debug(f"Loaded prediction target time index from {prediction_index_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load prediction target time index from {prediction_index_path}: {e}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Prediction target time index file not found at {prediction_index_path}. Plotting x-axis might be inaccurate for ensemble plots.")
|
||||||
|
# --- End Index Extraction ---
|
||||||
|
|
||||||
|
|
||||||
|
# 3. Find Checkpoint and Load Model
|
||||||
|
checkpoint_path = None
|
||||||
|
try:
|
||||||
|
# Use rglob to find the checkpoint potentially nested deeper
|
||||||
|
checkpoints = list(fold_dir.glob("**/best_model_fold_*.ckpt"))
|
||||||
|
if not checkpoints:
|
||||||
|
logger.error(f"No 'best_model_fold_*.ckpt' checkpoint found in {fold_dir} or subdirectories.")
|
||||||
|
return None
|
||||||
|
if len(checkpoints) > 1:
|
||||||
|
logger.warning(f"Multiple checkpoints found in {fold_dir}, using the first one: {checkpoints[0]}")
|
||||||
|
checkpoint_path = checkpoints[0]
|
||||||
|
|
||||||
|
logger.info(f"Loading model from checkpoint: {checkpoint_path}")
|
||||||
|
model = LSTMForecastLightningModule.load_from_checkpoint(
|
||||||
|
checkpoint_path,
|
||||||
|
map_location=torch.device('cpu'), # Optional: load to CPU first if memory is tight
|
||||||
|
model_config=fold_config.model,
|
||||||
|
train_config=fold_config.training,
|
||||||
|
input_size=input_size,
|
||||||
|
target_scaler=target_scaler
|
||||||
|
)
|
||||||
|
model.eval()
|
||||||
|
logger.info(f"Successfully loaded model and artifacts from {fold_dir}")
|
||||||
|
return model, fold_config, test_loader, target_scaler, input_size, prediction_target_time_index, forecast_horizons
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.error(f"Checkpoint file not found: {checkpoint_path}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load model from checkpoint {checkpoint_path} in {fold_dir}: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Generic error loading artifacts from {fold_dir}: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def make_ensemble_predictions(
|
||||||
|
models: List[LSTMForecastLightningModule],
|
||||||
|
test_loader: torch.utils.data.DataLoader,
|
||||||
|
device: Optional[torch.device] = None
|
||||||
|
) -> Tuple[Optional[Dict[str, np.ndarray]], Optional[np.ndarray]]:
|
||||||
|
"""
|
||||||
|
Make predictions using an ensemble of models efficiently.
|
||||||
|
|
||||||
|
Processes the test_loader once, getting predictions from all models per batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
models: List of trained models (already in eval mode).
|
||||||
|
test_loader: DataLoader for the test set.
|
||||||
|
device: Device to run predictions on (e.g., torch.device("cuda:0")).
|
||||||
|
If None, attempts to use GPU if available, else CPU.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (ensemble_predictions, targets):
|
||||||
|
- ensemble_predictions: Dict containing ensemble predictions keyed by method
|
||||||
|
('mean', 'median', 'min', 'max'). Values are np.arrays.
|
||||||
|
Returns None if prediction fails.
|
||||||
|
- targets: Ground truth values as a single np.array. Returns None if prediction fails
|
||||||
|
or targets are unavailable in loader.
|
||||||
|
"""
|
||||||
|
if not models:
|
||||||
|
logger.warning("make_ensemble_predictions received an empty list of models.")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
if device is None:
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
logger.info(f"Running ensemble predictions on device: {device}")
|
||||||
|
|
||||||
|
# Move all models to the target device
|
||||||
|
for model in models:
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
all_batch_preds: List[List[np.ndarray]] = [[] for _ in models] # Outer list: models, Inner list: batches
|
||||||
|
all_batch_targets: List[np.ndarray] = []
|
||||||
|
targets_available = True
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_idx, batch in enumerate(test_loader):
|
||||||
|
try:
|
||||||
|
# Determine if batch contains targets
|
||||||
|
if isinstance(batch, (list, tuple)) and len(batch) == 2:
|
||||||
|
x, y = batch
|
||||||
|
x = x.to(device)
|
||||||
|
# Keep targets on CPU until needed for concatenation
|
||||||
|
all_batch_targets.append(y.cpu().numpy())
|
||||||
|
else:
|
||||||
|
x = batch.to(device)
|
||||||
|
targets_available = False # No targets found in this batch
|
||||||
|
|
||||||
|
# Get predictions from all models for this batch
|
||||||
|
for i, model in enumerate(models):
|
||||||
|
try:
|
||||||
|
pred = model(x) # Shape: (batch, horizon)
|
||||||
|
all_batch_preds[i].append(pred.cpu().numpy())
|
||||||
|
except Exception as model_err:
|
||||||
|
logger.error(f"Error during prediction with model {i} on batch {batch_idx}: {model_err}", exc_info=True)
|
||||||
|
# Handle error: Fill with NaNs? Skip model? For now, fill with NaNs of expected shape
|
||||||
|
# Infer expected shape: (batch_size, horizon)
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
horizon = models[0].output_size # Assume all models have same horizon
|
||||||
|
nan_preds = np.full((batch_size, horizon), np.nan)
|
||||||
|
all_batch_preds[i].append(nan_preds)
|
||||||
|
|
||||||
|
|
||||||
|
except Exception as batch_err:
|
||||||
|
logger.error(f"Error processing batch {batch_idx} for ensemble prediction: {batch_err}", exc_info=True)
|
||||||
|
# If a batch fails catastrophically, we might not be able to proceed reliably
|
||||||
|
return None, None # Indicate failure
|
||||||
|
|
||||||
|
# Concatenate batch results for each model
|
||||||
|
model_preds_concat = []
|
||||||
|
for i in range(len(models)):
|
||||||
|
if not all_batch_preds[i]: # Check if any predictions were collected for this model
|
||||||
|
logger.warning(f"No predictions collected for model index {i}. Skipping this model in ensemble.")
|
||||||
|
continue # Skip this model if it failed on all batches
|
||||||
|
try:
|
||||||
|
model_preds_concat.append(np.concatenate(all_batch_preds[i], axis=0))
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Failed to concatenate predictions for model index {i}: {e}. Check for shape mismatches or empty lists.")
|
||||||
|
# Decide how to handle: skip model or fail? Let's skip for robustness.
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not model_preds_concat:
|
||||||
|
logger.error("No valid predictions collected from any model in the ensemble.")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# Concatenate targets if available
|
||||||
|
targets_concat = None
|
||||||
|
if targets_available and all_batch_targets:
|
||||||
|
try:
|
||||||
|
targets_concat = np.concatenate(all_batch_targets, axis=0)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Failed to concatenate targets: {e}")
|
||||||
|
return None, None # Fail if targets were expected but couldn't be combined
|
||||||
|
elif targets_available and not all_batch_targets:
|
||||||
|
logger.warning("Targets were expected based on first batch, but none were collected.")
|
||||||
|
# Proceed without targets, returning None for them
|
||||||
|
|
||||||
|
# Stack predictions from all models: Shape (num_models, num_samples, horizon)
|
||||||
|
try:
|
||||||
|
stacked_preds = np.stack(model_preds_concat, axis=0)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Failed to stack model predictions: {e}. Check if all models produced compatible shapes.")
|
||||||
|
return None, targets_concat # Return targets if available, but no ensemble preds
|
||||||
|
|
||||||
|
# Calculate different ensemble predictions (handle NaNs potentially introduced by model failures)
|
||||||
|
# np.nanmean, np.nanmedian etc. ignore NaNs
|
||||||
|
ensemble_preds = {
|
||||||
|
'mean': np.nanmean(stacked_preds, axis=0),
|
||||||
|
'median': np.nanmedian(stacked_preds, axis=0),
|
||||||
|
'min': np.nanmin(stacked_preds, axis=0),
|
||||||
|
'max': np.nanmax(stacked_preds, axis=0)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Ensemble predictions generated using {stacked_preds.shape[0]} models.")
|
||||||
|
return ensemble_preds, targets_concat
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_ensemble_for_test_fold(
|
||||||
|
test_fold_num: int,
|
||||||
|
all_fold_dirs: List[Path],
|
||||||
|
output_base_dir: Path,
|
||||||
|
# full_data_index: Optional[pd.Index] = None # Removed, get from loaded objects
|
||||||
|
) -> Optional[Dict[str, Dict[str, float]]]:
|
||||||
|
"""
|
||||||
|
Evaluates ensemble predictions for a specific test fold.
|
||||||
|
Args:
|
||||||
|
test_fold_num: The 1-based number of the fold to use as the test set.
|
||||||
|
all_fold_dirs: List of paths to all fold directories.
|
||||||
|
output_base_dir: Base directory for saving evaluation results/plots.
|
||||||
|
Returns:
|
||||||
|
Dictionary containing metrics for each ensemble method for this test fold,
|
||||||
|
or None if evaluation fails.
|
||||||
|
"""
|
||||||
|
logger.info(f"--- Evaluating Ensemble: Test Fold {test_fold_num} ---")
|
||||||
|
test_fold_dir = output_base_dir / f"fold_{test_fold_num:02d}"
|
||||||
|
|
||||||
|
load_result = load_fold_model_and_objects(test_fold_dir)
|
||||||
|
if load_result is None:
|
||||||
|
logger.error(f"Failed to load necessary artifacts for test fold {test_fold_num}. Skipping ensemble evaluation for this fold.")
|
||||||
|
return None
|
||||||
|
# Unpack results including the prediction time index and horizons
|
||||||
|
_, test_fold_config, test_loader, target_scaler, _, prediction_target_time_index, test_forecast_horizons = load_result
|
||||||
|
|
||||||
|
# Load models from all *other* folds
|
||||||
|
ensemble_models: List[LSTMForecastLightningModule] = []
|
||||||
|
model_forecast_horizons = None # Track horizons from loaded models
|
||||||
|
for i, fold_dir in enumerate(all_fold_dirs):
|
||||||
|
current_fold_num = i + 1
|
||||||
|
if current_fold_num == test_fold_num:
|
||||||
|
continue # Skip the test fold itself
|
||||||
|
|
||||||
|
model_load_result = load_fold_model_and_objects(fold_dir)
|
||||||
|
if model_load_result:
|
||||||
|
model, _, _, _, _, _, fold_horizons = model_load_result # Only need the model here
|
||||||
|
if model:
|
||||||
|
ensemble_models.append(model)
|
||||||
|
# Store horizons from the first successful model load
|
||||||
|
if model_forecast_horizons is None:
|
||||||
|
model_forecast_horizons = fold_horizons
|
||||||
|
# Optional: Check consistency of horizons across ensemble models
|
||||||
|
elif set(model_forecast_horizons) != set(fold_horizons):
|
||||||
|
logger.error(f"Inconsistent forecast horizons between ensemble models! Test fold {test_fold_num} expected {test_forecast_horizons}, "
|
||||||
|
f"Model {i+1} has {fold_horizons}. Ensemble may be invalid.")
|
||||||
|
# Decide how to handle: error out, or proceed with caution?
|
||||||
|
# return None # Option: Fail hard
|
||||||
|
else:
|
||||||
|
logger.warning(f"Could not load model from fold {current_fold_num} to include in ensemble for test fold {test_fold_num}.")
|
||||||
|
|
||||||
|
|
||||||
|
if len(ensemble_models) < 2:
|
||||||
|
logger.warning(f"Skipping ensemble evaluation for test fold {test_fold_num}: "
|
||||||
|
f"Need at least 2 models for ensemble, only loaded {len(ensemble_models)}.")
|
||||||
|
return {} # Return empty dict, not None, to indicate process ran but no ensemble formed
|
||||||
|
|
||||||
|
# Check consistency between test fold horizons and ensemble model horizons
|
||||||
|
if model_forecast_horizons is None: # Should not happen if len(ensemble_models) >= 1
|
||||||
|
logger.error(f"Could not determine forecast horizons from ensemble models for test fold {test_fold_num}.")
|
||||||
|
return None
|
||||||
|
if set(test_forecast_horizons) != set(model_forecast_horizons):
|
||||||
|
logger.error(f"Forecast horizons of test fold {test_fold_num} ({test_forecast_horizons}) do not match "
|
||||||
|
f"horizons from ensemble models ({model_forecast_horizons}). Cannot evaluate.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Make ensemble predictions using the loaded models and the test fold's data loader
|
||||||
|
# Use the test fold's config to determine device implicitly
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
ensemble_preds_dict, targets_np = make_ensemble_predictions(ensemble_models, test_loader, device=device)
|
||||||
|
|
||||||
|
if ensemble_preds_dict is None or targets_np is None:
|
||||||
|
logger.error(f"Failed to generate ensemble predictions or retrieve targets for test fold {test_fold_num}.")
|
||||||
|
return None # Indicate failure
|
||||||
|
|
||||||
|
# Evaluate each ensemble method's predictions against the test fold's targets
|
||||||
|
fold_ensemble_results: Dict[str, Dict[str, float]] = {}
|
||||||
|
for method, preds_np in ensemble_preds_dict.items():
|
||||||
|
logger.info(f"Evaluating ensemble method '{method}' for test fold {test_fold_num}...")
|
||||||
|
|
||||||
|
# Define a unique output directory for this method's plots
|
||||||
|
method_plot_dir = output_base_dir / "ensemble_eval_plots" / f"test_fold_{test_fold_num:02d}" / f"method_{method}"
|
||||||
|
|
||||||
|
# Use the prediction_target_time_index loaded earlier
|
||||||
|
prediction_time_index_for_plot = None
|
||||||
|
if prediction_target_time_index is not None:
|
||||||
|
if len(prediction_target_time_index) == targets_np.shape[0]:
|
||||||
|
prediction_time_index_for_plot = prediction_target_time_index
|
||||||
|
else:
|
||||||
|
logger.warning(f"Length of loaded prediction target time index ({len(prediction_target_time_index)}) does not match "
|
||||||
|
f"number of samples ({targets_np.shape[0]}) for test fold {test_fold_num}, method '{method}'. Plot x-axis may be incorrect.")
|
||||||
|
|
||||||
|
|
||||||
|
# Call the standard evaluation function
|
||||||
|
metrics = evaluate_fold_predictions(
|
||||||
|
y_true_scaled=targets_np,
|
||||||
|
y_pred_scaled=preds_np,
|
||||||
|
target_scaler=target_scaler,
|
||||||
|
eval_config=test_fold_config.evaluation,
|
||||||
|
fold_num=test_fold_num - 1,
|
||||||
|
output_dir=str(method_plot_dir.parent.parent),
|
||||||
|
plot_subdir=f"method_{method}",
|
||||||
|
prediction_time_index=prediction_time_index_for_plot, # Pass the index
|
||||||
|
forecast_horizons=test_forecast_horizons,
|
||||||
|
plot_title_prefix=f"Ensemble ({method})"
|
||||||
|
)
|
||||||
|
fold_ensemble_results[method] = metrics
|
||||||
|
|
||||||
|
logger.info(f"--- Finished Ensemble Evaluation: Test Fold {test_fold_num} ---")
|
||||||
|
return fold_ensemble_results
|
||||||
|
|
||||||
|
|
||||||
|
def run_ensemble_evaluation(
|
||||||
|
config: MainConfig, # Pass main config for context if needed, though fold configs are loaded
|
||||||
|
output_base_dir: Path,
|
||||||
|
# full_data_index: Optional[pd.Index] = None # Removed, get index from loaded objects
|
||||||
|
) -> Dict[int, Dict[str, Dict[str, float]]]:
|
||||||
|
"""
|
||||||
|
Run ensemble evaluation across all folds, treating each as the test set once.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The main configuration object (potentially unused if fold configs sufficient).
|
||||||
|
output_base_dir: Base directory where fold outputs are stored.
|
||||||
|
Returns:
|
||||||
|
Dictionary containing ensemble metrics for each test fold:
|
||||||
|
{ test_fold_num: { ensemble_method: { metric_name: value, ... }, ... }, ... }
|
||||||
|
"""
|
||||||
|
logger.info("===== Starting Cross-Validated Ensemble Evaluation =====")
|
||||||
|
all_ensemble_results: Dict[int, Dict[str, Dict[str, float]]] = {}
|
||||||
|
|
||||||
|
# Discover fold directories
|
||||||
|
fold_dirs = sorted([d for d in output_base_dir.glob("fold_*") if d.is_dir()])
|
||||||
|
if not fold_dirs:
|
||||||
|
logger.error(f"No fold directories found in {output_base_dir} for ensemble evaluation.")
|
||||||
|
return {}
|
||||||
|
if len(fold_dirs) < 2:
|
||||||
|
logger.warning(f"Need at least 2 folds for ensemble evaluation, found {len(fold_dirs)}. Skipping.")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
logger.info(f"Found {len(fold_dirs)} fold directories.")
|
||||||
|
|
||||||
|
# Iterate through each fold, designating it as the test fold
|
||||||
|
for i, test_fold_dir in enumerate(fold_dirs):
|
||||||
|
test_fold_num = i + 1 # 1-based fold number
|
||||||
|
try:
|
||||||
|
results_for_test_fold = evaluate_ensemble_for_test_fold(
|
||||||
|
test_fold_num=test_fold_num,
|
||||||
|
all_fold_dirs=fold_dirs,
|
||||||
|
output_base_dir=output_base_dir,
|
||||||
|
# full_data_index=full_data_index # Removed
|
||||||
|
)
|
||||||
|
|
||||||
|
if results_for_test_fold is not None:
|
||||||
|
# Only add results if the evaluation didn't fail completely
|
||||||
|
all_ensemble_results[test_fold_num] = results_for_test_fold
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Catch unexpected errors during a specific test fold evaluation
|
||||||
|
logger.error(f"Unexpected error during ensemble evaluation with test fold {test_fold_num}: {e}", exc_info=True)
|
||||||
|
continue # Continue to the next fold
|
||||||
|
|
||||||
|
# Saving is handled by the main script (`forecasting_model_run.py`) which calls this
|
||||||
|
if not all_ensemble_results:
|
||||||
|
logger.warning("Ensemble evaluation finished, but no results were generated.")
|
||||||
|
else:
|
||||||
|
logger.info("===== Finished Cross-Validated Ensemble Evaluation =====")
|
||||||
|
|
||||||
|
return all_ensemble_results
|
0
forecasting_model/train/folds.py
Normal file
0
forecasting_model/train/folds.py
Normal file
@ -9,7 +9,7 @@ from typing import Optional, Dict, Any, Union, List, Tuple
|
|||||||
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
||||||
|
|
||||||
# Assuming config_model is in sibling directory utils/
|
# Assuming config_model is in sibling directory utils/
|
||||||
from forecasting_model.utils.config_model import ModelConfig, TrainingConfig
|
from forecasting_model.utils.forecast_config_model import ModelConfig, TrainingConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -30,41 +30,42 @@ class LSTMForecastLightningModule(pl.LightningModule):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# --- Validate & Store Configs ---
|
# --- Validate & Store Configs ---
|
||||||
# Validate the input_size passed during instantiation
|
|
||||||
if input_size <= 0:
|
if input_size <= 0:
|
||||||
raise ValueError("`input_size` must be provided as a positive integer during model instantiation.")
|
raise ValueError("`input_size` must be provided as a positive integer during model instantiation.")
|
||||||
|
self._input_size = input_size # Use a temporary attribute
|
||||||
|
|
||||||
# Store the validated input_size directly for use in layer definitions
|
# Ensure forecast_horizon is a valid list in the config
|
||||||
self._input_size = input_size # Use a temporary attribute before hparams are saved
|
if not hasattr(model_config, 'forecast_horizon') or \
|
||||||
|
not isinstance(model_config.forecast_horizon, list) or \
|
||||||
|
not model_config.forecast_horizon or \
|
||||||
|
any(h <= 0 for h in model_config.forecast_horizon):
|
||||||
|
raise ValueError("ModelConfig requires `forecast_horizon` to be a non-empty list of positive integers.")
|
||||||
|
|
||||||
# Ensure forecast_horizon is set in the config for the output layer
|
# Output size is the number of horizons we predict
|
||||||
if not hasattr(model_config, 'forecast_horizon') or model_config.forecast_horizon is None or model_config.forecast_horizon <= 0:
|
self.output_size = len(model_config.forecast_horizon)
|
||||||
raise ValueError("ModelConfig requires `forecast_horizon` to be set and positive.")
|
# Store the actual horizon list for reference if needed, ensure sorted
|
||||||
self.output_size = model_config.forecast_horizon
|
self.forecast_horizons = sorted(model_config.forecast_horizon)
|
||||||
|
|
||||||
# Store configurations - input_size argument will be saved via save_hyperparameters
|
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.train_config = train_config
|
self.train_config = train_config
|
||||||
self.target_scaler = target_scaler # Store scaler for this fold
|
self.target_scaler = target_scaler # Store scaler for this fold
|
||||||
|
|
||||||
# Use save_hyperparameters() to automatically log configs and allow loading
|
# Use save_hyperparameters() - forecast_horizon is part of model_config which is saved
|
||||||
# Pass input_size explicitly to be saved in hparams
|
|
||||||
# Exclude scaler as it's stateful and fold-specific
|
|
||||||
self.save_hyperparameters('model_config', 'train_config', 'input_size', ignore=['target_scaler'])
|
self.save_hyperparameters('model_config', 'train_config', 'input_size', ignore=['target_scaler'])
|
||||||
|
# Note: Pydantic models might not be perfectly saved/loaded by PL's hparams, check if needed.
|
||||||
|
# If issues arise loading, might need to flatten relevant hparams manually.
|
||||||
|
|
||||||
# --- Define Model Layers ---
|
# --- Define Model Layers ---
|
||||||
# Access input_size via hparams now
|
|
||||||
self.lstm = nn.LSTM(
|
self.lstm = nn.LSTM(
|
||||||
input_size=self.hparams.input_size,
|
input_size=self.hparams.input_size,
|
||||||
hidden_size=self.hparams.model_config.hidden_size,
|
hidden_size=self.hparams.model_config.hidden_size,
|
||||||
num_layers=self.hparams.model_config.num_layers,
|
num_layers=self.hparams.model_config.num_layers,
|
||||||
batch_first=True, # Input shape: (batch, seq_len, features)
|
batch_first=True,
|
||||||
dropout=self.hparams.model_config.dropout if self.hparams.model_config.num_layers > 1 else 0.0
|
dropout=self.hparams.model_config.dropout if self.hparams.model_config.num_layers > 1 else 0.0
|
||||||
)
|
)
|
||||||
self.dropout = nn.Dropout(self.hparams.model_config.dropout)
|
self.dropout = nn.Dropout(self.hparams.model_config.dropout)
|
||||||
|
|
||||||
# Output layer maps LSTM hidden state to the forecast horizon
|
# Output layer maps LSTM hidden state to the number of forecast horizons
|
||||||
# We typically take the output of the last time step
|
|
||||||
self.fc = nn.Linear(self.hparams.model_config.hidden_size, self.output_size)
|
self.fc = nn.Linear(self.hparams.model_config.hidden_size, self.output_size)
|
||||||
|
|
||||||
# Optional residual connection handling
|
# Optional residual connection handling
|
||||||
@ -96,7 +97,7 @@ class LSTMForecastLightningModule(pl.LightningModule):
|
|||||||
self.val_metrics = metrics.clone(prefix='val_')
|
self.val_metrics = metrics.clone(prefix='val_')
|
||||||
self.test_metrics = metrics.clone(prefix='test_')
|
self.test_metrics = metrics.clone(prefix='test_')
|
||||||
|
|
||||||
self.val_mae_original_scale = torchmetrics.MeanAbsoluteError()
|
self.val_MeanAbsoluteError_Original_Scale = torchmetrics.MeanAbsoluteError()
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@ -107,7 +108,8 @@ class LSTMForecastLightningModule(pl.LightningModule):
|
|||||||
x: Input tensor of shape (batch_size, sequence_length, input_size)
|
x: Input tensor of shape (batch_size, sequence_length, input_size)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Predictions tensor of shape (batch_size, forecast_horizon)
|
Predictions tensor of shape (batch_size, len(forecast_horizons))
|
||||||
|
where each element corresponds to a predicted horizon in sorted order.
|
||||||
"""
|
"""
|
||||||
# LSTM forward pass
|
# LSTM forward pass
|
||||||
lstm_out, (hidden, cell) = self.lstm(x) # Shape: (batch, seq_len, hidden_size)
|
lstm_out, (hidden, cell) = self.lstm(x) # Shape: (batch, seq_len, hidden_size)
|
||||||
@ -126,59 +128,50 @@ class LSTMForecastLightningModule(pl.LightningModule):
|
|||||||
last_time_step_out = last_time_step_out + residual
|
last_time_step_out = last_time_step_out + residual
|
||||||
|
|
||||||
# Final fully connected layer
|
# Final fully connected layer
|
||||||
predictions = self.fc(last_time_step_out) # Shape: (batch_size, output_size/horizon)
|
predictions = self.fc(last_time_step_out) # Shape: (batch_size, output_size/len(horizons))
|
||||||
|
|
||||||
return predictions # Shape: (batch_size, forecast_horizon)
|
return predictions # Shape: (batch_size, len(forecast_horizons))
|
||||||
|
|
||||||
def _calculate_loss(self, outputs, targets):
|
def _calculate_loss(self, outputs, targets):
|
||||||
# Ensure shapes match before loss calculation
|
# Shapes should now be (batch_size, len(horizons)) for both
|
||||||
if outputs.shape != targets.shape:
|
if outputs.shape != targets.shape:
|
||||||
# Squeeze potential extra dim: (batch, horizon, 1) -> (batch, horizon)
|
# Minimal check, dataset __getitem__ should ensure this
|
||||||
if outputs.ndim == targets.ndim + 1 and outputs.shape[-1] == 1:
|
raise ValueError(f"Output shape {outputs.shape} doesn't match target shape {targets.shape} for loss calculation.")
|
||||||
outputs = outputs.squeeze(-1)
|
|
||||||
if outputs.shape != targets.shape:
|
|
||||||
raise ValueError(f"Output shape {outputs.shape} doesn't match target shape {targets.shape} for loss calculation.")
|
|
||||||
return self.criterion(outputs, targets)
|
return self.criterion(outputs, targets)
|
||||||
|
|
||||||
def _inverse_transform(self, data: torch.Tensor) -> Optional[torch.Tensor]:
|
def _inverse_transform(self, data: torch.Tensor) -> Optional[torch.Tensor]:
|
||||||
"""Helper to inverse transform data using the stored target scaler."""
|
"""Helper to inverse transform data (preds or targets) using the stored target scaler."""
|
||||||
if self.target_scaler is None:
|
if self.target_scaler is None:
|
||||||
# logger.warning("Cannot inverse transform: target_scaler not available.")
|
return None
|
||||||
return None # Cannot inverse transform
|
|
||||||
|
data_cpu = data.detach().cpu().numpy().astype(np.float64)
|
||||||
|
original_shape = data_cpu.shape # e.g., (batch_size, len(horizons))
|
||||||
|
num_elements = data_cpu.size
|
||||||
|
|
||||||
# Scaler expects 2D input (N, 1)
|
# Scaler expects 2D input (N, 1)
|
||||||
# Ensure data is on CPU and is float64 for sklearn scaler typically
|
data_flat = data_cpu.reshape(num_elements, 1)
|
||||||
data_cpu = data.detach().cpu().numpy().astype(np.float64)
|
|
||||||
original_shape = data_cpu.shape
|
|
||||||
if data_cpu.ndim == 1:
|
|
||||||
data_flat = data_cpu.reshape(-1, 1)
|
|
||||||
elif data_cpu.ndim == 2: # (batch, horizon)
|
|
||||||
data_flat = data_cpu.reshape(-1, 1)
|
|
||||||
else:
|
|
||||||
logger.warning(f"Unexpected shape for inverse transform: {original_shape}. Reshaping to (-1, 1).")
|
|
||||||
data_flat = data_cpu.reshape(-1, 1)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
inversed_np = self.target_scaler.inverse_transform(data_flat)
|
inversed_np = self.target_scaler.inverse_transform(data_flat)
|
||||||
# Return as tensor on the original device
|
# Return as tensor on the original device, potentially reshaped
|
||||||
inversed_tensor = torch.from_numpy(inversed_np).float().to(data.device)
|
inversed_tensor = torch.from_numpy(inversed_np).float().to(data.device)
|
||||||
# Reshape back? Or keep flat? Keep flat for direct metric use often.
|
# Reshape back to original multi-horizon shape
|
||||||
return inversed_tensor.flatten()
|
return inversed_tensor.reshape(original_shape)
|
||||||
# return inversed_tensor.reshape(original_shape) # If original shape needed
|
# return inversed_tensor.flatten() # Keep flat if needed for specific metric inputs
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to inverse transform data: {e}", exc_info=True)
|
logger.error(f"Failed to inverse transform data: {e}", exc_info=True)
|
||||||
return None # Return None if inverse transform fails
|
return None
|
||||||
|
|
||||||
|
|
||||||
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
|
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
|
||||||
x, y = batch # Shapes: x=(batch, seq_len, features), y=(batch, horizon)
|
x, y = batch # Shapes: x=(batch, seq_len, features), y=(batch, len(horizons))
|
||||||
outputs = self(x) # Scaled outputs: (batch, horizon)
|
outputs = self(x) # Scaled outputs: (batch, len(horizons))
|
||||||
loss = self._calculate_loss(outputs, y)
|
loss = self._calculate_loss(outputs, y)
|
||||||
|
|
||||||
# Log scaled metrics
|
# Log scaled metrics
|
||||||
metrics = self.train_metrics(outputs, y) # Update internal state
|
self.train_metrics.update(outputs, y)
|
||||||
self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||||
self.log_dict(self.train_metrics, on_step=False, on_epoch=True, logger=True) # Log all metrics in collection
|
self.log_dict(self.train_metrics, on_step=False, on_epoch=True, logger=True)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
@ -188,20 +181,22 @@ class LSTMForecastLightningModule(pl.LightningModule):
|
|||||||
loss = self._calculate_loss(outputs, y)
|
loss = self._calculate_loss(outputs, y)
|
||||||
|
|
||||||
# Log scaled metrics
|
# Log scaled metrics
|
||||||
metrics = self.val_metrics(outputs, y) # Update internal state
|
self.val_metrics.update(outputs, y)
|
||||||
self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||||
self.log_dict(self.val_metrics, on_step=False, on_epoch=True, logger=True)
|
self.log_dict(self.val_metrics, on_step=False, on_epoch=True, logger=True)
|
||||||
|
|
||||||
# Log MAE on ORIGINAL scale if scaler is available (often the primary metric for checkpointing/Optuna)
|
# Log MAE on ORIGINAL scale (primary metric for checkpoints)
|
||||||
if self.target_scaler is not None:
|
if self.target_scaler is not None:
|
||||||
|
# Inverse transform keeps the (batch, len(horizons)) shape
|
||||||
outputs_inv = self._inverse_transform(outputs)
|
outputs_inv = self._inverse_transform(outputs)
|
||||||
y_inv = self._inverse_transform(y)
|
y_inv = self._inverse_transform(y)
|
||||||
|
|
||||||
if outputs_inv is not None and y_inv is not None:
|
if outputs_inv is not None and y_inv is not None:
|
||||||
# Ensure shapes are compatible (flattened by _inverse_transform)
|
# Ensure shapes match
|
||||||
if outputs_inv.shape == y_inv.shape:
|
if outputs_inv.shape == y_inv.shape:
|
||||||
self.val_mae_original_scale.update(outputs_inv, y_inv)
|
# It will compute the average MAE across all elements if multi-dim
|
||||||
self.log('val_mae_orig_scale', self.val_mae_original_scale, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
self.val_MeanAbsoluteError_Original_Scale.update(outputs_inv, y_inv)
|
||||||
|
self.log('val_MeanAbsoluteError_Original_Scale', self.val_MeanAbsoluteError_Original_Scale, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Shape mismatch after inverse transform in validation: Preds {outputs_inv.shape}, Targets {y_inv.shape}")
|
logger.warning(f"Shape mismatch after inverse transform in validation: Preds {outputs_inv.shape}, Targets {y_inv.shape}")
|
||||||
else:
|
else:
|
@ -5,7 +5,7 @@ This package contains configuration models, helper functions, and other utilitie
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Expose configuration models
|
# Expose configuration models
|
||||||
from .config_model import (
|
from .forecast_config_model import (
|
||||||
MainConfig,
|
MainConfig,
|
||||||
DataConfig,
|
DataConfig,
|
||||||
FeatureConfig,
|
FeatureConfig,
|
||||||
|
@ -44,7 +44,7 @@ class DataConfig(BaseModel):
|
|||||||
class FeatureConfig(BaseModel):
|
class FeatureConfig(BaseModel):
|
||||||
"""Configuration for feature engineering and preprocessing."""
|
"""Configuration for feature engineering and preprocessing."""
|
||||||
sequence_length: int = Field(..., gt=0)
|
sequence_length: int = Field(..., gt=0)
|
||||||
forecast_horizon: int = Field(..., gt=0)
|
forecast_horizon: List[int] = Field(..., min_length=1, description="List of specific forecast horizons to predict (e.g., [1, 6, 12]).")
|
||||||
lags: List[int] = []
|
lags: List[int] = []
|
||||||
rolling_window_sizes: List[int] = []
|
rolling_window_sizes: List[int] = []
|
||||||
use_time_features: bool = True
|
use_time_features: bool = True
|
||||||
@ -55,11 +55,11 @@ class FeatureConfig(BaseModel):
|
|||||||
clipping: ClippingConfig = ClippingConfig() # Default instance
|
clipping: ClippingConfig = ClippingConfig() # Default instance
|
||||||
scaling_method: Optional[Literal['standard', 'minmax']] = 'standard' # Added literal validation
|
scaling_method: Optional[Literal['standard', 'minmax']] = 'standard' # Added literal validation
|
||||||
|
|
||||||
@field_validator('lags', 'rolling_window_sizes')
|
@field_validator('lags', 'rolling_window_sizes', 'forecast_horizon')
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_positive_list_values(cls, v: List[int]) -> List[int]:
|
def check_positive_list_values(cls, v: List[int]) -> List[int]:
|
||||||
if any(val <= 0 for val in v):
|
if any(val <= 0 for val in v):
|
||||||
raise ValueError('Lists lags/rolling_window_sizes must contain only positive values')
|
raise ValueError('Lists lags, rolling_window_sizes, and forecast_horizon must contain only positive values')
|
||||||
return v
|
return v
|
||||||
|
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
@ -69,8 +69,8 @@ class ModelConfig(BaseModel):
|
|||||||
num_layers: int = Field(..., gt=0)
|
num_layers: int = Field(..., gt=0)
|
||||||
dropout: float = Field(..., ge=0.0, le=1.0)
|
dropout: float = Field(..., ge=0.0, le=1.0)
|
||||||
use_residual_skips: bool = False
|
use_residual_skips: bool = False
|
||||||
# Add forecast_horizon here to ensure LightningModule gets it directly
|
# forecast_horizon: Optional[int] = Field(None, gt=0) # OLD
|
||||||
forecast_horizon: Optional[int] = Field(None, gt=0) # Will be set from FeatureConfig
|
forecast_horizon: Optional[List[int]] = Field(None, min_length=1) # Will be set from FeatureConfig
|
||||||
|
|
||||||
class TrainingConfig(BaseModel):
|
class TrainingConfig(BaseModel):
|
||||||
"""Configuration for the training process (PyTorch Lightning)."""
|
"""Configuration for the training process (PyTorch Lightning)."""
|
||||||
@ -103,26 +103,35 @@ class EvaluationConfig(BaseModel):
|
|||||||
class OptunaConfig(BaseModel):
|
class OptunaConfig(BaseModel):
|
||||||
"""Optional configuration for Optuna hyperparameter optimization."""
|
"""Optional configuration for Optuna hyperparameter optimization."""
|
||||||
enabled: bool = False
|
enabled: bool = False
|
||||||
|
study_name: str = "default_study" # Added study_name
|
||||||
n_trials: int = Field(20, gt=0)
|
n_trials: int = Field(20, gt=0)
|
||||||
storage: Optional[str] = None # e.g., "sqlite:///output/hpo_results/study.db"
|
storage: Optional[str] = None # e.g., "sqlite:///output/hpo_results/study.db"
|
||||||
direction: Literal['minimize', 'maximize'] = 'minimize'
|
direction: Literal['minimize', 'maximize'] = 'minimize'
|
||||||
metric_to_optimize: str = 'val_mae_orig_scale'
|
metric_to_optimize: str = 'val_MeanAbsoluteError_Original_Scale' # Updated default metric
|
||||||
pruning: bool = True
|
pruning: bool = True
|
||||||
|
|
||||||
# --- Top-Level Configuration Model ---
|
# --- Top-Level Configuration Model ---
|
||||||
|
|
||||||
class MainConfig(BaseModel):
|
class MainConfig(BaseModel):
|
||||||
"""Main configuration model nesting all sections."""
|
"""Main configuration model nesting all sections."""
|
||||||
project_name: str = "TimeSeriesForecasting"
|
project_name: str = "TimeSeriesForecasting"
|
||||||
random_seed: Optional[int] = 42 # Added top-level seed
|
random_seed: Optional[int] = 42
|
||||||
|
log_level: Literal['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] = 'INFO'
|
||||||
|
output_dir: str = Field("output/cv_results", description="Base directory for saving all outputs (results, logs, models, plots).")
|
||||||
|
|
||||||
|
# --- Execution Control ---
|
||||||
|
run_cross_validation: bool = Field(True, description="Run the main cross-validation training loop?")
|
||||||
|
run_classic_training: bool = Field(True, description="Run a single classic train/val/test split training?")
|
||||||
|
run_ensemble_evaluation: bool = Field(True, description="Run ensemble evaluation using CV fold models?")
|
||||||
|
# --- End Execution Control ---
|
||||||
|
|
||||||
data: DataConfig
|
data: DataConfig
|
||||||
features: FeatureConfig
|
features: FeatureConfig
|
||||||
model: ModelConfig # ModelConfig no longer contains input_size
|
model: ModelConfig
|
||||||
training: TrainingConfig
|
training: TrainingConfig
|
||||||
cross_validation: CrossValidationConfig
|
cross_validation: CrossValidationConfig
|
||||||
evaluation: EvaluationConfig
|
evaluation: EvaluationConfig
|
||||||
optuna: Optional[OptunaConfig] = OptunaConfig() # Added optional Optuna config
|
optuna: Optional[OptunaConfig] = OptunaConfig()
|
||||||
|
|
||||||
@model_validator(mode='after')
|
@model_validator(mode='after')
|
||||||
def check_forecast_horizon_consistency(self) -> 'MainConfig':
|
def check_forecast_horizon_consistency(self) -> 'MainConfig':
|
||||||
@ -131,20 +140,33 @@ class MainConfig(BaseModel):
|
|||||||
if self.model.forecast_horizon is None:
|
if self.model.forecast_horizon is None:
|
||||||
# If model config doesn't have it, set it from features config
|
# If model config doesn't have it, set it from features config
|
||||||
self.model.forecast_horizon = self.features.forecast_horizon
|
self.model.forecast_horizon = self.features.forecast_horizon
|
||||||
elif self.model.forecast_horizon != self.features.forecast_horizon:
|
elif set(self.model.forecast_horizon) != set(self.features.forecast_horizon): # Compare sets for content equality
|
||||||
# If both are set but differ, raise error
|
# If both are set but differ, raise error
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"ModelConfig forecast_horizon ({self.model.forecast_horizon}) must match "
|
f"ModelConfig forecast_horizon ({self.model.forecast_horizon}) must match "
|
||||||
f"FeatureConfig forecast_horizon ({self.features.forecast_horizon})."
|
f"FeatureConfig forecast_horizon ({self.features.forecast_horizon})."
|
||||||
)
|
)
|
||||||
# After potential setting, ensure model.forecast_horizon is actually set
|
# After potential setting, ensure model.forecast_horizon is actually set and valid
|
||||||
if self.model and (self.model.forecast_horizon is None or self.model.forecast_horizon <= 0):
|
if self.model and (
|
||||||
raise ValueError("ModelConfig requires a positive forecast_horizon (must be set in features config if not set explicitly in model config).")
|
self.model.forecast_horizon is None or
|
||||||
|
not isinstance(self.model.forecast_horizon, list) or # Check type
|
||||||
|
len(self.model.forecast_horizon) == 0 or # Check not empty
|
||||||
|
any(h <= 0 for h in self.model.forecast_horizon) # Check positive values
|
||||||
|
):
|
||||||
|
raise ValueError("ModelConfig requires a non-empty list of positive forecast_horizon values (must be set in features config if not set explicitly in model config).")
|
||||||
|
|
||||||
# Input size check is removed as it's not part of static config anymore
|
# Input size check is removed as it's not part of static config anymore
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode='after')
|
||||||
|
def check_execution_flags(self) -> 'MainConfig':
|
||||||
|
if not self.run_cross_validation and not self.run_classic_training:
|
||||||
|
raise ValueError("At least one of 'run_cross_validation' or 'run_classic_training' must be True.")
|
||||||
|
if self.run_ensemble_evaluation and not self.run_cross_validation:
|
||||||
|
raise ValueError("'run_ensemble_evaluation' requires 'run_cross_validation' to be True (needs CV fold models).")
|
||||||
|
return self
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
# Example configuration for Pydantic itself
|
# Example configuration for Pydantic itself
|
||||||
validate_assignment = True # Re-validate on assignment
|
validate_assignment = True # Re-validate on assignment
|
173
forecasting_model/utils/helper.py
Normal file
173
forecasting_model/utils/helper.py
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, List, Dict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from forecasting_model import MainConfig
|
||||||
|
|
||||||
|
# Get the root logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_arguments():
|
||||||
|
"""Parses command-line arguments."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run the Time Series Forecasting training pipeline using a configuration file.",
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-c', '--config',
|
||||||
|
type=str,
|
||||||
|
default='config.yaml',
|
||||||
|
help="Path to the YAML configuration file."
|
||||||
|
)
|
||||||
|
# Removed seed, debug, and output-dir arguments
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(config_path: Path) -> MainConfig:
|
||||||
|
"""
|
||||||
|
Load and validate configuration from YAML file using Pydantic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: Path to the YAML configuration file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated MainConfig object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If the config file doesn't exist.
|
||||||
|
yaml.YAMLError: If the file is not valid YAML.
|
||||||
|
pydantic.ValidationError: If the config doesn't match the schema.
|
||||||
|
"""
|
||||||
|
if not config_path.is_file():
|
||||||
|
logger.error(f"Configuration file not found at: {config_path}")
|
||||||
|
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||||
|
|
||||||
|
logger.info(f"Loading configuration from: {config_path}")
|
||||||
|
try:
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
config_dict = yaml.safe_load(f)
|
||||||
|
|
||||||
|
# Validate configuration using Pydantic model
|
||||||
|
config = MainConfig(**config_dict)
|
||||||
|
logger.info("Configuration loaded and validated successfully.")
|
||||||
|
return config
|
||||||
|
except yaml.YAMLError as e:
|
||||||
|
logger.error(f"Error parsing YAML file {config_path}: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
except Exception as e: # Catches Pydantic validation errors too
|
||||||
|
logger.error(f"Error validating configuration {config_path}: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def set_seeds(seed: Optional[int] = 42) -> None:
|
||||||
|
"""
|
||||||
|
Set random seeds for reproducibility across libraries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: The seed value to use. If None, uses default 42.
|
||||||
|
"""
|
||||||
|
actual_seed = seed if seed is not None else 42
|
||||||
|
if seed is None:
|
||||||
|
logger.warning(f"No random_seed specified in config, using default seed: {actual_seed}")
|
||||||
|
else:
|
||||||
|
logger.info(f"Setting random seed from config: {actual_seed}")
|
||||||
|
|
||||||
|
random.seed(actual_seed)
|
||||||
|
np.random.seed(actual_seed)
|
||||||
|
torch.manual_seed(actual_seed)
|
||||||
|
# Ensure reproducibility for CUDA operations where possible
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(actual_seed)
|
||||||
|
torch.cuda.manual_seed_all(actual_seed) # For multi-GPU
|
||||||
|
# These settings can slow down training but improve reproducibility
|
||||||
|
# torch.backends.cudnn.deterministic = True
|
||||||
|
# torch.backends.cudnn.benchmark = False
|
||||||
|
# PyTorch Lightning seeding (optional, as we seed torch directly)
|
||||||
|
# pl.seed_everything(seed, workers=True) # workers=True ensures dataloader reproducibility
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_cv_metrics(all_fold_metrics: List[Dict[str, float]]) -> Dict[str, Dict[str, float]]:
|
||||||
|
"""
|
||||||
|
Calculate mean and standard deviation of metrics across folds.
|
||||||
|
Handles potential NaN values by ignoring them.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_fold_metrics: A list where each element is a dictionary of
|
||||||
|
metrics for one fold (e.g., {'MAE': v1, 'RMSE': v2}).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary where keys are metric names and values are dicts
|
||||||
|
containing 'mean' and 'std' for that metric across folds.
|
||||||
|
Example: {'MAE': {'mean': m, 'std': s}, 'RMSE': {'mean': m2, 'std': s2}}
|
||||||
|
"""
|
||||||
|
if not all_fold_metrics:
|
||||||
|
logger.warning("Received empty list for metric aggregation.")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
aggregated: Dict[str, Dict[str, float]] = {}
|
||||||
|
# Get metric names from the first valid fold's results
|
||||||
|
first_valid_metrics = next((m for m in all_fold_metrics if m), None)
|
||||||
|
if not first_valid_metrics:
|
||||||
|
logger.warning("No valid fold metrics found for aggregation.")
|
||||||
|
return {}
|
||||||
|
metric_names = list(first_valid_metrics.keys())
|
||||||
|
|
||||||
|
for metric in metric_names:
|
||||||
|
# Collect values for this metric across all folds, ignoring NaNs
|
||||||
|
values = [fold_metrics.get(metric) for fold_metrics in all_fold_metrics if fold_metrics and metric in fold_metrics]
|
||||||
|
valid_values = [v for v in values if v is not None and not np.isnan(v)]
|
||||||
|
|
||||||
|
if not valid_values:
|
||||||
|
logger.warning(f"No valid values found for metric '{metric}' across folds.")
|
||||||
|
mean_val = np.nan
|
||||||
|
std_val = np.nan
|
||||||
|
else:
|
||||||
|
mean_val = float(np.mean(valid_values))
|
||||||
|
std_val = float(np.std(valid_values))
|
||||||
|
logger.debug(f"Aggregated '{metric}': Mean={mean_val:.4f}, Std={std_val:.4f} from {len(valid_values)} folds.")
|
||||||
|
|
||||||
|
aggregated[metric] = {'mean': mean_val, 'std': std_val}
|
||||||
|
|
||||||
|
return aggregated
|
||||||
|
|
||||||
|
|
||||||
|
def save_results(results: Dict, filename: Path):
|
||||||
|
"""Save dictionary results to a JSON file."""
|
||||||
|
try:
|
||||||
|
filename.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
# Convert numpy types to native Python types for JSON serialization
|
||||||
|
results_serializable = json.loads(json.dumps(results, cls=NumpyEncoder))
|
||||||
|
with open(filename, 'w') as f:
|
||||||
|
json.dump(results_serializable, f, indent=4)
|
||||||
|
logger.info(f"Saved results to {filename}")
|
||||||
|
except TypeError as e:
|
||||||
|
logger.error(f"Serialization error saving results to {filename}. Check for non-serializable types (e.g., numpy types): {e}", exc_info=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save results to {filename}: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
class NumpyEncoder(json.JSONEncoder):
|
||||||
|
def default(self, obj):
|
||||||
|
if isinstance(obj, np.integer):
|
||||||
|
return int(obj)
|
||||||
|
elif isinstance(obj, np.floating):
|
||||||
|
return float(obj)
|
||||||
|
elif isinstance(obj, np.ndarray):
|
||||||
|
return obj.tolist()
|
||||||
|
elif isinstance(obj, (np.bool_, bool)):
|
||||||
|
return bool(obj)
|
||||||
|
elif pd.isna(obj): # Handle pandas NaT or numpy NaN gracefully
|
||||||
|
return None
|
||||||
|
return super(NumpyEncoder, self).default(obj)
|
@ -1,30 +1,34 @@
|
|||||||
import argparse
|
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
import os
|
|
||||||
import random
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import time
|
import time
|
||||||
import json
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
|
||||||
from pytorch_lightning.loggers import CSVLogger
|
from pytorch_lightning.loggers import CSVLogger
|
||||||
|
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
||||||
|
|
||||||
# Import necessary components from your project structure
|
# Import necessary components from your project structure
|
||||||
# Assuming forecasting_model is a package installable or in PYTHONPATH
|
# Assuming forecasting_model is a package installable or in PYTHONPATH
|
||||||
from forecasting_model.utils.config_model import MainConfig
|
from forecasting_model.utils.forecast_config_model import MainConfig
|
||||||
from forecasting_model.data_processing import (
|
from forecasting_model.data_processing import (
|
||||||
load_raw_data,
|
load_raw_data,
|
||||||
TimeSeriesCrossValidationSplitter,
|
TimeSeriesCrossValidationSplitter,
|
||||||
prepare_fold_data_and_loaders
|
prepare_fold_data_and_loaders
|
||||||
)
|
)
|
||||||
from forecasting_model.model import LSTMForecastLightningModule
|
from forecasting_model.train.model import LSTMForecastLightningModule
|
||||||
from forecasting_model.evaluation import evaluate_fold_predictions
|
from forecasting_model.evaluation import evaluate_fold_predictions
|
||||||
from typing import Dict, List, Any, Optional
|
from forecasting_model.train.ensemble_evaluation import run_ensemble_evaluation
|
||||||
|
|
||||||
|
# Import the new classic training function
|
||||||
|
from forecasting_model.train.classic import run_classic_training
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
from forecasting_model.utils.helper import parse_arguments, load_config, set_seeds, aggregate_cv_metrics, save_results
|
||||||
|
from forecasting_model.io.plotting import plot_loss_curve_from_csv, create_multi_horizon_time_series_plot, save_plot
|
||||||
|
|
||||||
# Silence overly verbose libraries if needed
|
# Silence overly verbose libraries if needed
|
||||||
mpl_logger = logging.getLogger('matplotlib')
|
mpl_logger = logging.getLogger('matplotlib')
|
||||||
@ -33,396 +37,552 @@ pil_logger = logging.getLogger('PIL')
|
|||||||
pil_logger.setLevel(logging.WARNING)
|
pil_logger.setLevel(logging.WARNING)
|
||||||
|
|
||||||
# --- Basic Logging Setup ---
|
# --- Basic Logging Setup ---
|
||||||
# Configure logging early. Level might be adjusted by config.
|
# Configure logging early. Level might be adjusted by config later.
|
||||||
logging.basicConfig(level=logging.INFO,
|
logging.basicConfig(level=logging.INFO,
|
||||||
format='%(asctime)s - %(levelname)-7s - %(message)s',
|
format='%(asctime)s - %(levelname)-7s - %(message)s',
|
||||||
datefmt='%H:%M:%S')
|
datefmt='%H:%M:%S')
|
||||||
# Get the root logger
|
# Get the root logger
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
# --- Argument Parsing ---
|
|
||||||
def parse_arguments():
|
|
||||||
"""Parses command-line arguments."""
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Run the Time Series Forecasting training pipeline.",
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-c', '--config',
|
|
||||||
type=str,
|
|
||||||
default='config.yaml',
|
|
||||||
help="Path to the YAML configuration file."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--seed',
|
|
||||||
type=int,
|
|
||||||
default=None, # Default to None, use config value if not provided
|
|
||||||
help="Override random seed defined in config."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--debug',
|
|
||||||
action='store_true',
|
|
||||||
help="Override log level to DEBUG."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--output-dir',
|
|
||||||
type=str,
|
|
||||||
default='output/cv_results', # Default output base directory
|
|
||||||
help="Base directory for saving cross-validation results (checkpoints, logs, plots)."
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
# --- Single Fold Processing Function ---
|
||||||
return args
|
def run_single_fold(
|
||||||
|
fold_num: int,
|
||||||
# --- Helper Functions ---
|
train_idx: np.ndarray,
|
||||||
|
val_idx: np.ndarray,
|
||||||
def load_config(config_path: Path) -> MainConfig:
|
test_idx: np.ndarray,
|
||||||
|
config: MainConfig,
|
||||||
|
full_df: pd.DataFrame,
|
||||||
|
output_base_dir: Path # Receives Path object from run_training_pipeline
|
||||||
|
) -> Tuple[Dict[str, float], Optional[float], Optional[Path], Optional[Path], Optional[Path], Optional[Path]]:
|
||||||
"""
|
"""
|
||||||
Load and validate configuration from YAML file using Pydantic.
|
Runs the pipeline for a single cross-validation fold.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config_path: Path to the YAML configuration file.
|
fold_num: The zero-based index of the current fold.
|
||||||
|
train_idx: Indices for the training set.
|
||||||
|
val_idx: Indices for the validation set.
|
||||||
|
test_idx: Indices for the test set.
|
||||||
|
config: The main configuration object.
|
||||||
|
full_df: The complete raw DataFrame.
|
||||||
|
output_base_dir: The base directory Path for saving results.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Validated MainConfig object.
|
A tuple containing:
|
||||||
|
- fold_metrics: Dictionary of test metrics for the fold (e.g., {'MAE': ..., 'RMSE': ...}).
|
||||||
Raises:
|
- best_val_score: The best validation score achieved during training (or None).
|
||||||
FileNotFoundError: If the config file doesn't exist.
|
- saved_model_path: Path to the best saved model checkpoint (or None).
|
||||||
yaml.YAMLError: If the file is not valid YAML.
|
- saved_target_scaler_path: Path to the saved target scaler (or None).
|
||||||
pydantic.ValidationError: If the config doesn't match the schema.
|
- saved_input_size_path: Path to the saved input size file (or None).
|
||||||
|
- saved_config_path: Path to the saved config file for this fold (or None).
|
||||||
"""
|
"""
|
||||||
if not config_path.is_file():
|
fold_start_time = time.perf_counter()
|
||||||
logger.error(f"Configuration file not found at: {config_path}")
|
fold_id = fold_num + 1 # User-facing fold number (1-based)
|
||||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
logger.info(f"--- Starting Fold {fold_id}/{config.cross_validation.n_splits} ---")
|
||||||
|
|
||||||
|
fold_output_dir = output_base_dir / f"fold_{fold_id:02d}"
|
||||||
|
fold_output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
logger.debug(f"Fold output directory: {fold_output_dir}")
|
||||||
|
|
||||||
|
fold_metrics: Dict[str, float] = {'MAE': np.nan, 'RMSE': np.nan} # Default in case of failure
|
||||||
|
best_val_score: Optional[float] = None
|
||||||
|
best_model_path_str: Optional[str] = None # Use a different name for the string from callback
|
||||||
|
|
||||||
|
# Variables to hold prediction results for plotting later
|
||||||
|
all_preds_scaled: Optional[np.ndarray] = None
|
||||||
|
all_targets_scaled: Optional[np.ndarray] = None
|
||||||
|
target_scaler: Optional[Union[StandardScaler, MinMaxScaler]] = None # Need to keep scaler reference
|
||||||
|
prediction_target_time_index_h1: Optional[pd.DatetimeIndex] = None
|
||||||
|
|
||||||
|
# Variables to store paths of saved artifacts
|
||||||
|
saved_model_path: Optional[Path] = None
|
||||||
|
saved_target_scaler_path: Optional[Path] = None
|
||||||
|
saved_input_size_path: Optional[Path] = None
|
||||||
|
saved_config_path: Optional[Path] = None
|
||||||
|
|
||||||
logger.info(f"Loading configuration from: {config_path}")
|
|
||||||
try:
|
try:
|
||||||
with open(config_path, 'r') as f:
|
# --- Per-Fold Data Preparation ---
|
||||||
config_dict = yaml.safe_load(f)
|
logger.info("Preparing data loaders for the fold...")
|
||||||
|
# Keep scaler and input_size references returned by prepare_fold_data_and_loaders
|
||||||
|
train_loader, val_loader, test_loader, target_scaler_fold, input_size = prepare_fold_data_and_loaders( # Renamed target_scaler
|
||||||
|
full_df=full_df,
|
||||||
|
train_idx=train_idx,
|
||||||
|
val_idx=val_idx,
|
||||||
|
test_idx=test_idx,
|
||||||
|
target_col=config.data.target_col, # Pass target col name explicitly
|
||||||
|
feature_config=config.features,
|
||||||
|
train_config=config.training,
|
||||||
|
eval_config=config.evaluation
|
||||||
|
)
|
||||||
|
target_scaler = target_scaler_fold # Store the scaler in the outer scope
|
||||||
|
logger.info(f"Data loaders prepared. Input size determined: {input_size}")
|
||||||
|
|
||||||
# Validate configuration using Pydantic model
|
# Save necessary items for potential later use (e.g., ensemble)
|
||||||
config = MainConfig(**config_dict)
|
# Capture the paths when saving
|
||||||
logger.info("Configuration loaded and validated successfully.")
|
saved_target_scaler_path = fold_output_dir / "target_scaler.pt"
|
||||||
return config
|
torch.save(target_scaler, saved_target_scaler_path)
|
||||||
except yaml.YAMLError as e:
|
torch.save(test_loader, fold_output_dir / "test_loader.pt") # Test loader might be large, consider if needed
|
||||||
logger.error(f"Error parsing YAML file {config_path}: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
except Exception as e: # Catches Pydantic validation errors too
|
|
||||||
logger.error(f"Error validating configuration {config_path}: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def set_seeds(seed: Optional[int] = 42) -> None:
|
# Save input size and capture path
|
||||||
"""
|
saved_input_size_path = fold_output_dir / "input_size.pt"
|
||||||
Set random seeds for reproducibility across libraries.
|
torch.save(input_size, saved_input_size_path)
|
||||||
|
|
||||||
Args:
|
# Save config for this fold (needed for reloading model) and capture path
|
||||||
seed: The seed value to use. If None, uses default 42.
|
config_dump = config.model_dump()
|
||||||
"""
|
saved_config_path = fold_output_dir / "config.yaml" # Capture the path before saving
|
||||||
if seed is None:
|
|
||||||
seed = 42
|
|
||||||
logger.warning(f"No seed provided, using default seed: {seed}")
|
|
||||||
else:
|
|
||||||
logger.info(f"Setting random seed: {seed}")
|
|
||||||
|
|
||||||
random.seed(seed)
|
with open(saved_config_path, 'w') as f:
|
||||||
np.random.seed(seed)
|
yaml.dump(config_dump, f, default_flow_style=False)
|
||||||
torch.manual_seed(seed)
|
|
||||||
# Ensure reproducibility for CUDA operations where possible
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.manual_seed(seed)
|
|
||||||
torch.cuda.manual_seed_all(seed) # For multi-GPU
|
|
||||||
# These settings can slow down training but improve reproducibility
|
|
||||||
# torch.backends.cudnn.deterministic = True
|
|
||||||
# torch.backends.cudnn.benchmark = False
|
|
||||||
# PyTorch Lightning seeding (optional, as we seed torch directly)
|
|
||||||
# pl.seed_everything(seed, workers=True) # workers=True ensures dataloader reproducibility
|
|
||||||
|
|
||||||
def aggregate_cv_metrics(all_fold_metrics: List[Dict[str, float]]) -> Dict[str, Dict[str, float]]:
|
|
||||||
"""
|
|
||||||
Calculate mean and standard deviation of metrics across folds.
|
|
||||||
Handles potential NaN values by ignoring them.
|
|
||||||
|
|
||||||
Args:
|
# --- Model Initialization ---
|
||||||
all_fold_metrics: A list where each element is a dictionary of
|
model = LSTMForecastLightningModule(
|
||||||
metrics for one fold (e.g., {'MAE': v1, 'RMSE': v2}).
|
model_config=config.model,
|
||||||
|
train_config=config.training,
|
||||||
|
input_size=input_size,
|
||||||
|
target_scaler=target_scaler_fold # Pass scaler during init
|
||||||
|
)
|
||||||
|
logger.info("LSTMForecastLightningModule initialized.")
|
||||||
|
|
||||||
Returns:
|
# --- PyTorch Lightning Callbacks ---
|
||||||
A dictionary where keys are metric names and values are dicts
|
# Ensure monitor_metric matches the exact name logged in model.py
|
||||||
containing 'mean' and 'std' for that metric across folds.
|
monitor_metric = "val_MeanAbsoluteError_Original_Scale" # Corrected metric name
|
||||||
Example: {'MAE': {'mean': m, 'std': s}, 'RMSE': {'mean': m2, 'std': s2}}
|
monitor_mode = "min"
|
||||||
"""
|
|
||||||
if not all_fold_metrics:
|
|
||||||
logger.warning("Received empty list for metric aggregation.")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
aggregated: Dict[str, Dict[str, float]] = {}
|
early_stop_callback = None
|
||||||
# Get metric names from the first valid fold's results
|
if config.training.early_stopping_patience is not None and config.training.early_stopping_patience > 0:
|
||||||
first_valid_metrics = next((m for m in all_fold_metrics if m), None)
|
early_stop_callback = EarlyStopping(
|
||||||
if not first_valid_metrics:
|
monitor=monitor_metric,
|
||||||
logger.warning("No valid fold metrics found for aggregation.")
|
min_delta=0.0001,
|
||||||
return {}
|
patience=config.training.early_stopping_patience,
|
||||||
metric_names = list(first_valid_metrics.keys())
|
verbose=True,
|
||||||
|
mode=monitor_mode
|
||||||
|
)
|
||||||
|
logger.info(f"Enabled EarlyStopping: monitor='{monitor_metric}', patience={config.training.early_stopping_patience}")
|
||||||
|
|
||||||
for metric in metric_names:
|
checkpoint_callback = ModelCheckpoint(
|
||||||
# Collect values for this metric across all folds, ignoring NaNs
|
dirpath=fold_output_dir / "checkpoints",
|
||||||
values = [fold_metrics.get(metric) for fold_metrics in all_fold_metrics if fold_metrics and metric in fold_metrics]
|
filename=f"best_model_fold_{fold_id}",
|
||||||
valid_values = [v for v in values if v is not None and not np.isnan(v)]
|
save_top_k=1,
|
||||||
|
monitor=monitor_metric,
|
||||||
|
mode=monitor_mode,
|
||||||
|
verbose=True
|
||||||
|
)
|
||||||
|
logger.info(f"Enabled ModelCheckpoint: monitor='{monitor_metric}', mode='{monitor_mode}'")
|
||||||
|
|
||||||
|
lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
||||||
|
|
||||||
|
callbacks = [checkpoint_callback, lr_monitor]
|
||||||
|
if early_stop_callback:
|
||||||
|
callbacks.append(early_stop_callback)
|
||||||
|
|
||||||
|
# --- PyTorch Lightning Logger ---
|
||||||
|
# Log to a subdir specific to the fold, relative to output_base_dir
|
||||||
|
log_dir = output_base_dir / f"fold_{fold_id:02d}" / "training_logs"
|
||||||
|
pl_logger = CSVLogger(save_dir=str(log_dir.parent), name=log_dir.name, version='') # Use name for subdir, version='' to avoid 'version_0'
|
||||||
|
logger.info(f"Using CSVLogger, logs will be saved in: {pl_logger.log_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
# --- PyTorch Lightning Trainer ---
|
||||||
|
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
|
||||||
|
devices = 1 if accelerator == 'gpu' else None
|
||||||
|
precision = getattr(config.training, 'precision', 32)
|
||||||
|
|
||||||
|
trainer = pl.Trainer(
|
||||||
|
accelerator=accelerator,
|
||||||
|
devices=devices,
|
||||||
|
max_epochs=config.training.epochs,
|
||||||
|
callbacks=callbacks,
|
||||||
|
logger=pl_logger,
|
||||||
|
log_every_n_steps=max(1, len(train_loader)//10),
|
||||||
|
enable_progress_bar=True,
|
||||||
|
gradient_clip_val=getattr(config.training, 'gradient_clip_val', None),
|
||||||
|
precision=precision,
|
||||||
|
)
|
||||||
|
logger.info(f"Initialized PyTorch Lightning Trainer: accelerator='{accelerator}', devices={devices}, precision={precision}")
|
||||||
|
|
||||||
|
# --- Training ---
|
||||||
|
logger.info(f"Starting training for Fold {fold_id}...")
|
||||||
|
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
|
||||||
|
logger.info(f"Training finished for Fold {fold_id}.")
|
||||||
|
|
||||||
|
# Store best validation score and path
|
||||||
|
best_val_score_tensor = trainer.checkpoint_callback.best_model_score
|
||||||
|
# Capture the best model path reported by the checkpoint callback
|
||||||
|
best_model_path_str = trainer.checkpoint_callback.best_model_path # Capture the string path
|
||||||
|
best_val_score = best_val_score_tensor.item() if best_val_score_tensor is not None else None
|
||||||
|
|
||||||
|
if best_val_score is not None:
|
||||||
|
logger.info(f"Best validation score ({monitor_metric}) for Fold {fold_id}: {best_val_score:.4f}")
|
||||||
|
# Check if best_model_path was actually set by the callback
|
||||||
|
if best_model_path_str:
|
||||||
|
saved_model_path = Path(best_model_path_str) # Convert string to Path object and store
|
||||||
|
logger.info(f"Best model checkpoint path: {best_model_path_str}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"ModelCheckpoint callback did not report a best_model_path for Fold {fold_id}.")
|
||||||
|
|
||||||
if not valid_values:
|
|
||||||
logger.warning(f"No valid values found for metric '{metric}' across folds.")
|
|
||||||
mean_val = np.nan
|
|
||||||
std_val = np.nan
|
|
||||||
else:
|
else:
|
||||||
mean_val = float(np.mean(valid_values))
|
logger.warning(f"Could not retrieve best validation score/path for Fold {fold_id} (metric: {monitor_metric}). Evaluation might use last model.")
|
||||||
std_val = float(np.std(valid_values))
|
best_model_path_str = None # Ensure string path is None if no best score
|
||||||
logger.debug(f"Aggregated '{metric}': Mean={mean_val:.4f}, Std={std_val:.4f} from {len(valid_values)} folds.")
|
|
||||||
|
|
||||||
aggregated[metric] = {'mean': mean_val, 'std': std_val}
|
|
||||||
|
|
||||||
return aggregated
|
# --- Prediction on Test Set ---
|
||||||
|
logger.info(f"Starting prediction for Fold {fold_id} using {'best checkpoint' if saved_model_path else 'last model'}...")
|
||||||
|
# Use the best checkpoint path if available, otherwise use the in-memory model instance
|
||||||
|
ckpt_path_for_predict = str(saved_model_path) if saved_model_path else None # Use the saved Path object, convert to string for ckpt_path
|
||||||
|
|
||||||
|
|
||||||
|
prediction_results_list = trainer.predict(
|
||||||
|
model=model, # Use the in-memory model instance
|
||||||
|
dataloaders=test_loader,
|
||||||
|
ckpt_path=ckpt_path_for_predict # Specify checkpoint path if needed, though using model=model is typical
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Process Prediction Results & Get Time Index ---
|
||||||
|
if not prediction_results_list:
|
||||||
|
logger.error(f"Predict phase did not return any results for Fold {fold_id}. Check predict_step and logs.")
|
||||||
|
all_preds_scaled = None # Ensure these are None on failure
|
||||||
|
all_targets_scaled = None
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
all_preds_scaled = torch.cat([batch_res['preds_scaled'] for batch_res in prediction_results_list], dim=0).numpy()
|
||||||
|
n_predictions = len(all_preds_scaled)
|
||||||
|
|
||||||
|
if 'targets_scaled' in prediction_results_list[0]:
|
||||||
|
all_targets_scaled = torch.cat([batch_res['targets_scaled'] for batch_res in prediction_results_list], dim=0).numpy()
|
||||||
|
if len(all_targets_scaled) != n_predictions:
|
||||||
|
logger.error(f"Fold {fold_id}: Mismatch between number of predictions ({n_predictions}) and targets ({len(all_targets_scaled)}).")
|
||||||
|
raise ValueError("Prediction and target count mismatch during evaluation.")
|
||||||
|
else:
|
||||||
|
logger.error(f"Targets not found in prediction results for Fold {fold_id}. Cannot evaluate or plot original scale targets.")
|
||||||
|
all_targets_scaled = None
|
||||||
|
|
||||||
|
|
||||||
|
logger.info(f"Processing {n_predictions} prediction results for Fold {fold_id}...")
|
||||||
|
|
||||||
|
# --- Calculate Correct Time Index for Plotting (First Horizon) ---
|
||||||
|
prediction_target_time_index_h1_path = fold_output_dir / "prediction_target_time_index_h1.pt"
|
||||||
|
|
||||||
|
prediction_target_time_index_h1 = None
|
||||||
|
|
||||||
|
if test_idx is not None and config.features.forecast_horizon and len(config.features.forecast_horizon) > 0:
|
||||||
|
try:
|
||||||
|
test_block_index = full_df.index[test_idx]
|
||||||
|
seq_len = config.features.sequence_length
|
||||||
|
first_horizon = config.features.forecast_horizon[0]
|
||||||
|
|
||||||
|
target_indices_h1 = test_idx + seq_len + first_horizon - 1
|
||||||
|
|
||||||
|
valid_target_indices_h1_mask = target_indices_h1 < len(full_df)
|
||||||
|
valid_target_indices_h1 = target_indices_h1[valid_target_indices_h1_mask]
|
||||||
|
|
||||||
|
if len(valid_target_indices_h1) >= n_predictions: # Should be exactly n_predictions if no indices were out of bounds
|
||||||
|
prediction_target_time_index_h1 = full_df.index[valid_target_indices_h1[:n_predictions]]
|
||||||
|
if len(prediction_target_time_index_h1) != n_predictions:
|
||||||
|
logger.warning(f"Fold {fold_id}: Calculated target time index length ({len(prediction_target_time_index_h1)}) "
|
||||||
|
f"does not match prediction count ({n_predictions}). Plotting x-axis might be misaligned.")
|
||||||
|
prediction_target_time_index_h1 = None
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.warning(f"Fold {fold_id}: Cannot calculate target time index for h1; insufficient valid indices ({len(valid_target_indices_h1)} < {n_predictions}).")
|
||||||
|
prediction_target_time_index_h1 = None
|
||||||
|
|
||||||
|
|
||||||
|
# Save the calculated index if it's valid and evaluation plots are enabled
|
||||||
|
if prediction_target_time_index_h1 is not None and not prediction_target_time_index_h1.empty and config.evaluation.save_plots:
|
||||||
|
try:
|
||||||
|
torch.save(prediction_target_time_index_h1, prediction_target_time_index_h1_path)
|
||||||
|
logger.debug(f"Saved prediction target time index for h1 to {prediction_target_time_index_h1_path}")
|
||||||
|
except Exception as save_e:
|
||||||
|
logger.warning(f"Failed to save prediction target time index file {prediction_target_time_index_h1_path}: {save_e}")
|
||||||
|
|
||||||
|
elif prediction_target_time_index_h1_path.exists():
|
||||||
|
try:
|
||||||
|
prediction_target_time_index_h1_path.unlink()
|
||||||
|
logger.debug("Removed outdated prediction target time index h1 file.")
|
||||||
|
except OSError as e:
|
||||||
|
logger.warning(f"Could not remove outdated prediction target index h1 file {prediction_target_time_index_h1_path}: {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Fold {fold_id}: Error calculating or saving target time index for plotting: {e}", exc_info=True)
|
||||||
|
prediction_target_time_index_h1 = None
|
||||||
|
else:
|
||||||
|
logger.warning(f"Fold {fold_id}: Skipping target time index calculation (missing test_idx, forecast_horizon, or empty list).")
|
||||||
|
if prediction_target_time_index_h1_path.exists():
|
||||||
|
try:
|
||||||
|
prediction_target_time_index_h1_path.unlink()
|
||||||
|
logger.debug("Removed outdated prediction target time index h1 file as calculation was skipped.")
|
||||||
|
except OSError as e:
|
||||||
|
logger.warning(f"Could not remove outdated prediction target index h1 file {prediction_target_time_index_h1_path}: {e}")
|
||||||
|
# --- End Index Calculation and Saving ---
|
||||||
|
|
||||||
|
|
||||||
|
# --- Evaluation ---
|
||||||
|
if all_targets_scaled is not None: # Only evaluate if targets are available
|
||||||
|
fold_metrics = evaluate_fold_predictions(
|
||||||
|
y_true_scaled=all_targets_scaled, # Pass the (N, H) array
|
||||||
|
y_pred_scaled=all_preds_scaled, # Pass the (N, H) array
|
||||||
|
target_scaler=target_scaler,
|
||||||
|
eval_config=config.evaluation,
|
||||||
|
fold_num=fold_num, # Pass zero-based index
|
||||||
|
output_dir=str(fold_output_dir),
|
||||||
|
plot_subdir="plots",
|
||||||
|
# Pass the calculated index for the targets being plotted (h1 reference)
|
||||||
|
prediction_time_index=prediction_target_time_index_h1, # Use the calculated index here (for h1)
|
||||||
|
forecast_horizons=config.features.forecast_horizon, # Pass the list of horizons
|
||||||
|
plot_title_prefix=f"CV Fold {fold_id}"
|
||||||
|
)
|
||||||
|
save_results(fold_metrics, fold_output_dir / "test_metrics.json")
|
||||||
|
else:
|
||||||
|
logger.error(f"Skipping evaluation for Fold {fold_id} due to missing targets.")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Multi-Horizon Plotting ---
|
||||||
|
if config.evaluation.save_plots and all_preds_scaled is not None and all_targets_scaled is not None and prediction_target_time_index_h1 is not None and target_scaler is not None:
|
||||||
|
logger.info(f"Generating multi-horizon plot for Fold {fold_id}...")
|
||||||
|
try:
|
||||||
|
multi_horizon_plot_path = fold_output_dir / "plots" / "multi_horizon_forecast.png"
|
||||||
|
# Need to import save_plot function if it's not already imported
|
||||||
|
# from forecasting_model.io.plotting import save_plot # Ensure this import is present if needed
|
||||||
|
fig = create_multi_horizon_time_series_plot(
|
||||||
|
y_true_scaled_all_horizons=all_targets_scaled,
|
||||||
|
y_pred_scaled_all_horizons=all_preds_scaled,
|
||||||
|
target_scaler=target_scaler,
|
||||||
|
prediction_time_index_h1=prediction_target_time_index_h1,
|
||||||
|
forecast_horizons=config.features.forecast_horizon,
|
||||||
|
title=f"Fold {fold_id} Multi-Horizon Forecast",
|
||||||
|
max_points=1000 # Limit points for clarity
|
||||||
|
)
|
||||||
|
# Check if save_plot is available or use fig.savefig()
|
||||||
|
try:
|
||||||
|
save_plot(fig, multi_horizon_plot_path)
|
||||||
|
except NameError:
|
||||||
|
# Fallback if save_plot is not defined/imported
|
||||||
|
fig.savefig(multi_horizon_plot_path)
|
||||||
|
plt.close(fig) # Close the figure after saving
|
||||||
|
logger.warning("Using fig.savefig as save_plot function was not found.")
|
||||||
|
|
||||||
|
except Exception as plot_e:
|
||||||
|
logger.error(f"Fold {fold_id}: Failed to generate multi-horizon plot: {plot_e}", exc_info=True)
|
||||||
|
elif config.evaluation.save_plots:
|
||||||
|
logger.warning(f"Fold {fold_id}: Skipping multi-horizon plot due to missing data (preds, targets, time index, or scaler).")
|
||||||
|
|
||||||
|
|
||||||
|
except KeyError as e:
|
||||||
|
logger.error(f"KeyError processing prediction results for Fold {fold_id}: Missing key {e}. Check predict_step return format.", exc_info=True)
|
||||||
|
except ValueError as e: # Catch specific error from above
|
||||||
|
logger.error(f"ValueError processing prediction results for Fold {fold_id}: {e}", exc_info=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing prediction results for Fold {fold_id}: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# --- Plot Loss Curve for Fold ---
|
||||||
|
try:
|
||||||
|
actual_log_dir = Path(pl_logger.log_dir) / pl_logger.name # Should be .../fold_XX/training_logs
|
||||||
|
metrics_file_path = actual_log_dir / "metrics.csv"
|
||||||
|
|
||||||
|
if metrics_file_path.is_file():
|
||||||
|
plot_loss_curve_from_csv(
|
||||||
|
metrics_csv_path=metrics_file_path,
|
||||||
|
output_path=fold_output_dir / "plots" / "loss_curve.png", # Save in plots subdir
|
||||||
|
title=f"Fold {fold_id} Training Progression",
|
||||||
|
train_loss_col='train_loss',
|
||||||
|
val_loss_col='val_loss' # This function handles fallback
|
||||||
|
)
|
||||||
|
logger.info(f"Loss curve plot saved for Fold {fold_id} to {fold_output_dir / 'plots' / 'loss_curve.png'}.")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Fold {fold_id}: Could not find metrics.csv at {metrics_file_path} for loss curve plot.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Fold {fold_id}: Failed to generate loss curve plot: {e}", exc_info=True)
|
||||||
|
# --- End Loss Curve Plotting ---
|
||||||
|
|
||||||
def save_results(results: Dict, filename: Path):
|
|
||||||
"""Save dictionary results to a JSON file."""
|
|
||||||
try:
|
|
||||||
filename.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(filename, 'w') as f:
|
|
||||||
json.dump(results, f, indent=4)
|
|
||||||
logger.info(f"Saved results to {filename}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to save results to {filename}: {e}", exc_info=True)
|
logger.error(f"An error occurred during Fold {fold_id} pipeline: {e}", exc_info=True)
|
||||||
|
# Ensure paths are None if an error occurs before they are set
|
||||||
|
if saved_model_path is None: saved_model_path = None
|
||||||
|
if saved_target_scaler_path is None: saved_target_scaler_path = None
|
||||||
|
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up GPU memory explicitly
|
||||||
|
del model, trainer # Ensure objects are deleted before clearing cache
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
logger.debug("Cleared CUDA cache.")
|
||||||
|
|
||||||
|
# Delete loaders explicitly if they might hold references
|
||||||
|
del train_loader, val_loader, test_loader
|
||||||
|
|
||||||
|
fold_end_time = time.perf_counter()
|
||||||
|
logger.info(f"--- Finished Fold {fold_id} in {fold_end_time - fold_start_time:.2f} seconds ---")
|
||||||
|
|
||||||
|
# Return the calculated fold metrics, best validation score, and saved artifact paths
|
||||||
|
return fold_metrics, best_val_score, saved_model_path, saved_target_scaler_path, saved_input_size_path, saved_config_path
|
||||||
|
|
||||||
|
|
||||||
# --- Main Training & Evaluation Function ---
|
# --- Main Training & Evaluation Function ---
|
||||||
def run_training_pipeline(config: MainConfig, output_base_dir: Path):
|
def run_training_pipeline(config: MainConfig, output_base_dir: Path):
|
||||||
"""Runs the full cross-validation training and evaluation pipeline."""
|
"""Runs the full training and evaluation pipeline based on config flags."""
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
logger.info(f"Starting training pipeline. Results will be saved to: {output_base_dir}")
|
||||||
|
output_base_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# --- Data Loading ---
|
# --- Data Loading ---
|
||||||
try:
|
try:
|
||||||
df = load_raw_data(config.data)
|
df = load_raw_data(config.data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.critical(f"Failed to load raw data: {e}", exc_info=True)
|
logger.critical(f"Failed to load raw data: {e}", exc_info=True)
|
||||||
sys.exit(1) # Cannot proceed without data
|
sys.exit(1)
|
||||||
|
|
||||||
# --- Cross-Validation Setup ---
|
|
||||||
try:
|
|
||||||
cv_splitter = TimeSeriesCrossValidationSplitter(config.cross_validation, len(df))
|
|
||||||
except ValueError as e:
|
|
||||||
logger.critical(f"Failed to initialize CV splitter: {e}", exc_info=True)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
# --- Initialize results ---
|
||||||
all_fold_test_metrics: List[Dict[str, float]] = []
|
all_fold_test_metrics: List[Dict[str, float]] = []
|
||||||
all_fold_best_val_scores: Dict[int, Optional[float]] = {} # Store best val score per fold
|
all_fold_best_val_scores: Dict[int, Optional[float]] = {}
|
||||||
|
aggregated_metrics: Dict = {}
|
||||||
|
final_results: Dict = {} # Initialize empty results dict
|
||||||
|
|
||||||
# --- Cross-Validation Loop ---
|
# --- Cross-Validation Loop ---
|
||||||
logger.info(f"Starting {config.cross_validation.n_splits}-Fold Cross-Validation...")
|
if config.run_cross_validation:
|
||||||
for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split()):
|
logger.info(f"Starting {config.cross_validation.n_splits}-Fold Cross-Validation...")
|
||||||
fold_start_time = time.perf_counter()
|
|
||||||
fold_id = fold_num + 1
|
|
||||||
logger.info(f"--- Starting Fold {fold_id}/{config.cross_validation.n_splits} ---")
|
|
||||||
|
|
||||||
fold_output_dir = output_base_dir / f"fold_{fold_id:02d}"
|
|
||||||
fold_output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
logger.debug(f"Fold output directory: {fold_output_dir}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# --- Per-Fold Data Preparation ---
|
cv_splitter = TimeSeriesCrossValidationSplitter(config.cross_validation, len(df))
|
||||||
logger.info("Preparing data loaders for the fold...")
|
except ValueError as e:
|
||||||
train_loader, val_loader, test_loader, target_scaler, input_size = prepare_fold_data_and_loaders(
|
logger.critical(f"Failed to initialize CV splitter: {e}", exc_info=True)
|
||||||
full_df=df,
|
sys.exit(1)
|
||||||
|
|
||||||
|
for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split()):
|
||||||
|
# Unpack the two new return values from run_single_fold
|
||||||
|
fold_metrics, best_val_score, saved_model_path, saved_target_scaler_path, _input_size_path, _config_path = run_single_fold(
|
||||||
|
fold_num=fold_num,
|
||||||
train_idx=train_idx,
|
train_idx=train_idx,
|
||||||
val_idx=val_idx,
|
val_idx=val_idx,
|
||||||
test_idx=test_idx,
|
test_idx=test_idx,
|
||||||
target_col=config.data.target_col, # Pass target col name explicitly
|
config=config,
|
||||||
feature_config=config.features,
|
full_df=df,
|
||||||
train_config=config.training,
|
output_base_dir=output_base_dir
|
||||||
eval_config=config.evaluation
|
|
||||||
)
|
)
|
||||||
logger.info(f"Data loaders prepared. Input size determined: {input_size}")
|
|
||||||
|
|
||||||
# --- Model Initialization ---
|
|
||||||
# Pass input_size directly, ModelConfig no longer holds it.
|
|
||||||
# Ensure forecast horizon is consistent (checked in MainConfig validation)
|
|
||||||
current_model_config = config.model # Use the validated model config
|
|
||||||
|
|
||||||
model = LSTMForecastLightningModule(
|
|
||||||
model_config=current_model_config, # Does not contain input_size
|
|
||||||
train_config=config.training,
|
|
||||||
input_size=input_size, # Pass the dynamically determined input_size
|
|
||||||
target_scaler=target_scaler # Pass the fold-specific scaler
|
|
||||||
)
|
|
||||||
logger.info("LSTMForecastLightningModule initialized.")
|
|
||||||
|
|
||||||
# --- PyTorch Lightning Callbacks ---
|
|
||||||
# Monitor the validation MAE on the original scale (logged by LightningModule)
|
|
||||||
monitor_metric = "val_mae_orig_scale"
|
|
||||||
monitor_mode = "min"
|
|
||||||
|
|
||||||
early_stop_callback = None
|
|
||||||
if config.training.early_stopping_patience is not None and config.training.early_stopping_patience > 0:
|
|
||||||
early_stop_callback = EarlyStopping(
|
|
||||||
monitor=monitor_metric,
|
|
||||||
min_delta=0.0001, # Minimum change to qualify as improvement
|
|
||||||
patience=config.training.early_stopping_patience,
|
|
||||||
verbose=True,
|
|
||||||
mode=monitor_mode
|
|
||||||
)
|
|
||||||
logger.info(f"Enabled EarlyStopping: monitor='{monitor_metric}', patience={config.training.early_stopping_patience}")
|
|
||||||
|
|
||||||
# Checkpoint callback to save the best model based on validation metric
|
|
||||||
checkpoint_callback = ModelCheckpoint(
|
|
||||||
dirpath=fold_output_dir / "checkpoints",
|
|
||||||
filename=f"best_model_fold_{fold_id}", # {{epoch}}-{{val_loss:.2f}} etc. possible
|
|
||||||
save_top_k=1,
|
|
||||||
monitor=monitor_metric,
|
|
||||||
mode=monitor_mode,
|
|
||||||
verbose=True
|
|
||||||
)
|
|
||||||
logger.info(f"Enabled ModelCheckpoint: monitor='{monitor_metric}', mode='{monitor_mode}'")
|
|
||||||
|
|
||||||
# Learning rate monitor callback
|
|
||||||
lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
|
||||||
|
|
||||||
callbacks = [checkpoint_callback, lr_monitor]
|
|
||||||
if early_stop_callback:
|
|
||||||
callbacks.append(early_stop_callback)
|
|
||||||
|
|
||||||
# --- PyTorch Lightning Logger ---
|
|
||||||
# Log metrics to a CSV file within the fold directory
|
|
||||||
pl_logger = CSVLogger(save_dir=str(output_base_dir), name=f"fold_{fold_id:02d}", version='logs')
|
|
||||||
logger.info(f"Using CSVLogger, logs will be saved in: {pl_logger.log_dir}")
|
|
||||||
|
|
||||||
# --- PyTorch Lightning Trainer ---
|
|
||||||
# Determine accelerator and devices based on PyTorch check
|
|
||||||
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
|
|
||||||
devices = 1 if accelerator == 'gpu' else None # Or specify specific GPU IDs [0], [1] etc.
|
|
||||||
precision = getattr(config.training, 'precision', 32) # Default to 32-bit
|
|
||||||
|
|
||||||
trainer = pl.Trainer(
|
|
||||||
accelerator=accelerator,
|
|
||||||
devices=devices,
|
|
||||||
max_epochs=config.training.epochs,
|
|
||||||
callbacks=callbacks,
|
|
||||||
logger=pl_logger,
|
|
||||||
log_every_n_steps=max(1, len(train_loader)//10), # Log ~10 times per epoch
|
|
||||||
enable_progress_bar=True, # Set to False for less verbose runs (e.g., HPO)
|
|
||||||
gradient_clip_val=getattr(config.training, 'gradient_clip_val', None),
|
|
||||||
precision=precision,
|
|
||||||
# deterministic=True, # For stricter reproducibility (can slow down)
|
|
||||||
)
|
|
||||||
logger.info(f"Initialized PyTorch Lightning Trainer: accelerator='{accelerator}', devices={devices}, precision={precision}")
|
|
||||||
|
|
||||||
# --- Training ---
|
|
||||||
logger.info(f"Starting training for Fold {fold_id}...")
|
|
||||||
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
|
|
||||||
logger.info(f"Training finished for Fold {fold_id}.")
|
|
||||||
|
|
||||||
# Store best validation score for this fold
|
|
||||||
best_val_score = trainer.checkpoint_callback.best_model_score
|
|
||||||
best_model_path = trainer.checkpoint_callback.best_model_path
|
|
||||||
all_fold_best_val_scores[fold_id] = best_val_score.item() if best_val_score else None
|
|
||||||
if best_val_score is not None:
|
|
||||||
logger.info(f"Best validation score ({monitor_metric}) for Fold {fold_id}: {all_fold_best_val_scores[fold_id]:.4f}")
|
|
||||||
logger.info(f"Best model checkpoint path: {best_model_path}")
|
|
||||||
else:
|
|
||||||
logger.warning(f"Could not retrieve best validation score/path for Fold {fold_id} (metric: {monitor_metric}). Evaluation might use last model.")
|
|
||||||
best_model_path = None # Ensure evaluation doesn't try to load 'best' if checkpointing failed
|
|
||||||
|
|
||||||
# --- Prediction on Test Set ---
|
|
||||||
# Use trainer.predict() to get model outputs
|
|
||||||
logger.info(f"Starting prediction for Fold {fold_id} using best checkpoint...")
|
|
||||||
# predict_step returns dict {'preds_scaled': ..., 'targets_scaled': ...}
|
|
||||||
# We pass the test_loader here, which yields (x, y) pairs, so predict_step will include targets
|
|
||||||
prediction_results_list = trainer.predict(
|
|
||||||
# model=model, # Not needed if using ckpt_path
|
|
||||||
ckpt_path=best_model_path if best_model_path else 'last', # Load best model or last if best failed
|
|
||||||
dataloaders=test_loader
|
|
||||||
# return_predictions=True # Default is True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if prediction returned results
|
|
||||||
if not prediction_results_list:
|
|
||||||
logger.error(f"Predict phase did not return any results for Fold {fold_id}. Check predict_step and logs.")
|
|
||||||
fold_metrics = {'MAE': np.nan, 'RMSE': np.nan}
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
# Concatenate predictions and targets from predict_step results
|
|
||||||
all_preds_scaled = torch.cat([batch_res['preds_scaled'] for batch_res in prediction_results_list], dim=0).numpy()
|
|
||||||
# Check if targets were included (they should be if using test_loader)
|
|
||||||
if 'targets_scaled' in prediction_results_list[0]:
|
|
||||||
all_targets_scaled = torch.cat([batch_res['targets_scaled'] for batch_res in prediction_results_list], dim=0).numpy()
|
|
||||||
else:
|
|
||||||
# This case shouldn't happen if using test_loader, but good safeguard
|
|
||||||
logger.error(f"Targets not found in prediction results for Fold {fold_id}. Cannot evaluate.")
|
|
||||||
raise ValueError("Targets missing from prediction results.")
|
|
||||||
|
|
||||||
|
|
||||||
# --- Final Evaluation & Plotting ---
|
|
||||||
logger.info(f"Processing prediction results for Fold {fold_id}...")
|
|
||||||
fold_metrics = evaluate_fold_predictions(
|
|
||||||
y_true_scaled=all_targets_scaled,
|
|
||||||
y_pred_scaled=all_preds_scaled,
|
|
||||||
target_scaler=target_scaler, # Use the scaler from this fold
|
|
||||||
eval_config=config.evaluation,
|
|
||||||
fold_num=fold_num, # Pass zero-based index
|
|
||||||
output_dir=output_base_dir, # Base dir for saving plots etc.
|
|
||||||
# time_index=df.iloc[test_idx].index # Pass time index if needed
|
|
||||||
)
|
|
||||||
# Save fold metrics
|
|
||||||
save_results(fold_metrics, fold_output_dir / "test_metrics.json")
|
|
||||||
|
|
||||||
except KeyError as e:
|
|
||||||
logger.error(f"KeyError processing prediction results for Fold {fold_id}: Missing key {e}. Check predict_step return format.", exc_info=True)
|
|
||||||
fold_metrics = {'MAE': np.nan, 'RMSE': np.nan}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing prediction results for Fold {fold_id}: {e}", exc_info=True)
|
|
||||||
fold_metrics = {'MAE': np.nan, 'RMSE': np.nan}
|
|
||||||
|
|
||||||
all_fold_test_metrics.append(fold_metrics)
|
all_fold_test_metrics.append(fold_metrics)
|
||||||
|
all_fold_best_val_scores[fold_num + 1] = best_val_score
|
||||||
|
|
||||||
# --- (Optional) Log final test metrics using trainer.test() ---
|
# --- Aggregation and Reporting for CV ---
|
||||||
# If you want the metrics logged by test_step aggregated, call test now.
|
logger.info("Cross-validation finished. Aggregating results...")
|
||||||
# logger.info(f"Logging final test metrics via trainer.test() for Fold {fold_id}...")
|
aggregated_metrics = aggregate_cv_metrics(all_fold_test_metrics)
|
||||||
# try:
|
final_results['aggregated_test_metrics'] = aggregated_metrics
|
||||||
# trainer.test(ckpt_path=best_model_path if best_model_path else 'last', dataloaders=test_loader, verbose=False)
|
final_results['per_fold_test_metrics'] = all_fold_test_metrics
|
||||||
# except Exception as e:
|
final_results['per_fold_best_val_scores'] = all_fold_best_val_scores
|
||||||
# logger.warning(f"trainer.test() call failed for Fold {fold_id}: {e}")
|
# Save intermediate results after CV
|
||||||
|
save_results(final_results, output_base_dir / "aggregated_cv_results.json")
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Catch errors during the fold processing (data prep, training, prediction, eval)
|
|
||||||
logger.error(f"An error occurred during Fold {fold_id} pipeline: {e}", exc_info=True)
|
|
||||||
all_fold_test_metrics.append({'MAE': np.nan, 'RMSE': np.nan})
|
|
||||||
|
|
||||||
|
|
||||||
# --- Cleanup per fold ---
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
logger.debug("Cleared CUDA cache.")
|
|
||||||
|
|
||||||
fold_end_time = time.perf_counter()
|
|
||||||
logger.info(f"--- Finished Fold {fold_id} in {fold_end_time - fold_start_time:.2f} seconds ---")
|
|
||||||
|
|
||||||
|
|
||||||
# --- Aggregation and Final Reporting ---
|
|
||||||
logger.info("Cross-validation finished. Aggregating results...")
|
|
||||||
aggregated_metrics = aggregate_cv_metrics(all_fold_test_metrics)
|
|
||||||
|
|
||||||
# Save aggregated results
|
|
||||||
final_results = {
|
|
||||||
'aggregated_test_metrics': aggregated_metrics,
|
|
||||||
'per_fold_test_metrics': all_fold_test_metrics,
|
|
||||||
'per_fold_best_val_scores': all_fold_best_val_scores,
|
|
||||||
}
|
|
||||||
save_results(final_results, output_base_dir / "aggregated_cv_results.json")
|
|
||||||
|
|
||||||
|
|
||||||
# Log final results
|
|
||||||
logger.info("--- Aggregated Cross-Validation Test Results ---")
|
|
||||||
if aggregated_metrics:
|
|
||||||
for metric, stats in aggregated_metrics.items():
|
|
||||||
logger.info(f"{metric}: {stats['mean']:.4f} ± {stats['std']:.4f}")
|
|
||||||
else:
|
else:
|
||||||
logger.warning("No metrics available for aggregation.")
|
logger.info("Skipping Cross-Validation loop as per config.")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Ensemble Evaluation ---
|
||||||
|
if config.run_ensemble_evaluation:
|
||||||
|
# The validator in MainConfig already ensures run_cross_validation is also true here
|
||||||
|
logger.info("Starting ensemble evaluation...")
|
||||||
|
try:
|
||||||
|
ensemble_results = run_ensemble_evaluation(
|
||||||
|
config=config, # Pass config for context if needed by sub-functions
|
||||||
|
output_base_dir=output_base_dir
|
||||||
|
)
|
||||||
|
if ensemble_results:
|
||||||
|
logger.info("Ensemble evaluation completed successfully")
|
||||||
|
final_results['ensemble_results'] = ensemble_results
|
||||||
|
save_results(final_results, output_base_dir / "aggregated_cv_results.json")
|
||||||
|
else:
|
||||||
|
logger.warning("No ensemble results were generated (potentially < 2 folds available).")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during ensemble evaluation: {e}", exc_info=True)
|
||||||
|
else:
|
||||||
|
logger.info("Skipping Ensemble evaluation as per config.")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Classic Training Run ---
|
||||||
|
if config.run_classic_training:
|
||||||
|
logger.info("Starting classic training run...")
|
||||||
|
classic_output_dir = output_base_dir / "classic_run" # Define dir for logging path
|
||||||
|
try:
|
||||||
|
# Call the original classic training function directly
|
||||||
|
classic_metrics = run_classic_training(
|
||||||
|
config=config,
|
||||||
|
full_df=df,
|
||||||
|
output_base_dir=output_base_dir # It creates classic_run subdir internally
|
||||||
|
)
|
||||||
|
if classic_metrics:
|
||||||
|
logger.info(f"Classic training run completed. Test Metrics: {classic_metrics}")
|
||||||
|
final_results['classic_training_results'] = classic_metrics
|
||||||
|
save_results(final_results, output_base_dir / "aggregated_cv_results.json")
|
||||||
|
|
||||||
|
# --- Plot Loss Curve for Classic Run ---
|
||||||
|
try:
|
||||||
|
classic_log_dir = classic_output_dir / "training_logs"
|
||||||
|
metrics_file = classic_log_dir / "metrics.csv"
|
||||||
|
version_dirs = list(classic_log_dir.glob("version_*"))
|
||||||
|
if version_dirs:
|
||||||
|
metrics_file = version_dirs[0] / "metrics.csv"
|
||||||
|
|
||||||
|
if metrics_file.is_file():
|
||||||
|
plot_loss_curve_from_csv(
|
||||||
|
metrics_csv_path=metrics_file,
|
||||||
|
output_path=classic_output_dir / "loss_curve.png",
|
||||||
|
title="Classic Run Training Progression",
|
||||||
|
train_loss_col='train_loss', # Changed from 'train_loss_epoch'
|
||||||
|
val_loss_col='val_loss' # Check your logged metric names
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Classic Run: Could not find metrics.csv at {metrics_file} for loss curve plot.")
|
||||||
|
except Exception as plot_e:
|
||||||
|
logger.error(f"Classic Run: Failed to generate loss curve plot: {plot_e}", exc_info=True)
|
||||||
|
# --- End Classic Loss Plotting ---
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.warning("Classic training run did not produce metrics.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during classic training run: {e}", exc_info=True)
|
||||||
|
else:
|
||||||
|
logger.info("Skipping Classic training run as per config.")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Final Logging Summary ---
|
||||||
|
logger.info("--- Final Summary ---")
|
||||||
|
# Log aggregated CV results if they exist
|
||||||
|
if 'aggregated_test_metrics' in final_results and final_results['aggregated_test_metrics']:
|
||||||
|
logger.info("--- Aggregated Cross-Validation Test Results ---")
|
||||||
|
for metric, stats in final_results['aggregated_test_metrics'].items():
|
||||||
|
logger.info(f"{metric}: {stats.get('mean', np.nan):.4f} ± {stats.get('std', np.nan):.4f}")
|
||||||
|
elif config.run_cross_validation:
|
||||||
|
logger.warning("Cross-validation was run, but no metrics were aggregated.")
|
||||||
|
|
||||||
|
# Log aggregated ensemble results if they exist
|
||||||
|
if 'ensemble_results' in final_results and final_results['ensemble_results']:
|
||||||
|
logger.info("--- Aggregated Ensemble Test Results (Mean over Test Folds) ---")
|
||||||
|
agg_ensemble = {}
|
||||||
|
for fold_res in final_results['ensemble_results'].values():
|
||||||
|
if isinstance(fold_res, dict):
|
||||||
|
for method, metrics in fold_res.items():
|
||||||
|
if method not in agg_ensemble: agg_ensemble[method] = {}
|
||||||
|
if isinstance(metrics, dict):
|
||||||
|
for m_name, m_val in metrics.items():
|
||||||
|
if m_name not in agg_ensemble[method]: agg_ensemble[method][m_name] = []
|
||||||
|
agg_ensemble[method][m_name].append(m_val)
|
||||||
|
else: logger.warning(f"Skipping non-dict metrics for ensemble method '{method}'.")
|
||||||
|
else: logger.warning("Skipping non-dict fold result in ensemble aggregation.")
|
||||||
|
|
||||||
|
for method, metrics_data in agg_ensemble.items():
|
||||||
|
logger.info(f" Ensemble Method: {method}")
|
||||||
|
for m_name, values in metrics_data.items():
|
||||||
|
valid_vals = [v for v in values if v is not None and not np.isnan(v)]
|
||||||
|
if valid_vals: logger.info(f" {m_name}: {np.mean(valid_vals):.4f} ± {np.std(valid_vals):.4f}")
|
||||||
|
else: logger.info(f" {m_name}: N/A")
|
||||||
|
|
||||||
|
|
||||||
|
# Log classic results if they exist
|
||||||
|
if 'classic_training_results' in final_results and final_results['classic_training_results']:
|
||||||
|
logger.info("--- Classic Training Test Results ---")
|
||||||
|
classic_res = final_results['classic_training_results']
|
||||||
|
for metric, value in classic_res.items():
|
||||||
|
logger.info(f"{metric}: {value:.4f}")
|
||||||
|
|
||||||
logger.info("-------------------------------------------------")
|
logger.info("-------------------------------------------------")
|
||||||
|
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
@ -434,12 +594,6 @@ def run():
|
|||||||
"""Main execution function."""
|
"""Main execution function."""
|
||||||
args = parse_arguments()
|
args = parse_arguments()
|
||||||
config_path = Path(args.config)
|
config_path = Path(args.config)
|
||||||
output_dir = Path(args.output_dir)
|
|
||||||
|
|
||||||
# Adjust log level if debug flag is set
|
|
||||||
if args.debug:
|
|
||||||
logger.setLevel(logging.DEBUG)
|
|
||||||
logger.debug("# --- Debug mode enabled. --- #")
|
|
||||||
|
|
||||||
# --- Configuration Loading ---
|
# --- Configuration Loading ---
|
||||||
try:
|
try:
|
||||||
@ -448,10 +602,20 @@ def run():
|
|||||||
# Error already logged in load_config
|
# Error already logged in load_config
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# --- Seed Setting ---
|
# --- Setup based on Config ---
|
||||||
# Use command-line seed if provided, otherwise use config seed
|
# 1. Set Log Level
|
||||||
seed = args.seed if args.seed is not None else getattr(config, 'random_seed', 42)
|
log_level_name = config.log_level.upper()
|
||||||
set_seeds(seed)
|
log_level = getattr(logging, log_level_name, logging.INFO)
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
logger.info(f"Log level set to: {log_level_name}")
|
||||||
|
if log_level == logging.DEBUG:
|
||||||
|
logger.debug("# --- Debug mode enabled via config. --- #")
|
||||||
|
|
||||||
|
# 2. Set Seed
|
||||||
|
set_seeds(config.random_seed)
|
||||||
|
|
||||||
|
# 3. Determine Output Directory
|
||||||
|
output_dir = Path(config.output_dir)
|
||||||
|
|
||||||
# --- Pipeline Execution ---
|
# --- Pipeline Execution ---
|
||||||
try:
|
try:
|
||||||
@ -459,7 +623,7 @@ def run():
|
|||||||
|
|
||||||
except SystemExit as e:
|
except SystemExit as e:
|
||||||
logger.warning(f"Pipeline exited with code {e.code}.")
|
logger.warning(f"Pipeline exited with code {e.code}.")
|
||||||
sys.exit(e.code) # Propagate exit code
|
sys.exit(e.code)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.critical(f"An critical error occurred during pipeline execution: {e}", exc_info=True)
|
logger.critical(f"An critical error occurred during pipeline execution: {e}", exc_info=True)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
123
main.py
123
main.py
@ -1,123 +0,0 @@
|
|||||||
import logging
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Any
|
|
||||||
|
|
||||||
from forecasting_model.utils.config_model import MainConfig
|
|
||||||
from forecasting_model.data_processing import (
|
|
||||||
load_raw_data,
|
|
||||||
TimeSeriesCrossValidationSplitter,
|
|
||||||
prepare_fold_data_and_loaders
|
|
||||||
)
|
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
||||||
)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
def load_config(config_path: Path) -> MainConfig:
|
|
||||||
"""
|
|
||||||
Load and validate configuration from YAML file.
|
|
||||||
"""
|
|
||||||
# TODO: Implement config loading
|
|
||||||
pass
|
|
||||||
|
|
||||||
def set_seeds(seed: int = 42) -> None:
|
|
||||||
"""
|
|
||||||
Set random seeds for reproducibility.
|
|
||||||
"""
|
|
||||||
# TODO: Implement seed setting
|
|
||||||
pass
|
|
||||||
|
|
||||||
def determine_device(config: MainConfig) -> torch.device:
|
|
||||||
"""
|
|
||||||
Determine the device to use for training.
|
|
||||||
"""
|
|
||||||
# TODO: Implement device determination
|
|
||||||
pass
|
|
||||||
|
|
||||||
def aggregate_cv_metrics(all_fold_metrics: List[Dict[str, float]]) -> Dict[str, Dict[str, float]]:
|
|
||||||
"""
|
|
||||||
Calculate mean and standard deviation of metrics across folds.
|
|
||||||
"""
|
|
||||||
# TODO: Implement metric aggregation
|
|
||||||
pass
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# Load configuration
|
|
||||||
config = load_config(Path("config.yaml"))
|
|
||||||
|
|
||||||
# Set random seeds
|
|
||||||
set_seeds()
|
|
||||||
|
|
||||||
# Determine device
|
|
||||||
device = determine_device(config)
|
|
||||||
|
|
||||||
# Load raw data
|
|
||||||
df = load_raw_data(config.data)
|
|
||||||
|
|
||||||
# Initialize CV splitter
|
|
||||||
cv_splitter = TimeSeriesCrossValidationSplitter(config.cross_validation, len(df))
|
|
||||||
|
|
||||||
# Initialize list to store fold metrics
|
|
||||||
all_fold_metrics = []
|
|
||||||
|
|
||||||
# Cross-validation loop
|
|
||||||
for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split(), 1):
|
|
||||||
logger.info(f"Starting fold {fold_num}")
|
|
||||||
|
|
||||||
# Prepare data loaders
|
|
||||||
train_loader, val_loader, test_loader, target_scaler, input_size = prepare_fold_data_and_loaders(
|
|
||||||
df, train_idx, val_idx, test_idx,
|
|
||||||
config.features, config.training, config.evaluation
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update model config with input size
|
|
||||||
config.model.input_size = input_size
|
|
||||||
|
|
||||||
# Initialize model
|
|
||||||
model = LSTMForecastModel(config.model).to(device)
|
|
||||||
|
|
||||||
# Initialize loss function
|
|
||||||
loss_fn = torch.nn.MSELoss() if config.training.loss_function == "MSE" else torch.nn.L1Loss()
|
|
||||||
|
|
||||||
# Initialize scheduler if configured
|
|
||||||
scheduler = None
|
|
||||||
if config.training.scheduler_step_size is not None:
|
|
||||||
# TODO: Initialize scheduler
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Initialize trainer
|
|
||||||
trainer = Trainer(
|
|
||||||
model, train_loader, val_loader, loss_fn, device,
|
|
||||||
config.training, scheduler, target_scaler
|
|
||||||
)
|
|
||||||
|
|
||||||
# Train model
|
|
||||||
trainer.train()
|
|
||||||
|
|
||||||
# Evaluate on test set
|
|
||||||
fold_metrics = evaluate_fold(
|
|
||||||
model, test_loader, loss_fn, device,
|
|
||||||
target_scaler, config.evaluation, fold_num
|
|
||||||
)
|
|
||||||
|
|
||||||
all_fold_metrics.append(fold_metrics)
|
|
||||||
|
|
||||||
# Optional: Clear GPU memory
|
|
||||||
if device.type == "cuda":
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# Aggregate metrics
|
|
||||||
aggregated_metrics = aggregate_cv_metrics(all_fold_metrics)
|
|
||||||
|
|
||||||
# Log final results
|
|
||||||
logger.info("Cross-validation results:")
|
|
||||||
for metric, stats in aggregated_metrics.items():
|
|
||||||
logger.info(f"{metric}: {stats['mean']:.4f} ± {stats['std']:.4f}")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
25
optim_config.yaml
Normal file
25
optim_config.yaml
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# Configuration for the battery optimization runs
|
||||||
|
|
||||||
|
# Initial state of charge of the battery (MWh)
|
||||||
|
initial_b: 0.0
|
||||||
|
|
||||||
|
# Maximum energy capacity of the battery (MWh)
|
||||||
|
max_capacity: 1.0
|
||||||
|
|
||||||
|
# Maximum charge/discharge power rate of the battery (MW)
|
||||||
|
max_rate: 1.0
|
||||||
|
|
||||||
|
# The length of the time window (in hours) for which the optimization is run
|
||||||
|
# This should match the forecast horizon of the models being evaluated.
|
||||||
|
optimization_horizon_hours: 24
|
||||||
|
|
||||||
|
# List of models to evaluate. Each entry includes the path to the model's
|
||||||
|
# forecast output file and the path to the forecasting config used for that model.
|
||||||
|
models:
|
||||||
|
- name: "Model_A"
|
||||||
|
forecast_path: "path/to/model_a_forecast_output.csv" # Path to the file containing forecast time points and prices
|
||||||
|
forecast_config_path: "configs/model_a_forecasting_config.yaml" # Path to the forecasting config used for this model
|
||||||
|
- name: "Model_B"
|
||||||
|
forecast_path: "path/to/model_b_forecast_output.csv"
|
||||||
|
forecast_config_path: "configs/model_b_forecasting_config.yaml"
|
||||||
|
# Add more models here
|
544
optim_run.py
Normal file
544
optim_run.py
Normal file
@ -0,0 +1,544 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import yaml
|
||||||
|
import logging
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import seaborn as sns
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Import Forecasting Providers
|
||||||
|
from forecasting_model.data_processing import load_raw_data
|
||||||
|
from optimizer.forecasting.base import ForecastProvider
|
||||||
|
from optimizer.forecasting.single_model import SingleModelProvider
|
||||||
|
from optimizer.forecasting.ensemble import EnsembleProvider
|
||||||
|
|
||||||
|
from optimizer.optimization.battery import solve_battery_optimization_hourly
|
||||||
|
from optimizer.utils.optim_config import OptimizationRunConfig
|
||||||
|
from forecasting_model.utils.forecast_config_model import DataConfig, MainConfig
|
||||||
|
|
||||||
|
# Import the newly created loading functions
|
||||||
|
from optimizer.utils.model_io import load_single_model_artifact, load_ensemble_artifact
|
||||||
|
|
||||||
|
from typing import Dict, Any, Optional, Union # Added Union
|
||||||
|
|
||||||
|
# Silence overly verbose libraries if needed
|
||||||
|
mpl_logger = logging.getLogger('matplotlib')
|
||||||
|
mpl_logger.setLevel(logging.WARNING)
|
||||||
|
pil_logger = logging.getLogger('PIL')
|
||||||
|
pil_logger.setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
# --- Basic Logging Setup ---
|
||||||
|
logging.basicConfig(level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)-7s - %(message)s',
|
||||||
|
datefmt='%H:%M:%S')
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
def load_optimization_config(config_path: str) -> OptimizationRunConfig | None:
|
||||||
|
"""Loads the main optimization configuration from a YAML file."""
|
||||||
|
logger.info(f"Loading optimization config from {config_path}")
|
||||||
|
try:
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
config_data = yaml.safe_load(f)
|
||||||
|
return OptimizationRunConfig(**config_data)
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.error(f"Optimization config file not found at {config_path}")
|
||||||
|
return None
|
||||||
|
except yaml.YAMLError as e:
|
||||||
|
logger.error(f"Error parsing YAML optimization config file: {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading optimization config: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def load_main_forecasting_config(config_path: str) -> MainConfig | None:
|
||||||
|
"""Loads the main forecasting configuration from a YAML file."""
|
||||||
|
logger.info(f"Loading main forecasting config from: {config_path}")
|
||||||
|
try:
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
config_data = yaml.safe_load(f)
|
||||||
|
# Assuming MainConfig is the top-level model in forecast_config_model.py
|
||||||
|
return MainConfig(**config_data)
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.error(f"Main forecasting config file not found at {config_path}")
|
||||||
|
return None
|
||||||
|
except yaml.YAMLError as e:
|
||||||
|
logger.error(f"Error parsing YAML main forecasting config file: {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading main forecasting config: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# --- Main Execution Logic ---
|
||||||
|
# 1. Load configs
|
||||||
|
# 2. Initialize forecast providers
|
||||||
|
# 3. For each time window:
|
||||||
|
# a. Get forecasts for all horizons
|
||||||
|
# b. Run optimization for each horizon
|
||||||
|
# c. Store results
|
||||||
|
# 4. Evaluate and visualize
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logger.info("Starting battery optimization evaluation with baseline and models/ensembles.")
|
||||||
|
|
||||||
|
# --- Load Main Optimization Config ---
|
||||||
|
optimization_config_path = "optim_config.yaml"
|
||||||
|
optimization_config = load_optimization_config(optimization_config_path)
|
||||||
|
|
||||||
|
if optimization_config is None:
|
||||||
|
logger.critical("Failed to load main optimization config. Exiting.") # Use critical for exit
|
||||||
|
exit(1) # Use non-zero exit code for error
|
||||||
|
|
||||||
|
optim_run_script_dir = Path(__file__).parent
|
||||||
|
|
||||||
|
if not optimization_config.models:
|
||||||
|
logger.critical("No models or ensembles specified in optimization config. Exiting.")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
# Try to load the main forecasting config for the first model/ensemble to get the data path
|
||||||
|
first_model_config_path = Path(optimization_config.models[0].model_config_path)
|
||||||
|
main_forecasting_config_for_data = load_main_forecasting_config(str(first_model_config_path))
|
||||||
|
|
||||||
|
if main_forecasting_config_for_data is None:
|
||||||
|
logger.critical("Failed to load forecasting config for the first specified model/ensemble to get data path. Exiting.")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
# Use the DataConfig from the first loaded forecasting config
|
||||||
|
historical_data_config = DataConfig(
|
||||||
|
data_path=main_forecasting_config_for_data.data.data_path,
|
||||||
|
raw_datetime_col=main_forecasting_config_for_data.data.raw_datetime_col,
|
||||||
|
raw_datetime_format=main_forecasting_config_for_data.data.raw_datetime_format,
|
||||||
|
datetime_col=main_forecasting_config_for_data.data.datetime_col,
|
||||||
|
raw_target_col=main_forecasting_config_for_data.data.raw_target_col,
|
||||||
|
target_col=main_forecasting_config_for_data.data.target_col,
|
||||||
|
expected_frequency=main_forecasting_config_for_data.data.expected_frequency,
|
||||||
|
fill_initial_target_nans=main_forecasting_config_for_data.data.fill_initial_target_nans
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Loading original historical data from: {historical_data_config.data_path}")
|
||||||
|
try:
|
||||||
|
full_historical_df = load_raw_data(historical_data_config)
|
||||||
|
|
||||||
|
if full_historical_df.empty:
|
||||||
|
logger.critical("Loaded original historical data is empty. Cannot proceed. Exiting.")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
# Ensure data is at the expected frequency and sorted
|
||||||
|
full_historical_df = full_historical_df.sort_index().asfreq(historical_data_config.expected_frequency)
|
||||||
|
# Fill any NaNs introduced by asfreq if not already handled by fill_initial_target_nans
|
||||||
|
if full_historical_df[historical_data_config.target_col].isnull().any():
|
||||||
|
logger.warning(f"NaNs found after setting frequency {historical_data_config.expected_frequency}. Applying ffill().bfill().")
|
||||||
|
full_historical_df[historical_data_config.target_col] = full_historical_df[historical_data_config.target_col].ffill().bfill()
|
||||||
|
if full_historical_df[historical_data_config.target_col].isnull().any():
|
||||||
|
logger.critical("NaNs still remain after filling. Cannot proceed. Exiting.")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
logger.info(f"Original historical data loaded and prepared. Shape: {full_historical_df.shape}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.critical(f"Failed to load or prepare original historical data from {historical_data_config.data_path}: {e}", exc_info=True)
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Define Evaluation Window and Step ---
|
||||||
|
optimization_horizon_hours = optimization_config.optimization_horizon_hours
|
||||||
|
step_size_hours = 1 # Evaluate every hour by sliding the window by 1 hour
|
||||||
|
|
||||||
|
logger.info(f"Using optimization horizon: {optimization_horizon_hours} hours with a step size of {step_size_hours} hour(s).")
|
||||||
|
|
||||||
|
# --- Storage for results per time window ---
|
||||||
|
window_results_list = []
|
||||||
|
|
||||||
|
# --- Load Models/Ensembles and Instantiate Providers ---
|
||||||
|
# Store loaded provider instances, keyed by the name from optim_config
|
||||||
|
forecast_providers: Dict[str, ForecastProvider] = {} # Store provider instances
|
||||||
|
|
||||||
|
for model_eval_config in optimization_config.models:
|
||||||
|
provider_name = model_eval_config.name
|
||||||
|
artifact_type = model_eval_config.type
|
||||||
|
artifact_path = Path(model_eval_config.model_path) # Path to .ckpt or .json
|
||||||
|
config_path = Path(model_eval_config.model_config_path) # Path to YAML config
|
||||||
|
|
||||||
|
provider_instance: Optional[ForecastProvider] = None # Initialize provider instance
|
||||||
|
|
||||||
|
if artifact_type == 'model':
|
||||||
|
logger.info(f"Attempting to load single model artifact and create provider: {provider_name}")
|
||||||
|
target_scaler_path = Path(model_eval_config.target_scaler_path) if model_eval_config.target_scaler_path else None
|
||||||
|
input_size_path = artifact_path.parent / "input_size.pt" # Derive path convention
|
||||||
|
if not input_size_path.exists() and artifact_path.parent.name == 'checkpoints':
|
||||||
|
input_size_path = artifact_path.parent.parent / "input_size.pt"
|
||||||
|
|
||||||
|
loaded_artifact_info = load_single_model_artifact(
|
||||||
|
model_path=artifact_path,
|
||||||
|
config_path=config_path,
|
||||||
|
input_size_path=input_size_path,
|
||||||
|
target_scaler_path=target_scaler_path
|
||||||
|
)
|
||||||
|
|
||||||
|
if loaded_artifact_info:
|
||||||
|
try:
|
||||||
|
provider_instance = SingleModelProvider(
|
||||||
|
model_instance=loaded_artifact_info['model_instance'],
|
||||||
|
feature_config=loaded_artifact_info['feature_config'],
|
||||||
|
target_col=loaded_artifact_info['main_forecasting_config'].data.target_col, # Get target col from loaded config
|
||||||
|
target_scaler=loaded_artifact_info['target_scaler']
|
||||||
|
)
|
||||||
|
# Validation check (basic horizon check)
|
||||||
|
if 1 not in provider_instance.feature_config.forecast_horizon:
|
||||||
|
logger.error(f"Model '{provider_name}' forecast horizon {provider_instance.feature_config.forecast_horizon} does not include 1 hour. Cannot use for this evaluation.")
|
||||||
|
provider_instance = None # Discard if validation fails
|
||||||
|
else:
|
||||||
|
logger.info(f"Successfully created SingleModelProvider for '{provider_name}'.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to instantiate SingleModelProvider for '{provider_name}': {e}", exc_info=True)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Single model artifact '{provider_name}' could not be loaded. Skipping provider creation.")
|
||||||
|
|
||||||
|
elif artifact_type == 'ensemble':
|
||||||
|
logger.info(f"Attempting to load ensemble artifact and create provider: {provider_name}")
|
||||||
|
hpo_base_output_dir_for_ensemble = artifact_path.parent
|
||||||
|
|
||||||
|
loaded_artifact_info = load_ensemble_artifact(
|
||||||
|
ensemble_definition_path=artifact_path,
|
||||||
|
hpo_base_output_dir=hpo_base_output_dir_for_ensemble
|
||||||
|
)
|
||||||
|
|
||||||
|
if loaded_artifact_info:
|
||||||
|
try:
|
||||||
|
# Ensure necessary keys are present before instantiation
|
||||||
|
required_keys = ['fold_artifacts', 'ensemble_method', 'ensemble_feature_config', 'ensemble_target_col']
|
||||||
|
if not all(key in loaded_artifact_info for key in required_keys):
|
||||||
|
missing_keys = [key for key in required_keys if key not in loaded_artifact_info]
|
||||||
|
raise ValueError(f"Ensemble artifact info is missing required keys: {missing_keys}")
|
||||||
|
|
||||||
|
provider_instance = EnsembleProvider(
|
||||||
|
fold_artifacts=loaded_artifact_info['fold_artifacts'],
|
||||||
|
ensemble_method=loaded_artifact_info['ensemble_method'],
|
||||||
|
ensemble_feature_config=loaded_artifact_info['ensemble_feature_config'],
|
||||||
|
ensemble_target_col=loaded_artifact_info['ensemble_target_col']
|
||||||
|
)
|
||||||
|
# Validation check (basic horizon check)
|
||||||
|
if 1 not in provider_instance.common_forecast_horizons:
|
||||||
|
logger.error(f"Ensemble '{provider_name}' common forecast horizon {provider_instance.common_forecast_horizons} does not include 1 hour. Cannot use for this evaluation.")
|
||||||
|
provider_instance = None # Discard if validation fails
|
||||||
|
else:
|
||||||
|
logger.info(f"Successfully created EnsembleProvider for '{provider_name}'.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to instantiate EnsembleProvider for '{provider_name}': {e}", exc_info=True)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Ensemble artifact '{provider_name}' could not be loaded. Skipping provider creation.")
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.error(f"Unknown artifact type '{artifact_type}' for '{provider_name}'. Skipping.")
|
||||||
|
continue # Skip to next model_eval_config
|
||||||
|
|
||||||
|
# Store the successfully created provider instance
|
||||||
|
if provider_instance:
|
||||||
|
forecast_providers[provider_name] = provider_instance
|
||||||
|
|
||||||
|
# --- End Loading ---
|
||||||
|
|
||||||
|
|
||||||
|
if not forecast_providers:
|
||||||
|
logger.critical("No forecast providers were successfully created. Cannot proceed with evaluation. Exiting.")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
# --- Calculate Max Lookback Needed Across All Providers ---
|
||||||
|
max_required_lookback = 0
|
||||||
|
for provider_name, provider in forecast_providers.items():
|
||||||
|
try:
|
||||||
|
lookback = provider.get_required_lookback()
|
||||||
|
max_required_lookback = max(max_required_lookback, lookback)
|
||||||
|
logger.debug(f"Provider '{provider_name}' requires lookback: {lookback}")
|
||||||
|
except AttributeError:
|
||||||
|
logger.error(f"Provider '{provider_name}' does not have a 'get_required_lookback' method. Cannot determine lookback requirements accurately. Exiting.")
|
||||||
|
exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting lookback for provider '{provider_name}': {e}. Exiting.", exc_info=True)
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
logger.info(f"Maximum lookback required across all providers: {max_required_lookback} hours.")
|
||||||
|
|
||||||
|
# The first timestamp for which we can generate a forecast needs `max_required_lookback` points *before* it.
|
||||||
|
# If optimization starts at `window_start_time` (iloc `i`), the forecast generation needs data up to `i-1`.
|
||||||
|
# The historical slice passed to `get_forecast` must contain `max_required_lookback` points, ending at `i-1`.
|
||||||
|
# Therefore, the slice starts at `i - max_required_lookback`. This must be >= 0.
|
||||||
|
# So, `i >= max_required_lookback`.
|
||||||
|
first_window_start_iloc = max_required_lookback
|
||||||
|
|
||||||
|
# The last window starts such that the window ends within the data: `i + optimization_horizon_hours - 1 < len(df)`
|
||||||
|
# So, `i < len(df) - optimization_horizon_hours + 1`.
|
||||||
|
last_window_start_iloc = len(full_historical_df) - optimization_horizon_hours
|
||||||
|
|
||||||
|
if first_window_start_iloc > last_window_start_iloc:
|
||||||
|
logger.critical(f"Not enough historical data ({len(full_historical_df)} hours) for the required lookback ({max_required_lookback}) and optimization horizon ({optimization_horizon_hours}). First possible window start iloc: {first_window_start_iloc}, last possible: {last_window_start_iloc}. Exiting.")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
logger.info(f"Evaluating over historical windows from iloc {first_window_start_iloc} to {last_window_start_iloc}.")
|
||||||
|
|
||||||
|
# --- Evaluation Loop ---
|
||||||
|
for i in range(first_window_start_iloc, last_window_start_iloc + 1, step_size_hours):
|
||||||
|
# Define the actual optimization window in terms of iloc and time
|
||||||
|
window_start_iloc = i
|
||||||
|
window_end_iloc = i + optimization_horizon_hours - 1 # Inclusive index for the window end
|
||||||
|
|
||||||
|
# Check if the window is complete within the dataset bounds
|
||||||
|
if window_end_iloc >= len(full_historical_df):
|
||||||
|
logger.warning(f"Skipping window starting at iloc {window_start_iloc}: extends beyond available data (needs up to iloc {window_end_iloc}, max is {len(full_historical_df)-1}).")
|
||||||
|
continue
|
||||||
|
|
||||||
|
window_timestamps = full_historical_df.index[window_start_iloc : window_end_iloc + 1]
|
||||||
|
|
||||||
|
# Double-check length just in case
|
||||||
|
if len(window_timestamps) != optimization_horizon_hours:
|
||||||
|
logger.warning(f"Skipping window starting at iloc {window_start_iloc} due to unexpected timestamp slice length ({len(window_timestamps)} instead of {optimization_horizon_hours} hours).")
|
||||||
|
continue
|
||||||
|
|
||||||
|
window_start_time = window_timestamps[0]
|
||||||
|
window_end_time = window_timestamps[-1]
|
||||||
|
logger.info(f"Processing window: {window_start_time.strftime('%Y-%m-%d %H:%M')} to {window_end_time.strftime('%Y-%m-%d %H:%M')} (iloc {window_start_iloc})")
|
||||||
|
|
||||||
|
# --- Prepare Historical Slice for Forecasting ---
|
||||||
|
# We need data *up to* the beginning of the optimization window, including lookback.
|
||||||
|
# Slice should end at iloc `window_start_iloc - 1`.
|
||||||
|
# Slice should start at `window_start_iloc - max_required_lookback`.
|
||||||
|
hist_slice_start_iloc = max(0, window_start_iloc - max_required_lookback)
|
||||||
|
hist_slice_end_iloc = window_start_iloc # Exclusive end iloc for slicing, so it includes up to window_start_iloc - 1
|
||||||
|
|
||||||
|
if hist_slice_end_iloc <= hist_slice_start_iloc:
|
||||||
|
logger.error(f"Invalid historical slice range for window starting at {window_start_time}: start_iloc={hist_slice_start_iloc}, end_iloc={hist_slice_end_iloc}. Skipping window.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
historical_slice_for_forecasting = full_historical_df.iloc[hist_slice_start_iloc : hist_slice_end_iloc].copy()
|
||||||
|
|
||||||
|
# Check if the slice has the expected length (at least max_required_lookback, unless near start of data)
|
||||||
|
if len(historical_slice_for_forecasting) < max_required_lookback and window_start_iloc >= max_required_lookback:
|
||||||
|
logger.warning(f"Historical slice for window starting {window_start_time} has unexpected length {len(historical_slice_for_forecasting)}, expected {max_required_lookback}. Check slicing logic. Skipping.")
|
||||||
|
continue
|
||||||
|
elif len(historical_slice_for_forecasting) == 0:
|
||||||
|
logger.warning(f"Historical slice for window starting {window_start_time} is empty. Skipping.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.debug(f"Using historical slice from {historical_slice_for_forecasting.index.min()} to {historical_slice_for_forecasting.index.max()} (Length: {len(historical_slice_for_forecasting)}) for forecasting.")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Collect Window Results ---
|
||||||
|
window_results = {
|
||||||
|
'start_time': window_start_time,
|
||||||
|
'end_time': window_end_time,
|
||||||
|
'actual_prices': full_historical_df[historical_data_config.target_col].iloc[window_start_iloc : window_end_iloc + 1].tolist()
|
||||||
|
}
|
||||||
|
|
||||||
|
# --- Baseline Optimization ---
|
||||||
|
baseline_prices_input = np.array(window_results['actual_prices'])
|
||||||
|
logger.debug(f"Running baseline optimization for window starting {window_start_time}")
|
||||||
|
try:
|
||||||
|
baseline_status, baseline_profit, baseline_power, baseline_B = solve_battery_optimization_hourly(
|
||||||
|
baseline_prices_input,
|
||||||
|
optimization_config.initial_b,
|
||||||
|
optimization_config.max_capacity,
|
||||||
|
optimization_config.max_rate
|
||||||
|
)
|
||||||
|
window_results['baseline'] = {
|
||||||
|
"status": baseline_status,
|
||||||
|
"profit": baseline_profit,
|
||||||
|
"power_schedule": baseline_power.tolist() if baseline_power is not None else None,
|
||||||
|
"B_schedule": baseline_B.tolist() if baseline_B is not None else None
|
||||||
|
}
|
||||||
|
logger.debug(f"Baseline profit: {baseline_profit if baseline_profit is not None else 'N/A'}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Baseline optimization failed for window starting {window_start_time}: {e}", exc_info=True)
|
||||||
|
window_results['baseline'] = {"status": "Error", "profit": None, "power_schedule": None, "B_schedule": None}
|
||||||
|
|
||||||
|
|
||||||
|
# --- Forecast Provider Optimizations ---
|
||||||
|
for provider_name, provider_instance in forecast_providers.items():
|
||||||
|
logger.debug(f"Generating forecast and running optimization for provider '{provider_name}' for window starting {window_start_time}")
|
||||||
|
|
||||||
|
# Generate forecast using the provider's get_forecast method
|
||||||
|
try:
|
||||||
|
forecast_prices_input = provider_instance.get_forecast(
|
||||||
|
historical_data_slice=historical_slice_for_forecasting.copy(), # Pass a copy
|
||||||
|
optimization_horizon_hours=optimization_horizon_hours
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error calling get_forecast for provider '{provider_name}': {e}", exc_info=True)
|
||||||
|
forecast_prices_input = None
|
||||||
|
|
||||||
|
|
||||||
|
if forecast_prices_input is None or len(forecast_prices_input) != optimization_horizon_hours:
|
||||||
|
logger.warning(f"Forecast generation failed or returned incorrect length ({len(forecast_prices_input) if forecast_prices_input is not None else 0} instead of {optimization_horizon_hours}) for provider '{provider_name}' window starting {window_start_time}. Skipping optimization.")
|
||||||
|
window_results[provider_name] = {"status": "Forecast Generation Failed", "profit": None, "power_schedule": None, "B_schedule": None}
|
||||||
|
continue # Skip optimization for this provider/window
|
||||||
|
|
||||||
|
# Ensure the forecast input is a numpy array of the correct shape
|
||||||
|
if not isinstance(forecast_prices_input, np.ndarray) or forecast_prices_input.shape != (optimization_horizon_hours,):
|
||||||
|
logger.error(f"Forecast input for provider '{provider_name}' has incorrect format ({type(forecast_prices_input)}, shape {forecast_prices_input.shape if isinstance(forecast_prices_input, np.ndarray) else 'N/A'}). Expected ({optimization_horizon_hours},). Skipping optimization.")
|
||||||
|
window_results[provider_name] = {"status": "Invalid Forecast Format", "profit": None, "power_schedule": None, "B_schedule": None}
|
||||||
|
continue
|
||||||
|
|
||||||
|
# --- Run Optimization with Forecast Prices ---
|
||||||
|
try:
|
||||||
|
model_status, model_profit, model_power, model_B = solve_battery_optimization_hourly(
|
||||||
|
forecast_prices_input,
|
||||||
|
optimization_config.initial_b,
|
||||||
|
optimization_config.max_capacity,
|
||||||
|
optimization_config.max_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
window_results[provider_name] = {
|
||||||
|
"status": model_status,
|
||||||
|
"profit": model_profit,
|
||||||
|
"power_schedule": model_power.tolist() if model_power is not None else None,
|
||||||
|
"B_schedule": model_B.tolist() if model_B is not None else None
|
||||||
|
}
|
||||||
|
logger.debug(f"Provider '{provider_name}' profit: {model_profit if model_profit is not None else 'N/A'}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Optimization failed for provider '{provider_name}' window starting {window_start_time}: {e}", exc_info=True)
|
||||||
|
window_results[provider_name] = {"status": "Error", "profit": None, "power_schedule": None, "B_schedule": None}
|
||||||
|
|
||||||
|
# Append results for this window
|
||||||
|
window_results_list.append(window_results)
|
||||||
|
logger.debug(f"Finished processing window starting at: {window_start_time.strftime('%Y-%m-%d %H:%M')}")
|
||||||
|
|
||||||
|
logger.info("Finished processing all evaluation windows.")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Post-processing and Plotting ---
|
||||||
|
logger.info("Starting results analysis and plotting.")
|
||||||
|
|
||||||
|
if not window_results_list:
|
||||||
|
logger.warning("No window results were collected. Skipping plotting.")
|
||||||
|
exit(0) # Not necessarily an error state
|
||||||
|
|
||||||
|
# Convert results list to a DataFrame
|
||||||
|
flat_results = []
|
||||||
|
successfully_loaded_provider_names = list(forecast_providers.keys()) # Names of providers used
|
||||||
|
|
||||||
|
for window_res in window_results_list:
|
||||||
|
base_info = {
|
||||||
|
'start_time': window_res['start_time'],
|
||||||
|
'end_time': window_res['end_time'],
|
||||||
|
}
|
||||||
|
# Add baseline results
|
||||||
|
flat_results.append({**base_info, 'type': 'baseline', **window_res.get('baseline', {})})
|
||||||
|
# Add provider results
|
||||||
|
for provider_name in successfully_loaded_provider_names:
|
||||||
|
provider_res = window_res.get(provider_name, {}) # Get results or empty dict
|
||||||
|
flat_results.append({**base_info, 'type': provider_name, **provider_res})
|
||||||
|
|
||||||
|
|
||||||
|
results_df = pd.DataFrame(flat_results)
|
||||||
|
results_df['start_time'] = pd.to_datetime(results_df['start_time']) # Ensure datetime type
|
||||||
|
|
||||||
|
# Filter out rows where essential optimization results are missing
|
||||||
|
# results_df.dropna(subset=['profit', 'power_schedule'], inplace=True) # Be careful with dropna
|
||||||
|
|
||||||
|
# Calculate Profit Absolute Error over time
|
||||||
|
profit_pivot = results_df.pivot_table(index='start_time', columns='type', values='profit')
|
||||||
|
|
||||||
|
mae_df = pd.DataFrame(index=profit_pivot.index)
|
||||||
|
if 'baseline' in profit_pivot.columns:
|
||||||
|
for provider_name in successfully_loaded_provider_names:
|
||||||
|
if provider_name in profit_pivot.columns:
|
||||||
|
# Use .sub() and .abs() to handle potential NaNs gracefully
|
||||||
|
mae_df[f'Profit_Abs_Error_{provider_name}'] = profit_pivot[provider_name].sub(profit_pivot['baseline']).abs()
|
||||||
|
else:
|
||||||
|
logger.warning(f"Cannot calculate profit MAE for provider '{provider_name}'. Data not found in pivoted results.")
|
||||||
|
else:
|
||||||
|
logger.warning("Cannot calculate profit MAE because baseline results are missing or incomplete.")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Plotting ---
|
||||||
|
|
||||||
|
# Plot 1: Price and First Hour's Power Schedule Over Time
|
||||||
|
logger.info("Generating Price and Power Schedule plot.")
|
||||||
|
|
||||||
|
continuous_power_data = []
|
||||||
|
for window_res in window_results_list:
|
||||||
|
start_time = window_res['start_time']
|
||||||
|
# Baseline power
|
||||||
|
baseline_data = window_res.get('baseline', {})
|
||||||
|
if baseline_data.get('power_schedule') and len(baseline_data['power_schedule']) > 0:
|
||||||
|
continuous_power_data.append({'time': start_time, 'type': 'baseline', 'power': baseline_data['power_schedule'][0]})
|
||||||
|
# Provider powers
|
||||||
|
for provider_name in successfully_loaded_provider_names:
|
||||||
|
provider_data = window_res.get(provider_name, {})
|
||||||
|
if provider_data.get('power_schedule') and len(provider_data['power_schedule']) > 0:
|
||||||
|
continuous_power_data.append({'time': start_time, 'type': provider_name, 'power': provider_data['power_schedule'][0]})
|
||||||
|
|
||||||
|
continuous_power_df = pd.DataFrame(continuous_power_data)
|
||||||
|
if not continuous_power_df.empty:
|
||||||
|
continuous_power_df['time'] = pd.to_datetime(continuous_power_df['time'])
|
||||||
|
|
||||||
|
# Get historical prices corresponding to the evaluation window start times
|
||||||
|
eval_start_times = results_df['start_time'].unique()
|
||||||
|
price_plot_df = full_historical_df.loc[eval_start_times, [historical_data_config.target_col]].reset_index()
|
||||||
|
price_plot_df.rename(columns={price_plot_df.columns[0]: 'time', historical_data_config.target_col: 'price'}, inplace=True) # Use positional index for timestamp column rename
|
||||||
|
|
||||||
|
plot_range_start = continuous_power_df['time'].min()
|
||||||
|
plot_range_end = continuous_power_df['time'].max()
|
||||||
|
|
||||||
|
# Filter data for the plot range
|
||||||
|
filtered_price_df = price_plot_df[(price_plot_df['time'] >= plot_range_start) & (price_plot_df['time'] <= plot_range_end)]
|
||||||
|
filtered_power_df = continuous_power_df[(continuous_power_df['time'] >= plot_range_start) & (continuous_power_df['time'] <= plot_range_end)]
|
||||||
|
|
||||||
|
if not filtered_power_df.empty:
|
||||||
|
fig1, ax1 = plt.subplots(figsize=(15, 7))
|
||||||
|
ax2 = ax1.twinx()
|
||||||
|
|
||||||
|
sns.lineplot(data=filtered_price_df, x='time', y='price', ax=ax1, color='gray', linestyle='--', label='Historical Price (Window Start)', zorder=1)
|
||||||
|
ax1.set_ylabel('Price (€/MWh)', color='gray')
|
||||||
|
ax1.tick_params(axis='y', labelcolor='gray')
|
||||||
|
|
||||||
|
sns.lineplot(data=filtered_power_df, x='time', y='power', hue='type', ax=ax2, zorder=2)
|
||||||
|
ax2.set_ylabel('Power (MW)')
|
||||||
|
|
||||||
|
h1, l1 = ax1.get_legend_handles_labels()
|
||||||
|
h2, l2 = ax2.get_legend_handles_labels()
|
||||||
|
ax2.legend(h1 + h2, l1 + l2, loc='upper left', title='Schedule Type')
|
||||||
|
ax1.get_legend().remove() # Remove the original legend from ax1
|
||||||
|
|
||||||
|
ax1.set_xlabel('Time')
|
||||||
|
ax1.set_title('Battery Power Schedule (1st Hour) vs. Historical Price (Window Start)')
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig("power_schedule_vs_price.png")
|
||||||
|
logger.info("Price and Power Schedule plot saved as power_schedule_vs_price.png")
|
||||||
|
# plt.show()
|
||||||
|
else:
|
||||||
|
logger.warning("No power data available within the determined plot range.")
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.warning("No continuous power data generated for plotting power schedule.")
|
||||||
|
|
||||||
|
# Plot 2: Absolute Profit Error over time
|
||||||
|
logger.info("Generating Profit Absolute Error plot.")
|
||||||
|
if not mae_df.empty and not mae_df.isnull().all().all(): # Check if not empty and not all NaN
|
||||||
|
fig2, ax = plt.subplots(figsize=(15, 7))
|
||||||
|
|
||||||
|
# Use the plot range from the power plot if available
|
||||||
|
mae_plot_range_start = plot_range_start if 'plot_range_start' in locals() else mae_df.index.min()
|
||||||
|
mae_plot_range_end = plot_range_end if 'plot_range_end' in locals() else mae_df.index.max()
|
||||||
|
|
||||||
|
filtered_mae_df = mae_df[(mae_df.index >= mae_plot_range_start) & (mae_df.index <= mae_plot_range_end)].copy() # Create copy
|
||||||
|
# Optional: Handle or remove columns that are all NaN within the range
|
||||||
|
filtered_mae_df.dropna(axis=1, how='all', inplace=True)
|
||||||
|
|
||||||
|
if not filtered_mae_df.empty:
|
||||||
|
sns.lineplot(data=filtered_mae_df, ax=ax)
|
||||||
|
ax.set_xlabel('Time')
|
||||||
|
ax.set_ylabel('Absolute Profit Error vs. Baseline (€)')
|
||||||
|
ax.set_title('Absolute Profit Error of Providers vs. Baseline over Time')
|
||||||
|
ax.legend(title='Provider Type')
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig("profit_abs_error_over_time.png")
|
||||||
|
logger.info("Profit Absolute Error plot saved as profit_abs_error_over_time.png")
|
||||||
|
# plt.show()
|
||||||
|
else:
|
||||||
|
logger.warning("MAE data is all NaN or empty within the plot range. Skipping MAE plot.")
|
||||||
|
else:
|
||||||
|
logger.warning("No valid data available to plot Profit Absolute Error.")
|
||||||
|
|
||||||
|
logger.info("Evaluation and plotting completed.")
|
0
optimizer/evaluation/__init__.py
Normal file
0
optimizer/evaluation/__init__.py
Normal file
0
optimizer/evaluation/metrics.py
Normal file
0
optimizer/evaluation/metrics.py
Normal file
0
optimizer/evaluation/vizualization.py
Normal file
0
optimizer/evaluation/vizualization.py
Normal file
0
optimizer/forecasting/__init__.py
Normal file
0
optimizer/forecasting/__init__.py
Normal file
22
optimizer/forecasting/base.py
Normal file
22
optimizer/forecasting/base.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
# forecasting/base.py
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class ForecastProvider:
|
||||||
|
def get_forecasts(self,
|
||||||
|
historical_data: pd.DataFrame,
|
||||||
|
forecast_horizons: List[int],
|
||||||
|
optimization_horizon: int) -> Dict[int, np.ndarray]:
|
||||||
|
"""Returns forecasts for each requested horizon."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_required_lookback(self) -> int:
|
||||||
|
"""Returns the minimum number of historical data points required."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_forecast_horizons(self) -> List[int]:
|
||||||
|
"""Returns the list of forecast horizons."""
|
||||||
|
pass
|
||||||
|
|
188
optimizer/forecasting/ensemble.py
Normal file
188
optimizer/forecasting/ensemble.py
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
import logging
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
||||||
|
|
||||||
|
from .base import ForecastProvider
|
||||||
|
from forecasting_model.utils import FeatureConfig
|
||||||
|
from forecasting_model.train.model import LSTMForecastLightningModule
|
||||||
|
from forecasting_model import engineer_features
|
||||||
|
from optimizer.forecasting.utils import interpolate_forecast
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class EnsembleProvider(ForecastProvider):
|
||||||
|
"""Provides forecasts using an ensemble of trained LSTM models."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fold_artifacts: List[Dict[str, Any]],
|
||||||
|
ensemble_method: str,
|
||||||
|
ensemble_feature_config: FeatureConfig, # Assumed consistent across folds by loading logic
|
||||||
|
ensemble_target_col: str, # Assumed consistent
|
||||||
|
):
|
||||||
|
if not fold_artifacts:
|
||||||
|
raise ValueError("EnsembleProvider requires at least one fold artifact.")
|
||||||
|
|
||||||
|
self.fold_artifacts = fold_artifacts
|
||||||
|
self.ensemble_method = ensemble_method
|
||||||
|
# Store common config for reference, but use fold-specific details in get_forecast
|
||||||
|
self.ensemble_feature_config = ensemble_feature_config
|
||||||
|
self.ensemble_target_col = ensemble_target_col
|
||||||
|
self.common_forecast_horizons = sorted(ensemble_feature_config.forecast_horizon) # Assumed consistent
|
||||||
|
|
||||||
|
# Calculate max lookback needed across all folds
|
||||||
|
max_lookback = 0
|
||||||
|
for i, fold in enumerate(fold_artifacts):
|
||||||
|
try:
|
||||||
|
fold_feature_config = fold['feature_config']
|
||||||
|
fold_seq_len = fold_feature_config.sequence_length
|
||||||
|
|
||||||
|
feature_lookback = 0
|
||||||
|
if fold_feature_config.lags:
|
||||||
|
feature_lookback = max(feature_lookback, max(fold_feature_config.lags))
|
||||||
|
if fold_feature_config.rolling_window_sizes:
|
||||||
|
feature_lookback = max(feature_lookback, max(w - 1 for w in fold_feature_config.rolling_window_sizes))
|
||||||
|
|
||||||
|
fold_total_lookback = fold_seq_len + feature_lookback
|
||||||
|
max_lookback = max(max_lookback, fold_total_lookback)
|
||||||
|
except KeyError as e:
|
||||||
|
raise ValueError(f"Fold artifact {i} is missing expected key: {e}") from e
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error processing fold artifact {i} for lookback calculation: {e}") from e
|
||||||
|
|
||||||
|
self._required_lookback = max_lookback
|
||||||
|
logger.debug(f"EnsembleProvider initialized with {len(fold_artifacts)} folds. Method: '{ensemble_method}'. Required lookback: {self._required_lookback}")
|
||||||
|
|
||||||
|
if ensemble_method not in ['mean', 'median']:
|
||||||
|
raise ValueError(f"Unsupported ensemble method: {ensemble_method}. Use 'mean' or 'median'.")
|
||||||
|
|
||||||
|
def get_required_lookback(self) -> int:
|
||||||
|
return self._required_lookback
|
||||||
|
|
||||||
|
def get_forecast(
|
||||||
|
self,
|
||||||
|
historical_data_slice: pd.DataFrame,
|
||||||
|
optimization_horizon_hours: int
|
||||||
|
) -> np.ndarray | None:
|
||||||
|
"""
|
||||||
|
Generates forecasts from each fold model, interpolates, and aggregates.
|
||||||
|
"""
|
||||||
|
logger.debug(f"EnsembleProvider: Generating forecast for {optimization_horizon_hours} hours using {self.ensemble_method}.")
|
||||||
|
if len(historical_data_slice) < self._required_lookback:
|
||||||
|
logger.error(f"Insufficient historical data provided. Need {self._required_lookback}, got {len(historical_data_slice)}.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
fold_forecasts_interpolated = []
|
||||||
|
last_actual_price = historical_data_slice[self.ensemble_target_col].iloc[-1] # Common anchor for all folds
|
||||||
|
|
||||||
|
for i, fold_artifact in enumerate(self.fold_artifacts):
|
||||||
|
fold_id = fold_artifact.get("fold_id", i + 1)
|
||||||
|
try:
|
||||||
|
fold_model: LSTMForecastLightningModule = fold_artifact['model_instance']
|
||||||
|
fold_feature_config: FeatureConfig = fold_artifact['feature_config']
|
||||||
|
fold_target_scaler: Optional[Any] = fold_artifact['target_scaler']
|
||||||
|
fold_target_col: str = fold_artifact['main_forecasting_config'].data.target_col # Use fold specific target
|
||||||
|
fold_seq_len = fold_feature_config.sequence_length
|
||||||
|
fold_horizons = sorted(fold_feature_config.forecast_horizon)
|
||||||
|
|
||||||
|
# Calculate lookback needed *for this specific fold* to check slice length
|
||||||
|
fold_feature_lookback = 0
|
||||||
|
if fold_feature_config.lags: fold_feature_lookback = max(fold_feature_lookback, max(fold_feature_config.lags))
|
||||||
|
if fold_feature_config.rolling_window_sizes: fold_feature_lookback = max(fold_feature_lookback, max(w - 1 for w in fold_feature_config.rolling_window_sizes))
|
||||||
|
fold_total_lookback = fold_seq_len + fold_feature_lookback
|
||||||
|
|
||||||
|
if len(historical_data_slice) < fold_total_lookback:
|
||||||
|
logger.warning(f"Fold {fold_id}: Skipping fold. Insufficient historical data in slice for this fold's lookback ({fold_total_lookback} needed).")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 1. Feature Engineering (using fold's config)
|
||||||
|
# Slice needs to be long enough for this fold's total lookback.
|
||||||
|
# The input slice `historical_data_slice` should already be long enough based on max_lookback.
|
||||||
|
engineered_df_fold = engineer_features(historical_data_slice.copy(), fold_target_col, fold_feature_config)
|
||||||
|
|
||||||
|
if engineered_df_fold.isnull().any().any():
|
||||||
|
logger.warning(f"Fold {fold_id}: NaNs found after feature engineering. Attempting fill.")
|
||||||
|
engineered_df_fold = engineered_df_fold.ffill().bfill()
|
||||||
|
if engineered_df_fold.isnull().any().any():
|
||||||
|
logger.error(f"Fold {fold_id}: NaNs persist after fill. Skipping fold.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 2. Create *one* input sequence (using fold's sequence length)
|
||||||
|
if len(engineered_df_fold) < fold_seq_len:
|
||||||
|
logger.error(f"Fold {fold_id}: Engineered data ({len(engineered_df_fold)}) is shorter than fold sequence length ({fold_seq_len}). Skipping fold.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
input_sequence_data_fold = engineered_df_fold.iloc[-fold_seq_len:].copy()
|
||||||
|
feature_columns_fold = [col for col in engineered_df_fold.columns if col != fold_target_col] # Example
|
||||||
|
if not feature_columns_fold: feature_columns_fold = engineered_df_fold.columns.tolist()
|
||||||
|
input_sequence_np_fold = input_sequence_data_fold[feature_columns_fold].values
|
||||||
|
|
||||||
|
if input_sequence_np_fold.shape != (fold_seq_len, len(feature_columns_fold)):
|
||||||
|
logger.error(f"Fold {fold_id}: Input sequence has wrong shape. Expected ({fold_seq_len}, {len(feature_columns_fold)}), got {input_sequence_np_fold.shape}. Skipping fold.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
input_tensor_fold = torch.FloatTensor(input_sequence_np_fold).unsqueeze(0)
|
||||||
|
|
||||||
|
# 3. Run Inference (using fold's model)
|
||||||
|
fold_model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
predictions_scaled_fold = fold_model(input_tensor_fold) # Shape (1, num_fold_horizons)
|
||||||
|
|
||||||
|
if predictions_scaled_fold.ndim != 2 or predictions_scaled_fold.shape[0] != 1 or predictions_scaled_fold.shape[1] != len(fold_horizons):
|
||||||
|
logger.error(f"Fold {fold_id}: Prediction output shape mismatch. Expected (1, {len(fold_horizons)}), got {predictions_scaled_fold.shape}. Skipping fold.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
predictions_scaled_np_fold = predictions_scaled_fold.squeeze(0).cpu().numpy()
|
||||||
|
|
||||||
|
# 4. Inverse Transform (using fold's scaler)
|
||||||
|
predictions_original_scale_fold = predictions_scaled_np_fold
|
||||||
|
if fold_target_scaler:
|
||||||
|
try:
|
||||||
|
predictions_original_scale_fold = fold_target_scaler.inverse_transform(predictions_scaled_np_fold.reshape(-1, 1)).flatten()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Fold {fold_id}: Failed to apply inverse transform: {e}. Skipping fold.", exc_info=True)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 5. Interpolate (using fold's horizons)
|
||||||
|
interpolated_forecast_fold = interpolate_forecast(
|
||||||
|
native_horizons=fold_horizons,
|
||||||
|
native_predictions=predictions_original_scale_fold,
|
||||||
|
target_horizon=optimization_horizon_hours,
|
||||||
|
last_known_actual=last_actual_price
|
||||||
|
)
|
||||||
|
|
||||||
|
if interpolated_forecast_fold is not None:
|
||||||
|
fold_forecasts_interpolated.append(interpolated_forecast_fold)
|
||||||
|
logger.debug(f"Fold {fold_id}: Successfully generated interpolated forecast.")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Fold {fold_id}: Interpolation failed. Skipping fold.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing ensemble fold {fold_id}: {e}", exc_info=True)
|
||||||
|
continue # Skip this fold on error
|
||||||
|
|
||||||
|
# --- Aggregation ---
|
||||||
|
if not fold_forecasts_interpolated:
|
||||||
|
logger.error("No successful forecasts generated from any ensemble folds.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.debug(f"Aggregating forecasts from {len(fold_forecasts_interpolated)} folds using '{self.ensemble_method}'.")
|
||||||
|
stacked_predictions = np.stack(fold_forecasts_interpolated, axis=0) # Shape (n_folds, target_horizon)
|
||||||
|
|
||||||
|
if self.ensemble_method == 'mean':
|
||||||
|
final_ensemble_forecast = np.mean(stacked_predictions, axis=0)
|
||||||
|
elif self.ensemble_method == 'median':
|
||||||
|
final_ensemble_forecast = np.median(stacked_predictions, axis=0)
|
||||||
|
else:
|
||||||
|
# Should be caught in __init__, but double-check
|
||||||
|
logger.error(f"Internal error: Invalid ensemble method '{self.ensemble_method}' during aggregation.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.debug(f"EnsembleProvider: Successfully generated forecast.")
|
||||||
|
return final_ensemble_forecast
|
150
optimizer/forecasting/single_model.py
Normal file
150
optimizer/forecasting/single_model.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
import logging
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
||||||
|
|
||||||
|
# Imports from our project structure
|
||||||
|
from .base import ForecastProvider
|
||||||
|
from forecasting_model.utils import FeatureConfig
|
||||||
|
from forecasting_model.train.model import LSTMForecastLightningModule
|
||||||
|
from forecasting_model import engineer_features
|
||||||
|
from optimizer.forecasting.utils import interpolate_forecast
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SingleModelProvider(ForecastProvider):
|
||||||
|
"""Provides forecasts using a single trained LSTM model."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_instance: LSTMForecastLightningModule,
|
||||||
|
feature_config: FeatureConfig,
|
||||||
|
target_col: str,
|
||||||
|
target_scaler: Optional[Any], # BaseEstimator, TransformerMixin -> more specific if possible
|
||||||
|
# input_size: int # Not needed directly if model instance is configured
|
||||||
|
):
|
||||||
|
self.model = model_instance
|
||||||
|
self.feature_config = feature_config
|
||||||
|
self.target_col = target_col
|
||||||
|
self.target_scaler = target_scaler
|
||||||
|
self.sequence_length = feature_config.sequence_length
|
||||||
|
self.forecast_horizons = sorted(feature_config.forecast_horizon) # Ensure sorted
|
||||||
|
|
||||||
|
# Calculate required lookback for feature engineering
|
||||||
|
feature_lookback = 0
|
||||||
|
if feature_config.lags:
|
||||||
|
feature_lookback = max(feature_lookback, max(feature_config.lags))
|
||||||
|
if feature_config.rolling_window_sizes:
|
||||||
|
# Rolling window of size W needs W-1 previous points
|
||||||
|
feature_lookback = max(feature_lookback, max(w - 1 for w in feature_config.rolling_window_sizes))
|
||||||
|
|
||||||
|
# Total lookback: sequence length for model input + feature engineering needs
|
||||||
|
# We need `sequence_length` points for the *last* input sequence.
|
||||||
|
# The first point of that sequence needs `feature_lookback` points before it.
|
||||||
|
# So, total points needed before the *end* of the input sequence is sequence_length + feature_lookback.
|
||||||
|
# Since the input sequence ends *before* the first forecast point (t=1),
|
||||||
|
# we need `sequence_length + feature_lookback` points before t=1.
|
||||||
|
self._required_lookback = self.sequence_length + feature_lookback
|
||||||
|
logger.debug(f"SingleModelProvider initialized. Required lookback: {self._required_lookback} (SeqLen: {self.sequence_length}, FeatLookback: {feature_lookback})")
|
||||||
|
|
||||||
|
|
||||||
|
def get_required_lookback(self) -> int:
|
||||||
|
return self._required_lookback
|
||||||
|
|
||||||
|
def get_forecast(
|
||||||
|
self,
|
||||||
|
historical_data_slice: pd.DataFrame,
|
||||||
|
optimization_horizon_hours: int
|
||||||
|
) -> np.ndarray | None:
|
||||||
|
"""
|
||||||
|
Generates forecast using the single model and interpolates to hourly resolution.
|
||||||
|
"""
|
||||||
|
logger.debug(f"SingleModelProvider: Generating forecast for {optimization_horizon_hours} hours.")
|
||||||
|
if len(historical_data_slice) < self._required_lookback:
|
||||||
|
logger.error(f"Insufficient historical data provided. Need {self._required_lookback}, got {len(historical_data_slice)}.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. Feature Engineering
|
||||||
|
# Use the provided slice which already includes the lookback.
|
||||||
|
engineered_df = engineer_features(historical_data_slice.copy(), self.target_col, self.feature_config)
|
||||||
|
|
||||||
|
# Check for NaNs after feature engineering before creating sequences
|
||||||
|
if engineered_df.isnull().any().any():
|
||||||
|
logger.warning("NaNs found after feature engineering. Attempting to fill with ffill/bfill.")
|
||||||
|
# Be careful about filling target vs features if needed
|
||||||
|
engineered_df = engineered_df.ffill().bfill()
|
||||||
|
if engineered_df.isnull().any().any():
|
||||||
|
logger.error("NaNs persist after fill. Cannot create sequences.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 2. Create *one* input sequence ending at the last point of the historical slice
|
||||||
|
# This sequence is used to predict starting from the next hour (t=1)
|
||||||
|
if len(engineered_df) < self.sequence_length:
|
||||||
|
logger.error(f"Engineered data ({len(engineered_df)}) is shorter than sequence length ({self.sequence_length}).")
|
||||||
|
return None
|
||||||
|
|
||||||
|
input_sequence_data = engineered_df.iloc[-self.sequence_length:].copy()
|
||||||
|
|
||||||
|
# Convert sequence data to numpy array (excluding target if model expects it that way)
|
||||||
|
# Assuming model takes all engineered features as input
|
||||||
|
# TODO: Verify the exact features the model expects (target included/excluded?)
|
||||||
|
# Assuming all columns except maybe the original target are features
|
||||||
|
feature_columns = [col for col in engineered_df.columns if col != self.target_col] # Example
|
||||||
|
if not feature_columns: feature_columns = engineered_df.columns.tolist() # Use all if target wasn't dropped
|
||||||
|
input_sequence_np = input_sequence_data[feature_columns].values
|
||||||
|
|
||||||
|
if input_sequence_np.shape != (self.sequence_length, len(feature_columns)):
|
||||||
|
logger.error(f"Input sequence has wrong shape. Expected ({self.sequence_length}, {len(feature_columns)}), got {input_sequence_np.shape}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
input_tensor = torch.FloatTensor(input_sequence_np).unsqueeze(0) # Add batch dim
|
||||||
|
|
||||||
|
# 3. Run Inference
|
||||||
|
self.model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
# Model output shape: (1, num_horizons)
|
||||||
|
predictions_scaled = self.model(input_tensor)
|
||||||
|
|
||||||
|
if predictions_scaled.ndim != 2 or predictions_scaled.shape[0] != 1 or predictions_scaled.shape[1] != len(self.forecast_horizons):
|
||||||
|
logger.error(f"Model prediction output shape mismatch. Expected (1, {len(self.forecast_horizons)}), got {predictions_scaled.shape}.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
predictions_scaled_np = predictions_scaled.squeeze(0).cpu().numpy() # Shape: (num_horizons,)
|
||||||
|
|
||||||
|
# 4. Inverse Transform
|
||||||
|
predictions_original_scale = predictions_scaled_np
|
||||||
|
if self.target_scaler:
|
||||||
|
try:
|
||||||
|
# Scaler expects shape (n_samples, n_features), even if n_features=1
|
||||||
|
predictions_original_scale = self.target_scaler.inverse_transform(predictions_scaled_np.reshape(-1, 1)).flatten()
|
||||||
|
logger.debug("Applied inverse transform to predictions.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to apply inverse transform: {e}", exc_info=True)
|
||||||
|
# Decide whether to return scaled or None. Returning None is safer.
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 5. Interpolate
|
||||||
|
# Use the last actual price from the input data as the anchor point t=0
|
||||||
|
last_actual_price = historical_data_slice[self.target_col].iloc[-1]
|
||||||
|
interpolated_forecast = interpolate_forecast(
|
||||||
|
native_horizons=self.forecast_horizons,
|
||||||
|
native_predictions=predictions_original_scale,
|
||||||
|
target_horizon=optimization_horizon_hours,
|
||||||
|
last_known_actual=last_actual_price
|
||||||
|
)
|
||||||
|
|
||||||
|
if interpolated_forecast is None:
|
||||||
|
logger.error("Interpolation step failed.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.debug(f"SingleModelProvider: Successfully generated forecast.")
|
||||||
|
return interpolated_forecast
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during single model forecast generation: {e}", exc_info=True)
|
||||||
|
return None
|
67
optimizer/forecasting/utils.py
Normal file
67
optimizer/forecasting/utils.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Interpolation Helper ---
|
||||||
|
def interpolate_forecast(
|
||||||
|
native_horizons: List[int],
|
||||||
|
native_predictions: np.ndarray,
|
||||||
|
target_horizon: int,
|
||||||
|
last_known_actual: Optional[float] = None # Optional: use last known price as t=0 for anchor
|
||||||
|
) -> np.ndarray | None:
|
||||||
|
"""
|
||||||
|
Linearly interpolates model predictions at native horizons to a full hourly sequence.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
native_horizons: List of horizons the model predicts (e.g., [1, 6, 12, 24]). Must not be empty.
|
||||||
|
native_predictions: Numpy array of predictions corresponding to native_horizons. Must not be empty.
|
||||||
|
target_horizon: The desired length of the hourly forecast (e.g., 24).
|
||||||
|
last_known_actual: Optional last actual price before the forecast starts (at t=0). Used as anchor if 0 not in native_horizons.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A numpy array of shape (target_horizon,) with interpolated values, or None on error.
|
||||||
|
"""
|
||||||
|
if not native_horizons or native_predictions is None or native_predictions.size == 0:
|
||||||
|
logger.error("Cannot interpolate with empty native horizons or predictions.")
|
||||||
|
return None
|
||||||
|
if len(native_horizons) != len(native_predictions):
|
||||||
|
logger.error(f"Mismatched lengths: native_horizons ({len(native_horizons)}) vs native_predictions ({len(native_predictions)})")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Ensure horizons are sorted
|
||||||
|
sorted_indices = np.argsort(native_horizons)
|
||||||
|
# Use float for potentially non-integer horizons if ever needed, ensure points > 0 usually
|
||||||
|
xp = np.array(native_horizons, dtype=float)[sorted_indices]
|
||||||
|
fp = native_predictions[sorted_indices]
|
||||||
|
|
||||||
|
# Target points for interpolation (hours 1 to target_horizon)
|
||||||
|
x_target = np.arange(1, target_horizon + 1, dtype=float)
|
||||||
|
|
||||||
|
# Add t=0 point if provided and 0 is not already a native horizon
|
||||||
|
# This anchors the start of the interpolation.
|
||||||
|
if last_known_actual is not None and xp[0] > 0:
|
||||||
|
xp = np.insert(xp, 0, 0.0)
|
||||||
|
fp = np.insert(fp, 0, last_known_actual)
|
||||||
|
elif xp[0] == 0 and last_known_actual is not None:
|
||||||
|
logger.debug("Native horizons include 0, using model's prediction for t=0 instead of last_known_actual.")
|
||||||
|
elif last_known_actual is None and xp[0] > 0:
|
||||||
|
logger.warning("No last_known_actual provided and native horizons start > 0. Interpolation might be less accurate at the beginning.")
|
||||||
|
# If the first native horizon is > 1, np.interp will extrapolate constantly backwards from the first point.
|
||||||
|
|
||||||
|
|
||||||
|
# Check if target range requires extrapolation beyond the model's capability
|
||||||
|
if target_horizon > xp[-1]:
|
||||||
|
logger.warning(f"Target horizon ({target_horizon}) extends beyond the maximum native forecast horizon ({xp[-1]}). Extrapolation will occur (constant value).")
|
||||||
|
|
||||||
|
interpolated_values = np.interp(x_target, xp, fp)
|
||||||
|
return interpolated_values
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Linear interpolation failed: {e}", exc_info=True)
|
||||||
|
return None
|
0
optimizer/optimization/__init__.py
Normal file
0
optimizer/optimization/__init__.py
Normal file
74
optimizer/optimization/battery.py
Normal file
74
optimizer/optimization/battery.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
import cvxpy as cp
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def solve_battery_optimization_hourly(
|
||||||
|
hourly_prices, # Array of prices for each hour [0, 1, ..., n-1]
|
||||||
|
initial_B, # Current state of charge (MWh)
|
||||||
|
max_capacity=1.0, # MWh
|
||||||
|
max_rate=1.0 # MW (+ve discharge / -ve charge)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Solves the battery scheduling optimization problem assuming hourly steps. We want to decide at the start of each hour t=0..n-1
|
||||||
|
how much power to buy/sell (P_net_t) and therefore the state of charge at the start of each next hour (B_t+1).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hourly_prices: Prices (€/MWh) for each hour t=0..n-1.
|
||||||
|
initial_B: The state of charge at the beginning (time t=0).
|
||||||
|
max_capacity: Maximum battery energy capacity (MWh).
|
||||||
|
max_rate: Maximum charge/discharge power rate (MW).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple: (status, optimal_profit, power_schedule, B_schedule)
|
||||||
|
Returns (status, None, None, None) if optimization fails.
|
||||||
|
"""
|
||||||
|
n_hours = len(hourly_prices)
|
||||||
|
|
||||||
|
# --- CVXPY Variables ---
|
||||||
|
# Power flow for each hour t=0..n-1 (-discharge, +charge)
|
||||||
|
P = cp.Variable(n_hours, name="Power_Flow_MW")
|
||||||
|
# State of charge at the START of each hour t=0..n (B[t] is B at hour t)
|
||||||
|
B = cp.Variable(n_hours + 1, name="State_of_Charge_MWh")
|
||||||
|
|
||||||
|
# --- Objective Function ---
|
||||||
|
# Profit = sum(price[t] * Power[t])
|
||||||
|
prices = np.array(hourly_prices)
|
||||||
|
profit = prices @ P # Equivalent to cp.sum(cp.multiply(prices, P)) / prices.dot(P)
|
||||||
|
objective = cp.Maximize(profit)
|
||||||
|
|
||||||
|
# --- Constraints ---
|
||||||
|
constraints = []
|
||||||
|
|
||||||
|
# 1. Initial B
|
||||||
|
constraints.append(B[0] == initial_B)
|
||||||
|
|
||||||
|
# 2. B Dynamics: B[t+1] = B[t] - P[t] * 1 hour
|
||||||
|
constraints.append(B[1:] == B[:-1] + P)
|
||||||
|
|
||||||
|
# 3. Power Rate Limits: -max_rate <= P[t] <= max_rate
|
||||||
|
constraints.append(cp.abs(P) <= max_rate)
|
||||||
|
|
||||||
|
# 4. B Limits: 0 <= B[t] <= max_capacity (applies to B[0]...B[n])
|
||||||
|
constraints.append(B >= 0)
|
||||||
|
constraints.append(B <= max_capacity)
|
||||||
|
|
||||||
|
# --- Problem Definition and Solving ---
|
||||||
|
problem = cp.Problem(objective, constraints)
|
||||||
|
try:
|
||||||
|
# Alternative solvers are ECOS, MOSEK, and SCS
|
||||||
|
optimal_profit = problem.solve(solver=cp.CLARABEL, verbose=False)
|
||||||
|
|
||||||
|
if problem.status in [cp.OPTIMAL, cp.OPTIMAL_INACCURATE]:
|
||||||
|
return (
|
||||||
|
problem.status,
|
||||||
|
optimal_profit,
|
||||||
|
P.value, # NumPy array of optimal power flows per hour
|
||||||
|
B.value # NumPy array of optimal B at start of each hour
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(f"Optimization failed. Solver status: {problem.status}")
|
||||||
|
return problem.status, None, None, None
|
||||||
|
|
||||||
|
except cp.error.SolverError as e:
|
||||||
|
print(f"Solver Error: {e}")
|
||||||
|
return "Solver Error", None, None, None
|
0
optimizer/optimization/utils.py
Normal file
0
optimizer/optimization/utils.py
Normal file
41
optimizer/optimzer_plan.md
Normal file
41
optimizer/optimzer_plan.md
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
|
||||||
|
## Optimizer Definition for a constraint n-forecast Trading-Problem
|
||||||
|
|
||||||
|
|
||||||
|
We want to optimize the performance of an energy trader given the forecast for n steps.
|
||||||
|
The battery:
|
||||||
|
- holds 1MWh
|
||||||
|
- charges/discharges at max. 1MW per hour (we can add/loose x*1MW, x \in R )
|
||||||
|
Prices are stable for the given hour (t) and we sell and buy for the same price.
|
||||||
|
|
||||||
|
|
||||||
|
### Considerations:
|
||||||
|
- Single variable, P (=x), for each hour t from 0 to n-1.
|
||||||
|
|
||||||
|
- If P > 0, it represents discharging (selling power) with a magnitude of P.
|
||||||
|
- If P < 0, it represents charging (buying power) with a magnitude of -P.
|
||||||
|
- If P = 0, it represents holding (doing nothing).
|
||||||
|
|
||||||
|
- if we have forecasts for t_n, t_n+m we might have to **interpolate** between n .. m
|
||||||
|
- or... we work with the gaps and dt as charge time .... no
|
||||||
|
|
||||||
|
#### Variables:
|
||||||
|
|
||||||
|
- price_t = price per MWH at t (eq)
|
||||||
|
- B (t=0..n) = State of Battery in MWH
|
||||||
|
- P (t=0..n) = Charge/Discharge factor given the possible base rate of 1MW/h
|
||||||
|
- max_p = 1 (charge/discharge limits) & and battery capacity limits (both=1)
|
||||||
|
- SoB_initial = 0
|
||||||
|
- h = horizon \in N^+
|
||||||
|
|
||||||
|
|
||||||
|
### Objective
|
||||||
|
- We **Maximize**: Sum_{t=0}^{n-1} (price_t * P)
|
||||||
|
|
||||||
|
### Constraints
|
||||||
|
- Fixed starting state: SoB_0 = SoB_initial
|
||||||
|
- Charge/Discharge Limit: (-max_p <= P <= max_p) for all t = 0, ..., n-1
|
||||||
|
- Storage Limit: (0 <= B+(1*P) <= max_p) for all t = 0, ..., n-1
|
||||||
|
- Future B State: SoB_{t+1} = (B + P) for t = 0 to n-1
|
||||||
|
|
||||||
|
|
297
optimizer/utils/model_io.py
Normal file
297
optimizer/utils/model_io.py
Normal file
@ -0,0 +1,297 @@
|
|||||||
|
import logging
|
||||||
|
import yaml
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any, Optional, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from sklearn.base import BaseEstimator, TransformerMixin # For scaler type hint
|
||||||
|
|
||||||
|
# Import necessary components from forecasting_model
|
||||||
|
from forecasting_model.utils.forecast_config_model import MainConfig, FeatureConfig
|
||||||
|
from forecasting_model.train.model import LSTMForecastLightningModule
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def load_single_model_artifact(
|
||||||
|
model_path: Path,
|
||||||
|
config_path: Path,
|
||||||
|
input_size_path: Path,
|
||||||
|
target_scaler_path: Optional[Path] = None
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Loads artifacts for a single trained model checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the model checkpoint file (.ckpt).
|
||||||
|
config_path: Path to the corresponding main YAML config file.
|
||||||
|
input_size_path: Path to the input_size.pt file.
|
||||||
|
target_scaler_path: Optional path to the target_scaler.pt file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary containing loaded artifacts ('model_instance', 'feature_config',
|
||||||
|
'target_scaler', 'main_forecasting_config'), or None if loading fails.
|
||||||
|
"""
|
||||||
|
logger.info(f"Loading single model artifact from directory: {model_path.parent}")
|
||||||
|
loaded_artifacts = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. Load Config
|
||||||
|
if not config_path.is_file():
|
||||||
|
logger.error(f"Config file not found at {config_path}")
|
||||||
|
return None
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
config_data = yaml.safe_load(f)
|
||||||
|
main_config = MainConfig(**config_data)
|
||||||
|
loaded_artifacts['main_forecasting_config'] = main_config
|
||||||
|
loaded_artifacts['feature_config'] = main_config.features
|
||||||
|
logger.debug(f"Loaded config from {config_path}")
|
||||||
|
|
||||||
|
# 2. Load Input Size
|
||||||
|
if not input_size_path.is_file():
|
||||||
|
logger.error(f"Input size file not found at {input_size_path}")
|
||||||
|
return None
|
||||||
|
input_size = torch.load(input_size_path)
|
||||||
|
if not isinstance(input_size, int) or input_size <= 0:
|
||||||
|
logger.error(f"Invalid input size loaded from {input_size_path}: {input_size}")
|
||||||
|
return None
|
||||||
|
logger.debug(f"Loaded input size ({input_size}) from {input_size_path}")
|
||||||
|
|
||||||
|
# 3. Load Target Scaler (Optional)
|
||||||
|
target_scaler = None
|
||||||
|
if target_scaler_path:
|
||||||
|
if not target_scaler_path.is_file():
|
||||||
|
logger.warning(f"Target scaler file not found at {target_scaler_path}. Proceeding without scaler.")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
target_scaler = torch.load(target_scaler_path)
|
||||||
|
# Basic check if it looks like a scaler
|
||||||
|
if not isinstance(target_scaler, (BaseEstimator, TransformerMixin)):
|
||||||
|
logger.warning(f"Loaded object from {target_scaler_path} might not be a valid scaler ({type(target_scaler)}).")
|
||||||
|
# Decide if this should be a hard failure or just a warning
|
||||||
|
else:
|
||||||
|
logger.debug(f"Loaded target scaler from {target_scaler_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading target scaler from {target_scaler_path}: {e}", exc_info=True)
|
||||||
|
# Decide if this should be a hard failure
|
||||||
|
return None # Fail hard if scaler loading fails
|
||||||
|
loaded_artifacts['target_scaler'] = target_scaler
|
||||||
|
|
||||||
|
# 4. Initialize Model Architecture
|
||||||
|
# Ensure model config forecast horizon matches feature config (should be guaranteed by MainConfig validation)
|
||||||
|
if set(main_config.model.forecast_horizon) != set(main_config.features.forecast_horizon):
|
||||||
|
logger.warning(f"Mismatch between model ({main_config.model.forecast_horizon}) and feature ({main_config.features.forecast_horizon}) forecast horizons in config {config_path}. Using feature config.")
|
||||||
|
# This might indicate an issue with the saved config, but we proceed using the feature config horizon
|
||||||
|
# main_config.model.forecast_horizon = main_config.features.forecast_horizon # Correct it for model init? Risky.
|
||||||
|
|
||||||
|
model_instance = LSTMForecastLightningModule(
|
||||||
|
model_config=main_config.model,
|
||||||
|
train_config=main_config.training, # Pass train config if needed
|
||||||
|
input_size=input_size,
|
||||||
|
target_scaler=target_scaler # Pass scaler to model if it uses it internally during inference
|
||||||
|
)
|
||||||
|
logger.debug("Initialized model architecture.")
|
||||||
|
|
||||||
|
# 5. Load Model State Dictionary
|
||||||
|
if not model_path.is_file():
|
||||||
|
logger.error(f"Model checkpoint file not found at {model_path}")
|
||||||
|
return None
|
||||||
|
# Load onto CPU first to avoid GPU memory issues if the loading machine is different
|
||||||
|
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
|
||||||
|
# Adjust state dict keys if saved with 'model.' prefix from Lightning wrapper common during saving ckpt
|
||||||
|
if any(key.startswith('model.') for key in state_dict.get('state_dict', state_dict).keys()):
|
||||||
|
state_dict = {k.partition('model.')[2]: v for k, v in state_dict.get('state_dict', state_dict).items()}
|
||||||
|
logger.debug("Adjusted state dict keys (removed 'model.' prefix).")
|
||||||
|
|
||||||
|
# Load the state dict
|
||||||
|
# Use strict=False initially if unsure about exact key matching, but strict=True is safer
|
||||||
|
try:
|
||||||
|
load_result = model_instance.load_state_dict(state_dict, strict=True)
|
||||||
|
logger.debug(f"Model state loaded. Result: {load_result}")
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(f"Error loading state dict into model (strict=True): {e}. Trying strict=False.")
|
||||||
|
try:
|
||||||
|
load_result = model_instance.load_state_dict(state_dict, strict=False)
|
||||||
|
logger.warning(f"Model state loaded with strict=False. Result: {load_result}. Check for missing/unexpected keys.")
|
||||||
|
except Exception as e_false:
|
||||||
|
logger.error(f"Failed to load state dict even with strict=False: {e_false}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
model_instance.eval() # Set model to evaluation mode
|
||||||
|
loaded_artifacts['model_instance'] = model_instance
|
||||||
|
logger.info(f"Successfully loaded single model artifact: {model_path.name}")
|
||||||
|
|
||||||
|
return loaded_artifacts
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.error(f"A required file was not found during artifact loading for {model_path.parent}.", exc_info=True)
|
||||||
|
return None
|
||||||
|
except yaml.YAMLError as e:
|
||||||
|
logger.error(f"Error parsing YAML config file {config_path}: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load single model artifact from {model_path.parent}: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def load_ensemble_artifact(
|
||||||
|
ensemble_definition_path: Path,
|
||||||
|
hpo_base_output_dir: Path # Base directory where HPO study results (including ensemble JSON) are saved
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Loads artifacts for an ensemble based on its definition JSON file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ensemble_definition_path: Path to the _best_ensemble.json file.
|
||||||
|
hpo_base_output_dir: The base directory where the HPO study ran and
|
||||||
|
where relative paths within the JSON are anchored.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary containing 'ensemble_method', 'fold_artifacts' (a list
|
||||||
|
of dictionaries, each like the output of load_single_model_artifact),
|
||||||
|
'ensemble_feature_config', and 'ensemble_target_col', or None if loading fails.
|
||||||
|
"""
|
||||||
|
logger.info(f"Loading ensemble artifact definition from: {ensemble_definition_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not ensemble_definition_path.is_file():
|
||||||
|
logger.error(f"Ensemble definition file not found at: {ensemble_definition_path}")
|
||||||
|
return None
|
||||||
|
with open(ensemble_definition_path, 'r') as f:
|
||||||
|
ensemble_definition = json.load(f)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"Error decoding ensemble definition JSON file: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading ensemble definition: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Extract information from the definition
|
||||||
|
ensemble_method = ensemble_definition.get("ensemble_method")
|
||||||
|
fold_models_definitions = ensemble_definition.get("fold_models")
|
||||||
|
# Base directory for artifacts *relative to* hpo_base_output_dir
|
||||||
|
relative_artifacts_base_dir = ensemble_definition.get("ensemble_artifacts_base_dir")
|
||||||
|
|
||||||
|
if not ensemble_method or not fold_models_definitions:
|
||||||
|
logger.error("Ensemble definition file is missing 'ensemble_method' or 'fold_models' list.")
|
||||||
|
return None
|
||||||
|
if not relative_artifacts_base_dir:
|
||||||
|
logger.error("Ensemble definition file is missing 'ensemble_artifacts_base_dir'. Cannot locate fold artifacts.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# --- Determine Absolute Path to Fold Artifacts ---
|
||||||
|
# The paths inside fold_models are relative to ensemble_artifacts_base_dir,
|
||||||
|
# which itself is relative to hpo_base_output_dir.
|
||||||
|
absolute_artifacts_base_dir = hpo_base_output_dir / Path(relative_artifacts_base_dir)
|
||||||
|
logger.debug(f"Absolute base directory for fold artifacts: {absolute_artifacts_base_dir}")
|
||||||
|
if not absolute_artifacts_base_dir.is_dir():
|
||||||
|
logger.error(f"Calculated absolute artifact base directory does not exist or is not a directory: {absolute_artifacts_base_dir}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
loaded_fold_artifacts: List[Dict[str, Any]] = []
|
||||||
|
common_feature_config: Optional[FeatureConfig] = None
|
||||||
|
common_target_col: Optional[str] = None
|
||||||
|
|
||||||
|
logger.info(f"Loading artifacts for {len(fold_models_definitions)} folds defined in the ensemble...")
|
||||||
|
|
||||||
|
# --- Load Artifacts for Each Fold ---
|
||||||
|
for i, fold_def in enumerate(fold_models_definitions):
|
||||||
|
fold_id = fold_def.get("fold_id", i + 1)
|
||||||
|
logger.debug(f"--- Loading Fold {fold_id} ---")
|
||||||
|
|
||||||
|
model_path_rel = fold_def.get("model_path")
|
||||||
|
scaler_path_rel = fold_def.get("target_scaler_path")
|
||||||
|
input_size_path_rel = fold_def.get("input_size_path")
|
||||||
|
config_path_rel = fold_def.get("config_path")
|
||||||
|
|
||||||
|
if not model_path_rel or not input_size_path_rel or not config_path_rel:
|
||||||
|
logger.error(f"Fold {fold_id}: Definition is missing required path(s) (model, input_size, or config). Skipping fold.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Construct absolute paths for this fold's artifacts
|
||||||
|
try:
|
||||||
|
abs_model_path = (absolute_artifacts_base_dir / Path(model_path_rel)).resolve()
|
||||||
|
abs_input_size_path = (absolute_artifacts_base_dir / Path(input_size_path_rel)).resolve()
|
||||||
|
abs_config_path = (absolute_artifacts_base_dir / Path(config_path_rel)).resolve()
|
||||||
|
abs_scaler_path = (absolute_artifacts_base_dir / Path(scaler_path_rel)).resolve() if scaler_path_rel else None
|
||||||
|
|
||||||
|
logger.debug(f"Fold {fold_id} - Model Path: {abs_model_path}")
|
||||||
|
logger.debug(f"Fold {fold_id} - Config Path: {abs_config_path}")
|
||||||
|
logger.debug(f"Fold {fold_id} - Input Size Path: {abs_input_size_path}")
|
||||||
|
logger.debug(f"Fold {fold_id} - Scaler Path: {abs_scaler_path}")
|
||||||
|
|
||||||
|
# Load the artifacts for this single fold using the other function
|
||||||
|
single_fold_loaded_artifacts = load_single_model_artifact(
|
||||||
|
model_path=abs_model_path,
|
||||||
|
config_path=abs_config_path,
|
||||||
|
input_size_path=abs_input_size_path,
|
||||||
|
target_scaler_path=abs_scaler_path
|
||||||
|
)
|
||||||
|
|
||||||
|
if single_fold_loaded_artifacts:
|
||||||
|
# Add fold_id for reference
|
||||||
|
single_fold_loaded_artifacts['fold_id'] = fold_id
|
||||||
|
loaded_fold_artifacts.append(single_fold_loaded_artifacts)
|
||||||
|
logger.info(f"Successfully loaded artifacts for fold {fold_id}.")
|
||||||
|
|
||||||
|
# --- Consistency Check (Optional but Recommended) ---
|
||||||
|
# Store the feature config and target col from the first successful fold
|
||||||
|
# Then compare subsequent folds against these
|
||||||
|
current_feature_config = single_fold_loaded_artifacts['feature_config']
|
||||||
|
current_target_col = single_fold_loaded_artifacts['main_forecasting_config'].data.target_col
|
||||||
|
|
||||||
|
if common_feature_config is None:
|
||||||
|
common_feature_config = current_feature_config
|
||||||
|
common_target_col = current_target_col
|
||||||
|
logger.debug(f"Set common feature config and target column based on fold {fold_id}.")
|
||||||
|
else:
|
||||||
|
# Compare crucial feature engineering aspects
|
||||||
|
if common_feature_config.sequence_length != current_feature_config.sequence_length or \
|
||||||
|
set(common_feature_config.forecast_horizon) != set(current_feature_config.forecast_horizon) or \
|
||||||
|
common_feature_config.scaling_method != current_feature_config.scaling_method: # Add more checks if needed
|
||||||
|
logger.error(f"Fold {fold_id}: Feature configuration mismatch with previous folds. Cannot proceed with this ensemble definition.")
|
||||||
|
# You might want to compare more fields like lags, rolling_windows etc.
|
||||||
|
return None # Fail hard if configs are inconsistent
|
||||||
|
if common_target_col != current_target_col:
|
||||||
|
logger.error(f"Fold {fold_id}: Target column '{current_target_col}' differs from previous folds ('{common_target_col}'). Cannot proceed.")
|
||||||
|
return None # Fail hard
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to load artifacts for fold {fold_id}. Skipping fold.")
|
||||||
|
# Decide if ensemble loading should fail if *any* fold fails
|
||||||
|
# For now, we continue and will check if enough folds loaded later
|
||||||
|
|
||||||
|
except TypeError as e:
|
||||||
|
# Catch potential errors if paths are None or invalid types
|
||||||
|
logger.error(f"Fold {fold_id}: Error constructing artifact paths - check definition file content: {e}", exc_info=True)
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Fold {fold_id}: Unexpected error during loading: {e}", exc_info=True)
|
||||||
|
continue # Skip this fold
|
||||||
|
|
||||||
|
# --- Final Checks and Return ---
|
||||||
|
if not loaded_fold_artifacts:
|
||||||
|
logger.error("Failed to load artifacts for *any* fold in the ensemble.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Add a check if a minimum number of folds is required (e.g., > 1)
|
||||||
|
if len(loaded_fold_artifacts) < 1: # Or maybe check against len(fold_models_definitions)?
|
||||||
|
logger.error(f"Only successfully loaded {len(loaded_fold_artifacts)} folds, which might be insufficient for the ensemble.")
|
||||||
|
# Decide if this is an error or just a warning
|
||||||
|
return None
|
||||||
|
|
||||||
|
if common_feature_config is None or common_target_col is None:
|
||||||
|
# This should not happen if loaded_fold_artifacts is not empty, but check anyway
|
||||||
|
logger.error("Internal error: Could not determine common feature config or target column for the ensemble.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.info(f"Successfully loaded artifacts for {len(loaded_fold_artifacts)} ensemble folds.")
|
||||||
|
|
||||||
|
return {
|
||||||
|
'ensemble_method': ensemble_method,
|
||||||
|
'fold_artifacts': loaded_fold_artifacts, # List of dicts
|
||||||
|
'ensemble_feature_config': common_feature_config, # The common config
|
||||||
|
'ensemble_target_col': common_target_col # The common target column name
|
||||||
|
}
|
18
optimizer/utils/optim_config.py
Normal file
18
optimizer/utils/optim_config.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import List, Optional, Literal
|
||||||
|
|
||||||
|
class ModelEvalConfig(BaseModel):
|
||||||
|
"""Configuration for evaluating a single forecasting model or an ensemble."""
|
||||||
|
name: str = Field(..., description="Name of the forecasting model or ensemble.")
|
||||||
|
type: Literal['model', 'ensemble'] = Field(..., description="Type of evaluation artifact: 'model' for a single checkpoint, 'ensemble' for an ensemble definition JSON.")
|
||||||
|
model_path: str = Field(..., description="Path to the saved PyTorch model file (.ckpt for type='model') or the ensemble definition JSON file (.json for type='ensemble').")
|
||||||
|
model_config_path: str = Field(..., description="Path to the forecasting config (YAML) used for this model training (or for the best trial in an ensemble).")
|
||||||
|
target_scaler_path: Optional[str] = Field(None, description="Path to the target scaler file for the single model (or will be loaded per fold for ensemble).")
|
||||||
|
|
||||||
|
class OptimizationRunConfig(BaseModel):
|
||||||
|
"""Main configuration for running battery optimization with multiple models/ensembles."""
|
||||||
|
initial_b: float = Field(..., description="Initial state of charge of the battery (MWh).")
|
||||||
|
max_capacity: float = Field(..., description="Maximum energy capacity of the battery (MWh).")
|
||||||
|
max_rate: float = Field(..., description="Maximum charge/discharge power rate of the battery (MW).")
|
||||||
|
optimization_horizon_hours: int = Field(24, gt=0, description="The length of the time window (in hours) for optimization.")
|
||||||
|
models: List[ModelEvalConfig] = Field(..., description="List of forecasting models or ensembles to evaluate.")
|
603
optuna_ensemble_run.py
Normal file
603
optuna_ensemble_run.py
Normal file
@ -0,0 +1,603 @@
|
|||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
import copy # For deep copying config
|
||||||
|
from pathlib import Path
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
import yaml
|
||||||
|
import json # Import json to save best ensemble definition
|
||||||
|
|
||||||
|
import optuna
|
||||||
|
|
||||||
|
# Import necessary components from the forecasting_model package
|
||||||
|
from forecasting_model.utils.forecast_config_model import MainConfig
|
||||||
|
from forecasting_model.data_processing import (
|
||||||
|
load_raw_data,
|
||||||
|
TimeSeriesCrossValidationSplitter,
|
||||||
|
# prepare_fold_data_and_loaders used by run_single_fold
|
||||||
|
)
|
||||||
|
# Import the single fold runner from the main script
|
||||||
|
from forecasting_model_run import run_single_fold
|
||||||
|
from forecasting_model.train.ensemble_evaluation import run_ensemble_evaluation
|
||||||
|
from typing import List, Optional, Tuple, Dict, Any # Added Any for dictionary
|
||||||
|
|
||||||
|
# Import helper functions
|
||||||
|
from forecasting_model.utils.helper import load_config, set_seeds, aggregate_cv_metrics, save_results
|
||||||
|
|
||||||
|
# --- Suppress specific PL warnings about logger=True with no logger ---
|
||||||
|
# This is expected behavior in optuna_run.py where logger=False is intentional
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message=".*You called `self.log.*logger=True.*no logger configured.*",
|
||||||
|
category=UserWarning, # These specific warnings are often UserWarnings
|
||||||
|
module="pytorch_lightning.core.module"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Silence overly verbose libraries if needed
|
||||||
|
mpl_logger = logging.getLogger('matplotlib')
|
||||||
|
mpl_logger.setLevel(logging.WARNING)
|
||||||
|
pil_logger = logging.getLogger('PIL')
|
||||||
|
pil_logger.setLevel(logging.WARNING)
|
||||||
|
pl_logger = logging.getLogger('pytorch_lightning')
|
||||||
|
pl_logger.setLevel(logging.WARNING) # Set PL to WARNING by default, INFO/DEBUG set below if needed
|
||||||
|
|
||||||
|
# --- Basic Logging Setup ---
|
||||||
|
# Configure logging early. Level will be set properly later based on config.
|
||||||
|
logging.basicConfig(level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)-7s - %(message)s',
|
||||||
|
datefmt='%H:%M:%S')
|
||||||
|
# Get the root logger
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
# --- Argument Parsing ---
|
||||||
|
def parse_arguments():
|
||||||
|
"""Parses command-line arguments for Optuna Ensemble HPO."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run HPO optimizing ensemble performance using Optuna.",
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-c', '--config', type=str, default='forecasting_config.yaml',
|
||||||
|
help="Path to the YAML configuration file."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--output-dir', type=str, default=None,
|
||||||
|
help="Override base output directory for HPO results."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--keep-artifacts', action='store_true',
|
||||||
|
help="Prevent cleanup of trial directories after best trial is determined."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
# --- Optuna Objective Function ---
|
||||||
|
def objective(
|
||||||
|
trial: optuna.Trial,
|
||||||
|
base_config: MainConfig,
|
||||||
|
df: pd.DataFrame,
|
||||||
|
hpo_base_output_dir: Path # Pass base dir for trial outputs
|
||||||
|
) -> float: # Return the single ensemble metric to optimize
|
||||||
|
"""
|
||||||
|
Optuna objective function optimizing ensemble performance.
|
||||||
|
"""
|
||||||
|
logger.info(f"\n--- Starting Optuna Trial {trial.number} ---")
|
||||||
|
trial_start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Define trial-specific output directory for fold artifacts
|
||||||
|
trial_artifacts_dir = hpo_base_output_dir / "ensemble_runs_artifacts" / f"trial_{trial.number}"
|
||||||
|
trial_artifacts_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
logger.debug(f"Trial artifacts will be saved to: {trial_artifacts_dir}")
|
||||||
|
|
||||||
|
hpo_config = base_config.optuna
|
||||||
|
# Metric for pruning based on individual fold performance
|
||||||
|
validation_metric_monitor = hpo_config.metric_to_optimize
|
||||||
|
# Ensemble metric and method to optimize (e.g., MAE of the 'mean' ensemble)
|
||||||
|
ensemble_metric_optimize = 'MAE'
|
||||||
|
ensemble_method_optimize = 'mean'
|
||||||
|
optimization_direction = hpo_config.direction # 'minimize' or 'maximize'
|
||||||
|
worst_value = float('inf') if optimization_direction == 'minimize' else float('-inf')
|
||||||
|
|
||||||
|
# Store paths and details for all saved artifacts for this trial's folds
|
||||||
|
fold_artifact_details: List[Dict[str, Any]] = [] # Changed to list of dicts
|
||||||
|
|
||||||
|
# --- 1. Suggest Hyperparameters ---
|
||||||
|
try:
|
||||||
|
trial_config_dict = copy.deepcopy(base_config.model_dump(mode='python'))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Trial {trial.number}: Failed to deep copy base config: {e}", exc_info=True)
|
||||||
|
return worst_value
|
||||||
|
|
||||||
|
# ----- Suggest Hyperparameters -----
|
||||||
|
# Modify trial_config_dict using trial.suggest_*
|
||||||
|
trial_config_dict['training']['learning_rate'] = trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True)
|
||||||
|
trial_config_dict['training']['batch_size'] = trial.suggest_categorical('batch_size', [32, 64, 128, 256])
|
||||||
|
trial_config_dict['training']['loss_function'] = trial.suggest_categorical('loss_function', ['MSE', 'MAE'])
|
||||||
|
trial_config_dict['model']['hidden_size'] = trial.suggest_int('hidden_size', 18, 498, step=32)
|
||||||
|
trial_config_dict['model']['num_layers'] = trial.suggest_int('num_layers', 1, 8)
|
||||||
|
trial_config_dict['model']['dropout'] = trial.suggest_float('dropout', 0.0, 0.25, step=0.05)
|
||||||
|
trial_config_dict['features']['sequence_length'] = trial.suggest_int('sequence_length', 24, 168, step=12)
|
||||||
|
trial_config_dict['features']['scaling_method'] = trial.suggest_categorical('scaling_method', ['standard', 'minmax', None])
|
||||||
|
use_configured_lags = trial.suggest_categorical('use_lags', [True, False])
|
||||||
|
if not use_configured_lags: trial_config_dict['features']['lags'] = []
|
||||||
|
use_configured_rolling = trial.suggest_categorical('use_rolling_windows', [True, False])
|
||||||
|
if not use_configured_rolling: trial_config_dict['features']['rolling_window_sizes'] = []
|
||||||
|
trial_config_dict['features']['use_time_features'] = trial.suggest_categorical('use_time_features', [True, False])
|
||||||
|
trial_config_dict['features']['sinus_curve'] = trial.suggest_categorical('sinus_curve', [True, False])
|
||||||
|
trial_config_dict['features']['cosin_curve'] = trial.suggest_categorical('cosin_curve', [True, False])
|
||||||
|
trial_config_dict['features']['fill_nan'] = trial.suggest_categorical('fill_nan', ['ffill', 'bfill', 0])
|
||||||
|
# ----- End of Suggestions -----
|
||||||
|
|
||||||
|
# --- 2. Re-validate Trial Config ---
|
||||||
|
try:
|
||||||
|
trial_config_dict['features']['forecast_horizon'] = base_config.features.forecast_horizon
|
||||||
|
# Disable plotting during HPO runs to save time/resources
|
||||||
|
trial_config_dict['evaluation']['save_plots'] = False
|
||||||
|
trial_config = MainConfig(**trial_config_dict)
|
||||||
|
logger.info(f"Trial {trial.number} Parameters: {trial.params}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Trial {trial.number}: Invalid config generated: {e}", exc_info=True)
|
||||||
|
return worst_value
|
||||||
|
|
||||||
|
# --- Early check for data length ---
|
||||||
|
# ... (Keep the check as in optuna_run.py) ...
|
||||||
|
try:
|
||||||
|
if not isinstance(trial_config.features.forecast_horizon, list) or not trial_config.features.forecast_horizon:
|
||||||
|
raise ValueError("Trial config has invalid forecast_horizon list.")
|
||||||
|
min_data_for_sequence = trial_config.features.sequence_length + max(trial_config.features.forecast_horizon)
|
||||||
|
if min_data_for_sequence > len(df):
|
||||||
|
logger.warning(f"Trial {trial.number}: Skipped. sequence_length + max_horizon ({min_data_for_sequence}) exceeds data length ({len(df)}).")
|
||||||
|
# Report worst value so Optuna knows this trial failed badly
|
||||||
|
# Using study direction to determine the appropriate "worst" value
|
||||||
|
return worst_value
|
||||||
|
# raise optuna.TrialPruned() # Alternative: Prune instead of returning worst value
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Trial {trial.number}: Error during pre-check: {e}", exc_info=True)
|
||||||
|
return worst_value
|
||||||
|
|
||||||
|
# --- 3. Run Cross-Validation Training (Saving Artifacts) ---
|
||||||
|
all_fold_best_val_scores = {} # Store best val scores for pruning
|
||||||
|
actual_folds_trained = 0
|
||||||
|
# Store paths to saved models and scalers for this trial
|
||||||
|
fold_model_paths = []
|
||||||
|
fold_scaler_paths = []
|
||||||
|
try:
|
||||||
|
cv_splitter = TimeSeriesCrossValidationSplitter(trial_config.cross_validation, len(df))
|
||||||
|
|
||||||
|
for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split()):
|
||||||
|
fold_id = fold_num + 1
|
||||||
|
logger.info(f"Trial {trial.number}, Fold {fold_id}/{cv_splitter.n_splits}: Training model...")
|
||||||
|
current_fold_best_metric = None # Reset for each fold
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use run_single_fold - it handles training and saving artifacts
|
||||||
|
# Pass trial_output_dir so fold artifacts are saved per trial
|
||||||
|
fold_metrics, best_val_score, saved_model_path, saved_scaler_path, saved_input_size_path, saved_config_path = run_single_fold(
|
||||||
|
fold_num=fold_num,
|
||||||
|
train_idx=train_idx, val_idx=val_idx, test_idx=test_idx,
|
||||||
|
config=trial_config, # Use the config with trial's hyperparameters
|
||||||
|
full_df=df,
|
||||||
|
output_base_dir=trial_artifacts_dir # Save folds under trial dir
|
||||||
|
)
|
||||||
|
actual_folds_trained += 1
|
||||||
|
all_fold_best_val_scores[fold_id] = best_val_score
|
||||||
|
|
||||||
|
# Store all artifact paths for this fold
|
||||||
|
fold_artifact_details.append({
|
||||||
|
"fold_id": fold_id,
|
||||||
|
"model_path": str(saved_model_path) if saved_model_path else None,
|
||||||
|
"target_scaler_path": str(saved_scaler_path) if saved_scaler_path else None,
|
||||||
|
"input_size_path": str(saved_input_size_path) if saved_input_size_path else None,
|
||||||
|
"config_path": str(saved_config_path) if saved_config_path else None,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Check if the monitored validation metric was returned
|
||||||
|
if best_val_score is not None and np.isfinite(best_val_score):
|
||||||
|
current_fold_best_metric = best_val_score
|
||||||
|
logger.info(f"Trial {trial.number}, Fold {fold_id}: Best val score ({validation_metric_monitor}) = {current_fold_best_metric:.4f}")
|
||||||
|
else:
|
||||||
|
# Use worst value if metric is missing/invalid for pruning
|
||||||
|
logger.warning(f"Trial {trial.number}, Fold {fold_id}: Invalid or missing validation score ({validation_metric_monitor}). Using {worst_value} for pruning.")
|
||||||
|
current_fold_best_metric = worst_value # Assign worst for pruning report
|
||||||
|
|
||||||
|
# Report intermediate value (individual fold validation score) for pruning
|
||||||
|
trial.report(current_fold_best_metric, fold_num)
|
||||||
|
if trial.should_prune():
|
||||||
|
logger.info(f"Trial {trial.number}: Pruned after fold {fold_id}.")
|
||||||
|
raise optuna.TrialPruned()
|
||||||
|
|
||||||
|
except optuna.TrialPruned:
|
||||||
|
raise # Propagate prune signal
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Trial {trial.number}, Fold {fold_id}: Failed CV fold training: {e}", exc_info=True)
|
||||||
|
all_fold_best_val_scores[fold_id] = None # Mark fold as failed
|
||||||
|
# Continue to next fold if possible, but report worst value for this fold
|
||||||
|
trial.report(worst_value, fold_num)
|
||||||
|
# Optionally raise prune here if too many folds fail? Or let the ensemble eval handle it.
|
||||||
|
|
||||||
|
except optuna.TrialPruned:
|
||||||
|
logger.info(f"Trial {trial.number}: Pruned during CV training phase.")
|
||||||
|
return worst_value # Return worst value when pruned
|
||||||
|
except Exception as e:
|
||||||
|
logger.critical(f"Trial {trial.number}: Failed critically during CV training setup/loop: {e}", exc_info=True)
|
||||||
|
return worst_value
|
||||||
|
|
||||||
|
|
||||||
|
# --- 4. Run Ensemble Evaluation ---
|
||||||
|
if actual_folds_trained < 2:
|
||||||
|
logger.error(f"Trial {trial.number}: Only {actual_folds_trained} folds trained successfully. Cannot run ensemble evaluation.")
|
||||||
|
return worst_value # Not enough models for ensemble
|
||||||
|
|
||||||
|
logger.info(f"Trial {trial.number}: Starting Ensemble Evaluation using {actual_folds_trained} trained models...")
|
||||||
|
ensemble_metric_final = worst_value # Initialize to worst
|
||||||
|
best_ensemble_method_for_trial = None # Track the best method for this trial
|
||||||
|
try:
|
||||||
|
# Run evaluation using the artifacts saved in the trial's output directory
|
||||||
|
ensemble_results = run_ensemble_evaluation(
|
||||||
|
config=trial_config, # Pass trial config
|
||||||
|
output_base_dir=trial_artifacts_dir # Directory containing trial's fold subdirs
|
||||||
|
)
|
||||||
|
|
||||||
|
if ensemble_results:
|
||||||
|
# Aggregate the results to get the final objective value
|
||||||
|
ensemble_metrics_for_method = []
|
||||||
|
for fold_num, fold_res in ensemble_results.items():
|
||||||
|
if fold_res and ensemble_method_optimize in fold_res:
|
||||||
|
method_metrics = fold_res[ensemble_method_optimize]
|
||||||
|
if method_metrics and ensemble_metric_optimize in method_metrics:
|
||||||
|
metric_val = method_metrics[ensemble_metric_optimize]
|
||||||
|
if metric_val is not None and np.isfinite(metric_val):
|
||||||
|
ensemble_metrics_for_method.append(metric_val)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Trial {trial.number}: Invalid ensemble metric value found for fold {fold_num}, method '{ensemble_method_optimize}', metric '{ensemble_metric_optimize}'.")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Trial {trial.number}: Metric '{ensemble_metric_optimize}' not found for method '{ensemble_method_optimize}' in fold {fold_num}.")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Trial {trial.number}: Ensemble method '{ensemble_method_optimize}' results not found for fold {fold_num}.")
|
||||||
|
|
||||||
|
if not ensemble_metrics_for_method:
|
||||||
|
logger.error(f"Trial {trial.number}: No valid ensemble metrics found for method '{ensemble_method_optimize}', metric '{ensemble_metric_optimize}'.")
|
||||||
|
ensemble_metric_final = worst_value
|
||||||
|
else:
|
||||||
|
# Calculate the mean of the chosen ensemble metric across test folds
|
||||||
|
ensemble_metric_final = np.mean(ensemble_metrics_for_method)
|
||||||
|
logger.info(f"Trial {trial.number}: Final Ensemble Metric (Avg {ensemble_method_optimize} {ensemble_metric_optimize}): {ensemble_metric_final:.6f}")
|
||||||
|
|
||||||
|
# Determine the best ensemble method based on average performance across folds
|
||||||
|
# This requires re-calculating averages for *all* methods evaluated by run_ensemble_evaluation
|
||||||
|
all_ensemble_methods = set()
|
||||||
|
for fold_res in ensemble_results.values():
|
||||||
|
if fold_res: all_ensemble_methods.update(fold_res.keys())
|
||||||
|
|
||||||
|
avg_metrics_per_method = {}
|
||||||
|
for method in all_ensemble_methods:
|
||||||
|
method_metrics_across_folds = []
|
||||||
|
for fold_res in ensemble_results.values():
|
||||||
|
if fold_res and method in fold_res and ensemble_metric_optimize in fold_res[method]:
|
||||||
|
metric_val = fold_res[method][ensemble_metric_optimize]
|
||||||
|
if metric_val is not None and np.isfinite(metric_val):
|
||||||
|
method_metrics_across_folds.append(metric_val)
|
||||||
|
if method_metrics_across_folds:
|
||||||
|
avg_metrics_per_method[method] = np.mean(method_metrics_across_folds)
|
||||||
|
|
||||||
|
if avg_metrics_per_method:
|
||||||
|
if optimization_direction == 'minimize':
|
||||||
|
best_ensemble_method_for_trial = min(avg_metrics_per_method, key=avg_metrics_per_method.get)
|
||||||
|
else: # maximize
|
||||||
|
best_ensemble_method_for_trial = max(avg_metrics_per_method, key=avg_metrics_per_method.get)
|
||||||
|
logger.info(f"Trial {trial.number}: Best performing ensemble method for this trial (based on avg {ensemble_metric_optimize}): {best_ensemble_method_for_trial}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Trial {trial.number}: Could not determine best ensemble method based on average {ensemble_metric_optimize}.")
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.error(f"Trial {trial.number}: Ensemble evaluation function returned no results.")
|
||||||
|
ensemble_metric_final = worst_value
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Trial {trial.number}: Failed during ensemble evaluation phase: {e}", exc_info=True)
|
||||||
|
ensemble_metric_final = worst_value
|
||||||
|
|
||||||
|
|
||||||
|
# --- 5. Return Final Objective Value ---
|
||||||
|
trial_duration = time.perf_counter() - trial_start_time
|
||||||
|
logger.info(f"--- Trial {trial.number}: Finished ---")
|
||||||
|
logger.info(f" Final Objective Value (Avg Ensemble {ensemble_method_optimize} {ensemble_metric_optimize}): {ensemble_metric_final:.6f}")
|
||||||
|
logger.info(f" Total time: {trial_duration:.2f}s")
|
||||||
|
|
||||||
|
# Store ensemble evaluation results and best method in trial user attributes
|
||||||
|
# This makes it easier to retrieve the best ensemble details after the study
|
||||||
|
trial.set_user_attr("ensemble_evaluation_results", ensemble_results)
|
||||||
|
trial.set_user_attr("best_ensemble_method", best_ensemble_method_for_trial)
|
||||||
|
# trial.set_user_attr("fold_model_paths", fold_model_paths) # Removed
|
||||||
|
# trial.set_user_attr("fold_scaler_paths", fold_scaler_paths) # Removed
|
||||||
|
trial.set_user_attr("fold_artifact_details", fold_artifact_details) # Added comprehensive artifact details
|
||||||
|
|
||||||
|
return ensemble_metric_final
|
||||||
|
|
||||||
|
|
||||||
|
# --- Main HPO Execution ---
|
||||||
|
def run_hpo():
|
||||||
|
"""Main execution function for HPO optimizing ensemble performance."""
|
||||||
|
args = parse_arguments()
|
||||||
|
config_path = Path(args.config)
|
||||||
|
try:
|
||||||
|
base_config = load_config(config_path)
|
||||||
|
logger.info(f"Successfully loaded configuration from {config_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.critical(f"Failed to load configuration from {config_path}: {e}", exc_info=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# --- Setup Output Dir ---
|
||||||
|
if args.output_dir:
|
||||||
|
hpo_base_output_dir = Path(args.output_dir)
|
||||||
|
elif base_config.optuna.storage and base_config.optuna.storage.startswith("sqlite:///"):
|
||||||
|
hpo_base_output_dir = Path(base_config.optuna.storage.replace("sqlite:///", "")).parent
|
||||||
|
else:
|
||||||
|
# Fallback to default if output_dir is not in config either
|
||||||
|
main_output_dir_str = getattr(base_config, 'output_dir', 'output')
|
||||||
|
if not main_output_dir_str: # Handle empty string case
|
||||||
|
main_output_dir_str = 'output'
|
||||||
|
main_output_dir = Path(main_output_dir_str)
|
||||||
|
hpo_base_output_dir = main_output_dir / f'{base_config.optuna.study_name}_ensemble_hpo' # Specific subdir using study name
|
||||||
|
hpo_base_output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
logger.info(f"Using HPO output directory: {hpo_base_output_dir}")
|
||||||
|
|
||||||
|
# --- Setup Logging ---
|
||||||
|
try:
|
||||||
|
level_name = base_config.log_level.upper()
|
||||||
|
effective_log_level = logging.getLevelName(level_name)
|
||||||
|
# Ensure study name is filesystem-safe if used directly
|
||||||
|
safe_study_name = "".join(c if c.isalnum() or c in ('_', '-') else '_' for c in base_config.optuna.study_name)
|
||||||
|
log_file = hpo_base_output_dir / f"{safe_study_name}_ensemble_hpo.log"
|
||||||
|
file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8') # Specify encoding
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(name)-25s - %(levelname)-7s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
# Prevent adding duplicate handlers if script/function is called multiple times
|
||||||
|
if not any(isinstance(h, logging.FileHandler) and h.baseFilename == str(log_file.resolve()) for h in logger.handlers):
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
logger.setLevel(effective_log_level)
|
||||||
|
logger.info(f"Set log level to {level_name}. Logging HPO run to console and {log_file}")
|
||||||
|
if effective_log_level <= logging.DEBUG: logger.debug("Debug logging enabled.")
|
||||||
|
except (AttributeError, ValueError, TypeError) as e: # Added TypeError
|
||||||
|
logger.warning(f"Could not set log level from config: {e}. Defaulting to INFO.")
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
# Still try to log to a default file if possible
|
||||||
|
try:
|
||||||
|
log_file = hpo_base_output_dir / "default_ensemble_hpo.log"
|
||||||
|
file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8')
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(name)-25s - %(levelname)-7s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
if not any(isinstance(h, logging.FileHandler) and h.baseFilename == str(log_file.resolve()) for h in logger.handlers):
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
logger.info(f"Logging to default file: {log_file}")
|
||||||
|
except Exception as log_e:
|
||||||
|
logger.error(f"Failed to set up default file logging: {log_e}")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Setup Seeding ---
|
||||||
|
set_seeds(getattr(base_config, 'random_seed', 42))
|
||||||
|
|
||||||
|
# --- Load Data ---
|
||||||
|
try:
|
||||||
|
logger.info("Loading base dataset for HPO...")
|
||||||
|
df = load_raw_data(base_config.data)
|
||||||
|
logger.info(f"Base dataset loaded. Shape: {df.shape}")
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
logger.critical(f"Data file not found: {e}", exc_info=True)
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
logger.critical(f"Failed to load raw data for HPO: {e}", exc_info=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# --- Optuna Study Setup ---
|
||||||
|
try:
|
||||||
|
hpo_config = base_config.optuna
|
||||||
|
if not hpo_config.enabled:
|
||||||
|
logger.info("Optuna optimization is disabled in the configuration.")
|
||||||
|
sys.exit(0)
|
||||||
|
except AttributeError:
|
||||||
|
logger.critical("Optuna configuration section ('optuna') missing.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
storage_path = hpo_config.storage
|
||||||
|
if storage_path and storage_path.startswith("sqlite:///"):
|
||||||
|
db_path_str = storage_path.replace("sqlite:///", "")
|
||||||
|
if not db_path_str:
|
||||||
|
# Default filename if only 'sqlite:///' is provided
|
||||||
|
db_path = hpo_base_output_dir / f"{base_config.optuna.study_name}.db"
|
||||||
|
logger.warning(f"SQLite path was empty, defaulting to: {db_path}")
|
||||||
|
else:
|
||||||
|
db_path = Path(db_path_str)
|
||||||
|
|
||||||
|
if not db_path.is_absolute():
|
||||||
|
db_path = hpo_base_output_dir / db_path
|
||||||
|
db_path.parent.mkdir(parents=True, exist_ok=True) # Ensure parent dir exists
|
||||||
|
storage_path = f"sqlite:///{db_path.resolve()}"
|
||||||
|
logger.info(f"Using SQLite storage: {storage_path}")
|
||||||
|
elif storage_path:
|
||||||
|
logger.info(f"Using Optuna storage: {storage_path} (Assuming non-SQLite or pre-configured)")
|
||||||
|
else:
|
||||||
|
storage_path = None # Explicitly set to None for in-memory
|
||||||
|
logger.warning("No Optuna storage DB specified, using in-memory storage.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Single objective study based on ensemble performance
|
||||||
|
study = optuna.create_study(
|
||||||
|
study_name=hpo_config.study_name,
|
||||||
|
storage=storage_path,
|
||||||
|
direction=hpo_config.direction, # 'minimize' or 'maximize'
|
||||||
|
load_if_exists=True,
|
||||||
|
pruner=optuna.pruners.MedianPruner() if hpo_config.pruning else optuna.pruners.NopPruner()
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Run Optimization ---
|
||||||
|
logger.info(f"Starting Optuna optimization for ensemble performance: study='{hpo_config.study_name}', n_trials={hpo_config.n_trials}, direction='{hpo_config.direction}'")
|
||||||
|
study.optimize(
|
||||||
|
lambda trial: objective(trial, base_config, df, hpo_base_output_dir), # Pass base_config and output dir
|
||||||
|
n_trials=hpo_config.n_trials,
|
||||||
|
timeout=None,
|
||||||
|
gc_after_trial=True # Garbage collect after trial
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Report and Save Best Trial ---
|
||||||
|
logger.info("--- Optuna HPO Finished ---")
|
||||||
|
logger.info(f"Number of finished trials: {len(study.trials)}")
|
||||||
|
|
||||||
|
# Filter trials to find the actual best one (excluding pruned/failed)
|
||||||
|
try:
|
||||||
|
best_trial = study.best_trial
|
||||||
|
except ValueError: # Optuna raises ValueError if no trials completed successfully
|
||||||
|
best_trial = None
|
||||||
|
logger.warning("No successful trials completed. Cannot determine best trial.")
|
||||||
|
|
||||||
|
|
||||||
|
if best_trial:
|
||||||
|
logger.info("--- Best Trial ---")
|
||||||
|
logger.info(f" Trial Number: {best_trial.number}")
|
||||||
|
# Ensure value is not None before formatting
|
||||||
|
best_value_str = f"{best_trial.value:.6f}" if best_trial.value is not None else "N/A"
|
||||||
|
logger.info(f" Objective Value (Ensemble Metric): {best_value_str}")
|
||||||
|
logger.info(f" Hyperparameters:")
|
||||||
|
best_params = best_trial.params
|
||||||
|
for key, value in best_params.items():
|
||||||
|
logger.info(f" {key}: {value}")
|
||||||
|
|
||||||
|
# Save best hyperparameters
|
||||||
|
best_params_file = hpo_base_output_dir / f"{safe_study_name}_best_params.json"
|
||||||
|
try:
|
||||||
|
with open(best_params_file, 'w', encoding='utf-8') as f:
|
||||||
|
import json
|
||||||
|
json.dump(best_params, f, indent=4)
|
||||||
|
logger.info(f"Best hyperparameters saved to {best_params_file}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save best parameters: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# Save the corresponding config
|
||||||
|
best_config_file = hpo_base_output_dir / f"{safe_study_name}_best_config.yaml"
|
||||||
|
try:
|
||||||
|
# Use a fresh deepcopy to avoid modifying the original base_config
|
||||||
|
best_config_dict = copy.deepcopy(base_config.model_dump(mode='python'))
|
||||||
|
|
||||||
|
# Update with best trial's hyperparameters
|
||||||
|
# Pitfall: This assumes keys match exactly and exist in these sections.
|
||||||
|
# A more robust approach might involve checking key existence or
|
||||||
|
# iterating through the config structure if params are nested differently.
|
||||||
|
for key, value in best_params.items():
|
||||||
|
if key in best_config_dict.get('training', {}): best_config_dict['training'][key] = value
|
||||||
|
elif key in best_config_dict.get('model', {}): best_config_dict['model'][key] = value
|
||||||
|
elif key in best_config_dict.get('features', {}): best_config_dict['features'][key] = value
|
||||||
|
else:
|
||||||
|
logger.warning(f"Best parameter '{key}' not found in expected config sections (training, model, features).")
|
||||||
|
|
||||||
|
# Ensure forecast horizon is preserved from the original config
|
||||||
|
best_config_dict['features']['forecast_horizon'] = base_config.features.forecast_horizon
|
||||||
|
# Maybe remove optuna section from the best config?
|
||||||
|
# best_config_dict.pop('optuna', None)
|
||||||
|
|
||||||
|
with open(best_config_file, 'w', encoding='utf-8') as f:
|
||||||
|
yaml.dump(best_config_dict, f, default_flow_style=False, sort_keys=False, allow_unicode=True)
|
||||||
|
logger.info(f"Configuration for best trial saved to {best_config_file}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save best configuration: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# Retrieve saved artifact paths and best ensemble method from user attributes
|
||||||
|
best_trial_artifacts_dir = hpo_base_output_dir / "ensemble_runs_artifacts" / f"trial_{best_trial.number}"
|
||||||
|
best_ensemble_method = best_trial.user_attrs.get("best_ensemble_method")
|
||||||
|
# fold_model_paths = best_trial.user_attrs.get("fold_model_paths", []) # Removed
|
||||||
|
# fold_scaler_paths = best_trial.user_attrs.get("fold_scaler_paths", []) # Removed
|
||||||
|
fold_artifact_details = best_trial.user_attrs.get("fold_artifact_details", []) # Retrieve comprehensive details
|
||||||
|
|
||||||
|
if not best_trial_artifacts_dir.exists():
|
||||||
|
logger.error(f"Artifacts directory for best trial {best_trial.number} not found: {best_trial_artifacts_dir}. Cannot save best ensemble definition.")
|
||||||
|
elif not best_ensemble_method:
|
||||||
|
logger.error(f"Best ensemble method not recorded for best trial {best_trial.number}. Cannot save best ensemble definition.")
|
||||||
|
elif not fold_artifact_details: # Check if any artifact details were recorded
|
||||||
|
logger.error(f"No artifact details recorded for best trial {best_trial.number}. Cannot save best ensemble definition.")
|
||||||
|
else:
|
||||||
|
# --- Save Best Ensemble Definition ---
|
||||||
|
logger.info(f"Saving best ensemble definition for trial {best_trial.number}...")
|
||||||
|
|
||||||
|
ensemble_definition_file = hpo_base_output_dir / f"{safe_study_name}_best_ensemble.json"
|
||||||
|
|
||||||
|
best_ensemble_definition = {
|
||||||
|
"trial_number": best_trial.number,
|
||||||
|
"objective_value": best_trial.value,
|
||||||
|
"hyperparameters": best_trial.params,
|
||||||
|
"ensemble_method": best_ensemble_method,
|
||||||
|
"fold_models": [], # List of dictionaries for each fold's model and scaler, input_size, config
|
||||||
|
"ensemble_artifacts_base_dir": str(best_trial_artifacts_dir.relative_to(hpo_base_output_dir)) # Save path relative to hpo_base_output_dir
|
||||||
|
}
|
||||||
|
|
||||||
|
# Populate fold_models with paths to saved artifacts
|
||||||
|
for artifact_detail in fold_artifact_details:
|
||||||
|
fold_def = {
|
||||||
|
"fold_id": artifact_detail.get("fold_id"), # Include fold ID
|
||||||
|
"model_path": None,
|
||||||
|
"target_scaler_path": None,
|
||||||
|
"input_size_path": None,
|
||||||
|
"config_path": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process each path, making it relative if possible
|
||||||
|
for key in ["model_path", "target_scaler_path", "input_size_path", "config_path"]:
|
||||||
|
abs_path_str = artifact_detail.get(key)
|
||||||
|
if abs_path_str:
|
||||||
|
abs_path = Path(abs_path_str)
|
||||||
|
try:
|
||||||
|
# Make path relative to the trial artifacts dir
|
||||||
|
relative_path = str(abs_path.relative_to(best_trial_artifacts_dir))
|
||||||
|
fold_def[key] = relative_path
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(f"Failed to make path {abs_path} relative to {best_trial_artifacts_dir}. Saving absolute path for {key}.")
|
||||||
|
fold_def[key] = str(abs_path) # Fallback to absolute path
|
||||||
|
|
||||||
|
best_ensemble_definition["fold_models"].append(fold_def)
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(ensemble_definition_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(best_ensemble_definition, f, indent=4)
|
||||||
|
logger.info(f"Best ensemble definition saved to {ensemble_definition_file}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save best ensemble definition: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# --- Optional: Clean up artifact directories for non-best trials ---
|
||||||
|
if not args.keep_artifacts:
|
||||||
|
logger.info("Cleaning up artifact directories for non-best trials...")
|
||||||
|
ensemble_artifacts_base_dir = hpo_base_output_dir / "ensemble_runs_artifacts"
|
||||||
|
if ensemble_artifacts_base_dir.exists():
|
||||||
|
for item in ensemble_artifacts_base_dir.iterdir():
|
||||||
|
if item.is_dir():
|
||||||
|
# Check if this directory belongs to the best trial
|
||||||
|
if best_trial and item.name == f"trial_{best_trial.number}":
|
||||||
|
logger.debug(f"Keeping artifact directory for best trial: {item}")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
logger.debug(f"Removing artifact directory for non-best trial: {item}")
|
||||||
|
try:
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(item)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to remove directory {item}: {e}", exc_info=True)
|
||||||
|
else:
|
||||||
|
logger.debug(f"Artifacts base directory not found for cleanup: {ensemble_artifacts_base_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
except optuna.exceptions.StorageInternalError as e:
|
||||||
|
logger.critical(f"Optuna storage error: {e}. Check storage path/permissions: {storage_path}", exc_info=True)
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
logger.critical(f"An critical error occurred during the Optuna study: {e}", exc_info=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_hpo() # Changed main() to run_hpo()
|
605
optuna_run.py
605
optuna_run.py
@ -1,35 +1,44 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
import warnings # Import the warnings module
|
||||||
|
|
||||||
import copy # For deep copying config
|
import copy # For deep copying config
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import time
|
import time
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
|
import yaml # Added for saving best config
|
||||||
|
|
||||||
import optuna
|
import optuna
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
|
from pytorch_lightning.callbacks import EarlyStopping
|
||||||
# Import the Optuna callback for pruning
|
|
||||||
from optuna.integration.pytorch_lightning import PyTorchLightningPruningCallback
|
|
||||||
|
|
||||||
# Import necessary components from the forecasting_model package
|
# Import necessary components from the forecasting_model package
|
||||||
from forecasting_model.utils.config_model import MainConfig
|
from forecasting_model.utils.forecast_config_model import MainConfig
|
||||||
from forecasting_model.data_processing import (
|
from forecasting_model.data_processing import (
|
||||||
load_raw_data,
|
prepare_fold_data_and_loaders,
|
||||||
TimeSeriesCrossValidationSplitter,
|
split_data_classic
|
||||||
prepare_fold_data_and_loaders
|
|
||||||
)
|
)
|
||||||
from forecasting_model.model import LSTMForecastLightningModule
|
from forecasting_model.train.model import LSTMForecastLightningModule
|
||||||
# We don't need evaluation functions here, Optuna optimizes based on validation metrics
|
from forecasting_model.train.classic import run_classic_training
|
||||||
# from forecasting_model.evaluation import ...
|
|
||||||
from typing import Dict, List, Any, Optional
|
|
||||||
|
|
||||||
# Import helper functions from forecasting_model.py (or move them to a shared utils file)
|
|
||||||
# For now, let's redefine simplified versions or assume they exist in utils
|
# Import helper functions from forecasting_model_run.py
|
||||||
from forecasting_model_run import load_config, set_seeds # Assuming these are accessible
|
from forecasting_model.utils.helper import load_config, set_seeds
|
||||||
|
|
||||||
|
# Import the data processing functions
|
||||||
|
from forecasting_model.data_processing import load_raw_data
|
||||||
|
|
||||||
|
# --- Suppress specific PL warnings about logger=True with no logger ---
|
||||||
|
# This is expected behavior in optuna_run.py where logger=False is intentional
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message=".*You called `self.log.*logger=True.*no logger configured.*",
|
||||||
|
category=UserWarning,
|
||||||
|
module="pytorch_lightning.core.module"
|
||||||
|
)
|
||||||
|
|
||||||
# Silence overly verbose libraries if needed
|
# Silence overly verbose libraries if needed
|
||||||
mpl_logger = logging.getLogger('matplotlib')
|
mpl_logger = logging.getLogger('matplotlib')
|
||||||
@ -37,358 +46,396 @@ mpl_logger.setLevel(logging.WARNING)
|
|||||||
pil_logger = logging.getLogger('PIL')
|
pil_logger = logging.getLogger('PIL')
|
||||||
pil_logger.setLevel(logging.WARNING)
|
pil_logger.setLevel(logging.WARNING)
|
||||||
pl_logger = logging.getLogger('pytorch_lightning')
|
pl_logger = logging.getLogger('pytorch_lightning')
|
||||||
pl_logger.setLevel(logging.INFO) # Keep PL logs, but maybe set higher later
|
pl_logger.setLevel(logging.WARNING)
|
||||||
|
|
||||||
# --- Basic Logging Setup ---
|
# --- Basic Logging Setup ---
|
||||||
logging.basicConfig(level=logging.INFO,
|
logging.basicConfig(level=logging.INFO,
|
||||||
format='%(asctime)s - %(name)-25s - %(levelname)-7s - %(message)s',
|
format='%(asctime)s - %(levelname)-7s - %(message)s',
|
||||||
datefmt='%Y-%m-%d %H:%M:%S')
|
datefmt='%H:%M:%S')
|
||||||
root_logger = logging.getLogger()
|
# Get the root logger
|
||||||
logger = logging.getLogger(__name__) # Logger for this script
|
logger = logging.getLogger()
|
||||||
optuna_lg = logging.getLogger('optuna') # Optuna's logger
|
|
||||||
|
|
||||||
|
|
||||||
# --- Argument Parsing ---
|
# --- Argument Parsing (Simplified) ---
|
||||||
def parse_arguments():
|
def parse_arguments():
|
||||||
"""Parses command-line arguments for Optuna HPO."""
|
"""Parses command-line arguments for Optuna HPO."""
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Run Hyperparameter Optimization using Optuna for Time Series Forecasting.",
|
description="Run Hyperparameter Optimization using Optuna.",
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-c', '--config',
|
'-c', '--config', type=str, default='forecasting_config.yaml',
|
||||||
type=str,
|
help="Path to the YAML configuration file containing HPO settings."
|
||||||
default='config.yaml',
|
|
||||||
help="Path to the BASE YAML configuration file."
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--output-dir',
|
'--output-dir', type=str, default=None,
|
||||||
type=str,
|
help="Override output directory specified in the configuration file."
|
||||||
default='output/hpo_results',
|
|
||||||
help="Directory for saving Optuna study database and potentially best trial info."
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
'--study-name',
|
|
||||||
type=str,
|
|
||||||
default='lstm_forecasting_hpo',
|
|
||||||
help="Name for the Optuna study."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--n-trials',
|
|
||||||
type=int,
|
|
||||||
default=20,
|
|
||||||
help="Number of Optuna trials to run."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--storage-db',
|
|
||||||
type=str,
|
|
||||||
default=None, # Default to in-memory if not specified
|
|
||||||
help="Optuna storage database URL (e.g., 'sqlite:///output/hpo_results/study.db'). If None, uses in-memory storage."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--metric-to-optimize',
|
|
||||||
type=str,
|
|
||||||
default='val_mae_orig_scale',
|
|
||||||
help="Metric logged during validation to optimize (must match metric name in LightningModule)."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--direction',
|
|
||||||
type=str,
|
|
||||||
default='minimize',
|
|
||||||
choices=['minimize', 'maximize'],
|
|
||||||
help="Direction for Optuna optimization."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--pruning',
|
|
||||||
action='store_true',
|
|
||||||
help="Enable Optuna's trial pruning based on intermediate validation results."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--seed',
|
|
||||||
type=int,
|
|
||||||
default=42, # Fixed seed for the HPO process itself
|
|
||||||
help="Random seed for the main HPO script."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--debug',
|
|
||||||
action='store_true',
|
|
||||||
help="Override log level to DEBUG."
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
# --- Optuna Objective Function ---
|
# --- Optuna Objective Function ---
|
||||||
def objective(
|
def objective(
|
||||||
trial: optuna.Trial,
|
trial: optuna.Trial,
|
||||||
base_config: MainConfig, # Pass the loaded base config
|
base_config: MainConfig,
|
||||||
df: pd.DataFrame, # Pass the loaded data
|
df: pd.DataFrame,
|
||||||
output_base_dir: Path, # Base dir for any potential trial artifacts (usually avoid saving checkpoints here)
|
) -> float: # Ensure it returns a float
|
||||||
metric_to_optimize: str,
|
|
||||||
enable_pruning: bool
|
|
||||||
) -> float:
|
|
||||||
"""
|
"""
|
||||||
Optuna objective function. Trains and evaluates one set of hyperparameters
|
Optuna single-objective function using a classic train/val/test split.
|
||||||
using cross-validation and returns the average validation metric.
|
|
||||||
|
Returns:
|
||||||
|
- Validation score from the classic split (minimize).
|
||||||
"""
|
"""
|
||||||
logger.info(f"\n--- Starting Optuna Trial {trial.number} ---")
|
logger.info(f"\n--- Starting Optuna Trial {trial.number} ---")
|
||||||
trial_start_time = time.perf_counter()
|
trial_start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Get HPO settings
|
||||||
|
hpo_config = base_config.optuna
|
||||||
|
# Use the specific metric name from config for validation monitoring
|
||||||
|
validation_metric_monitor = hpo_config.metric_to_optimize
|
||||||
|
# The objective will be based on the validation metric
|
||||||
|
# Hardcode minimization for the single objective
|
||||||
|
monitor_mode = "min"
|
||||||
|
worst_value = float('inf') # Represents a poor validation score
|
||||||
|
|
||||||
# --- 1. Suggest Hyperparameters ---
|
# --- 1. Suggest Hyperparameters ---
|
||||||
# Make a deep copy of the base config to modify for this trial
|
|
||||||
# Using dict conversion and back might be easier than Pydantic's copy for deep nested updates
|
|
||||||
try:
|
try:
|
||||||
trial_config_dict = copy.deepcopy(base_config.dict()) # Convert to dict for easier modification
|
# Deep copy the base config dictionary to modify for the trial
|
||||||
|
trial_config_dict = copy.deepcopy(base_config.model_dump(mode='python')) # Use mode='python' for easier modification
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to deep copy base configuration: {e}")
|
logger.error(f"Trial {trial.number}: Failed to deep copy base configuration: {e}", exc_info=True)
|
||||||
raise # Cannot proceed without config
|
# Return worst value
|
||||||
|
return worst_value
|
||||||
|
|
||||||
# Suggest values for hyperparameters we want to tune
|
# ----- Suggest Hyperparameters -----
|
||||||
# Example suggestions (adjust ranges and types as needed):
|
# (Suggest parameters as before, modifying trial_config_dict)
|
||||||
trial_config_dict['training']['learning_rate'] = trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True)
|
trial_config_dict['training']['learning_rate'] = trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True)
|
||||||
trial_config_dict['training']['batch_size'] = trial.suggest_categorical('batch_size', [32, 64, 128])
|
trial_config_dict['training']['batch_size'] = trial.suggest_categorical('batch_size', [32, 64, 128, 256])
|
||||||
trial_config_dict['model']['hidden_size'] = trial.suggest_int('hidden_size', 32, 256, step=32)
|
trial_config_dict['training']['loss_function'] = trial.suggest_categorical('loss_function', ['MSE', 'MAE'])
|
||||||
trial_config_dict['model']['num_layers'] = trial.suggest_int('num_layers', 1, 4)
|
trial_config_dict['model']['hidden_size'] = trial.suggest_int('hidden_size', 18, 498, step=32)
|
||||||
trial_config_dict['model']['dropout'] = trial.suggest_float('dropout', 0.0, 0.5, step=0.1)
|
trial_config_dict['model']['num_layers'] = trial.suggest_int('num_layers', 1, 8)
|
||||||
# Example: Suggest sequence length? (Requires careful handling as it affects data prep)
|
trial_config_dict['model']['dropout'] = trial.suggest_float('dropout', 0.0, 0.25, step=0.05)
|
||||||
# trial_config_dict['features']['sequence_length'] = trial.suggest_int('sequence_length', 24, 168, step=24)
|
trial_config_dict['features']['sequence_length'] = trial.suggest_int('sequence_length', 24, 168, step=12)
|
||||||
|
# Note: forecast_horizon is NOT tuned here, taken from base_config
|
||||||
|
trial_config_dict['features']['scaling_method'] = trial.suggest_categorical('scaling_method', ['standard', 'minmax', None])
|
||||||
|
use_configured_lags = trial.suggest_categorical('use_lags', [True, False])
|
||||||
|
if not use_configured_lags: trial_config_dict['features']['lags'] = []
|
||||||
|
use_configured_rolling = trial.suggest_categorical('use_rolling_windows', [True, False])
|
||||||
|
if not use_configured_rolling: trial_config_dict['features']['rolling_window_sizes'] = []
|
||||||
|
trial_config_dict['features']['use_time_features'] = trial.suggest_categorical('use_time_features', [True, False])
|
||||||
|
trial_config_dict['features']['sinus_curve'] = trial.suggest_categorical('sinus_curve', [True, False])
|
||||||
|
trial_config_dict['features']['cosin_curve'] = trial.suggest_categorical('cosin_curve', [True, False])
|
||||||
|
trial_config_dict['features']['fill_nan'] = trial.suggest_categorical('fill_nan', ['ffill', 'bfill', 0])
|
||||||
|
# ----- End of Hyperparameter Suggestions -----
|
||||||
|
|
||||||
# --- 2. Re-validate Trial Config (Optional but Recommended) ---
|
# --- 2. Re-validate Trial Config ---
|
||||||
try:
|
try:
|
||||||
|
# Transfer forecast horizon from base config (not tuned)
|
||||||
|
trial_config_dict['features']['forecast_horizon'] = base_config.features.forecast_horizon
|
||||||
trial_config = MainConfig(**trial_config_dict)
|
trial_config = MainConfig(**trial_config_dict)
|
||||||
logger.debug(f"Trial {trial.number} Config: {trial_config.training} {trial_config.model} {trial_config.features}")
|
logger.info(f"Trial {trial.number} Parameters: {trial.params}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Trial {trial.number}: Invalid configuration generated from suggested parameters: {e}")
|
logger.error(f"Trial {trial.number}: Invalid configuration generated: {e}")
|
||||||
# Return a high value (for minimization) to penalize invalid configs
|
return worst_value
|
||||||
return float('inf')
|
|
||||||
|
# --- Early check for invalid sequence length / forecast horizon combination ---
|
||||||
|
try:
|
||||||
|
# Make sure forecast_horizon is a list before checking max()
|
||||||
|
if not isinstance(trial_config.features.forecast_horizon, list) or not trial_config.features.forecast_horizon:
|
||||||
|
raise ValueError("Trial config has invalid forecast_horizon list.")
|
||||||
|
min_data_for_sequence = trial_config.features.sequence_length + max(trial_config.features.forecast_horizon)
|
||||||
|
if min_data_for_sequence > len(df):
|
||||||
|
logger.warning(f"Trial {trial.number}: Skipped. sequence_length ({trial_config.features.sequence_length}) + "
|
||||||
|
f"max_horizon ({max(trial_config.features.forecast_horizon)}) "
|
||||||
|
f"exceeds data length ({len(df)}).")
|
||||||
|
# Optuna doesn't directly support skipping, so return worst values
|
||||||
|
return worst_value
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Trial {trial.number}: Error during pre-check: {e}", exc_info=True)
|
||||||
|
return worst_value
|
||||||
|
|
||||||
|
# --- 3. Run Classic Train/Test ---
|
||||||
|
logger.info(f"Trial {trial.number}: Starting Classic Run...")
|
||||||
|
validation_metric_value = worst_value # Initialize to worst
|
||||||
|
try:
|
||||||
|
n_samples = len(df)
|
||||||
|
val_frac = trial_config.cross_validation.val_size_fraction
|
||||||
|
test_frac = trial_config.cross_validation.test_size_fraction
|
||||||
|
train_idx_cl, val_idx_cl, test_idx_cl = split_data_classic(n_samples, val_frac, test_frac)
|
||||||
|
|
||||||
|
|
||||||
# --- 3. Run Cross-Validation for this Trial ---
|
# Prepare data for classic split
|
||||||
cv_splitter = TimeSeriesCrossValidationSplitter(trial_config.cross_validation, len(df))
|
train_loader_cl, val_loader_cl, test_loader_cl, target_scaler_cl, input_size_cl = prepare_fold_data_and_loaders(
|
||||||
fold_best_val_metrics: List[Optional[float]] = []
|
full_df=df, train_idx=train_idx_cl, val_idx=val_idx_cl, test_idx=test_idx_cl,
|
||||||
|
target_col=trial_config.data.target_col, feature_config=trial_config.features,
|
||||||
|
train_config=trial_config.training, eval_config=trial_config.evaluation
|
||||||
|
)
|
||||||
|
|
||||||
for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split()):
|
# Initialize Model
|
||||||
fold_id = fold_num + 1
|
model_cl = LSTMForecastLightningModule(
|
||||||
logger.info(f"Trial {trial.number}, Fold {fold_id}: Starting fold evaluation.")
|
model_config=trial_config.model, train_config=trial_config.training,
|
||||||
fold_start_time = time.perf_counter()
|
input_size=input_size_cl, target_scaler=target_scaler_cl
|
||||||
|
)
|
||||||
|
|
||||||
# Create a temporary directory for this specific trial+fold if needed (usually avoid for HPO)
|
# Callbacks (EarlyStopping and Pruning)
|
||||||
# fold_trial_dir = output_base_dir / f"trial_{trial.number}" / f"fold_{fold_id:02d}"
|
callbacks_cl = []
|
||||||
# fold_trial_dir.mkdir(parents=True, exist_ok=True)
|
if trial_config.training.early_stopping_patience and trial_config.training.early_stopping_patience > 0:
|
||||||
|
callbacks_cl.append(EarlyStopping(
|
||||||
|
monitor=validation_metric_monitor, # Monitor validation metric
|
||||||
|
patience=trial_config.training.early_stopping_patience,
|
||||||
|
mode=monitor_mode, verbose=False
|
||||||
|
))
|
||||||
|
|
||||||
try:
|
# Trainer for classic run
|
||||||
# --- Per-Fold Data Prep ---
|
trainer_cl = pl.Trainer(
|
||||||
# Use trial_config for batch sizes etc.
|
accelerator='gpu' if torch.cuda.is_available() else 'cpu', devices=1 if torch.cuda.is_available() else None,
|
||||||
train_loader, val_loader, _, target_scaler, input_size = prepare_fold_data_and_loaders(
|
max_epochs=trial_config.training.epochs, callbacks=callbacks_cl, logger=False, # logger=False as per original
|
||||||
full_df=df, train_idx=train_idx, val_idx=val_idx, test_idx=test_idx, # Test loader not needed here
|
enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False,
|
||||||
target_col=trial_config.data.target_col,
|
gradient_clip_val=getattr(trial_config.training, 'gradient_clip_val', None),
|
||||||
feature_config=trial_config.features,
|
precision=getattr(trial_config.training, 'precision', 32),
|
||||||
train_config=trial_config.training,
|
)
|
||||||
eval_config=trial_config.evaluation # Pass eval for batch size if needed by prep?
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Model Instantiation ---
|
# --- Train Model ---
|
||||||
current_model_config = trial_config.model.copy(update={'input_size': input_size,
|
logger.info(f"Trial {trial.number}: Fitting model on classic train/val split...")
|
||||||
'forecast_horizon': trial_config.features.forecast_horizon})
|
trainer_cl.fit(model_cl, train_dataloaders=train_loader_cl, val_dataloaders=val_loader_cl)
|
||||||
model = LSTMForecastLightningModule(
|
|
||||||
model_config=current_model_config,
|
|
||||||
train_config=trial_config.training,
|
|
||||||
target_scaler=target_scaler
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Callbacks for this Trial/Fold ---
|
# --- Get Best Validation Score ---
|
||||||
# Monitor the metric Optuna cares about
|
# Check early stopping callback first if it exists
|
||||||
monitor_mode = "min" if args.direction == "minimize" else "max"
|
best_score_tensor = None
|
||||||
|
if callbacks_cl and isinstance(callbacks_cl[0], EarlyStopping):
|
||||||
|
if hasattr(callbacks_cl[0], 'best_score') and callbacks_cl[0].best_score is not None:
|
||||||
|
best_score_tensor = callbacks_cl[0].best_score
|
||||||
|
elif callbacks_cl[0].stopped_epoch > 0 : # Early stopping triggered
|
||||||
|
logger.debug(f"Trial {trial.number}: Early stopping triggered, attempting to use last callback metric.")
|
||||||
|
|
||||||
callbacks = []
|
# If early stopping didn't capture best score, use last metrics from trainer
|
||||||
if trial_config.training.early_stopping_patience is not None and trial_config.training.early_stopping_patience > 0:
|
if best_score_tensor is None:
|
||||||
early_stopping = EarlyStopping(
|
metric_val = trainer_cl.callback_metrics.get(validation_metric_monitor)
|
||||||
monitor=metric_to_optimize,
|
if metric_val is not None:
|
||||||
patience=trial_config.training.early_stopping_patience,
|
best_score_tensor = metric_val # Use the last logged value
|
||||||
mode=monitor_mode,
|
|
||||||
verbose=False # Less verbose during HPO
|
|
||||||
)
|
|
||||||
callbacks.append(early_stopping)
|
|
||||||
|
|
||||||
# Add Optuna Pruning Callback
|
if best_score_tensor is None:
|
||||||
if enable_pruning:
|
logger.warning(f"Trial {trial.number}: Metric '{validation_metric_monitor}' not found in callbacks or metrics. Using {worst_value}.")
|
||||||
pruning_callback = PyTorchLightningPruningCallback(trial, monitor=metric_to_optimize)
|
validation_metric_value = worst_value
|
||||||
callbacks.append(pruning_callback)
|
else:
|
||||||
|
validation_metric_value = best_score_tensor.item()
|
||||||
|
logger.info(f"Trial {trial.number}: Best val score ({validation_metric_monitor}) = {validation_metric_value:.4f}")
|
||||||
|
|
||||||
# Optional: LR Monitor
|
# Report intermediate value for pruning (if enabled)
|
||||||
# callbacks.append(LearningRateMonitor(logging_interval='epoch'))
|
trial.report(validation_metric_value, trainer_cl.current_epoch)
|
||||||
|
if trial.should_prune():
|
||||||
|
logger.info(f"Trial {trial.number}: Pruned.")
|
||||||
|
raise optuna.TrialPruned()
|
||||||
|
|
||||||
# --- Trainer for this Trial/Fold ---
|
# Note: We don't run prediction/evaluation on the test set here,
|
||||||
trainer = pl.Trainer(
|
# as the objective is based on validation performance.
|
||||||
accelerator='gpu' if torch.cuda.is_available() else 'cpu',
|
# The test set evaluation will be done later for the best trial.
|
||||||
devices=1 if torch.cuda.is_available() else None,
|
|
||||||
max_epochs=trial_config.training.epochs,
|
|
||||||
callbacks=callbacks,
|
|
||||||
logger=False, # Disable default PL logging during HPO
|
|
||||||
enable_checkpointing=False, # Disable checkpoint saving during HPO
|
|
||||||
enable_progress_bar=False, # Disable progress bar for cleaner logs
|
|
||||||
enable_model_summary=False, # Disable model summary
|
|
||||||
gradient_clip_val=getattr(trial_config.training, 'gradient_clip_val', None),
|
|
||||||
precision=getattr(trial_config.training, 'precision', 32),
|
|
||||||
# Log GPU usage if available?
|
|
||||||
# log_gpu_memory='min_max',
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Fit the Model ---
|
logger.info(f"Trial {trial.number}: Finished Classic Run in {time.perf_counter() - trial_start_time:.2f}s")
|
||||||
logger.info(f"Trial {trial.number}, Fold {fold_id}: Fitting model...")
|
|
||||||
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
|
|
||||||
|
|
||||||
# --- Get Best Validation Score for Pruning/Reporting ---
|
except optuna.TrialPruned:
|
||||||
# Access the monitored metric value from the trainer's logged metrics or callback state
|
# Propagate prune signal, objective will be set to worst later by Optuna
|
||||||
# Ensure the key matches exactly what's logged in validation_step
|
raise
|
||||||
best_val_score = trainer.callback_metrics.get(metric_to_optimize)
|
except Exception as e:
|
||||||
|
logger.error(f"Trial {trial.number}: Failed during classic run phase: {e}", exc_info=True)
|
||||||
if best_val_score is None:
|
validation_metric_value = worst_value # Assign worst value if classic run fails
|
||||||
logger.warning(f"Trial {trial.number}, Fold {fold_id}: Metric '{metric_to_optimize}' not found in trainer metrics. Using inf/nan.")
|
finally:
|
||||||
# Handle cases where training might have failed or metric wasn't logged
|
# Clean up GPU memory after the run
|
||||||
best_val_score = float('inf') if monitor_mode == 'min' else float('-inf') # Return worst possible value
|
del model_cl, trainer_cl, train_loader_cl, val_loader_cl, test_loader_cl
|
||||||
else:
|
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
||||||
best_val_score = best_val_score.item() # Convert tensor to float
|
|
||||||
logger.info(f"Trial {trial.number}, Fold {fold_id}: Best validation score ({metric_to_optimize}) = {best_val_score:.4f}")
|
|
||||||
|
|
||||||
fold_best_val_metrics.append(best_val_score)
|
|
||||||
|
|
||||||
# --- Intermediate Pruning Report (Optional but Recommended) ---
|
|
||||||
# Report the intermediate value (best score for this fold) to Optuna
|
|
||||||
# trial.report(best_val_score, fold_id) # Report score at step `fold_id`
|
|
||||||
# Check if the trial should be pruned based on reported values
|
|
||||||
# if trial.should_prune():
|
|
||||||
# logger.info(f"Trial {trial.number}: Pruned after fold {fold_id}.")
|
|
||||||
# raise optuna.TrialPruned()
|
|
||||||
|
|
||||||
logger.info(f"Trial {trial.number}, Fold {fold_id}: Finished in {time.perf_counter() - fold_start_time:.2f}s")
|
|
||||||
|
|
||||||
except optuna.TrialPruned:
|
|
||||||
# Re-raise prune exception to let Optuna handle it
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Trial {trial.number}, Fold {fold_id}: Failed with error: {e}", exc_info=True)
|
|
||||||
# Record a failure for this fold (e.g., append NaN or worst value)
|
|
||||||
fold_best_val_metrics.append(float('inf') if monitor_mode == 'min' else float('-inf'))
|
|
||||||
# Optionally: Break the CV loop for this trial if one fold fails catastrophically?
|
|
||||||
# break
|
|
||||||
|
|
||||||
|
|
||||||
# --- 4. Calculate Average Metric Across Folds ---
|
# --- 4. Return Objective ---
|
||||||
if not fold_best_val_metrics:
|
|
||||||
logger.error(f"Trial {trial.number}: No validation results obtained across folds.")
|
|
||||||
return float('inf') # Return worst value
|
|
||||||
|
|
||||||
# Handle potential infinities or NaNs from failed folds
|
|
||||||
valid_scores = [s for s in fold_best_val_metrics if np.isfinite(s)]
|
|
||||||
if not valid_scores:
|
|
||||||
logger.error(f"Trial {trial.number}: All folds failed or produced non-finite scores.")
|
|
||||||
return float('inf')
|
|
||||||
|
|
||||||
average_val_metric = np.mean(valid_scores)
|
|
||||||
logger.info(f"--- Trial {trial.number}: Finished ---")
|
logger.info(f"--- Trial {trial.number}: Finished ---")
|
||||||
logger.info(f" Average validation {metric_to_optimize}: {average_val_metric:.5f}")
|
logger.info(f" Objective (Validation {validation_metric_monitor}): {validation_metric_value:.5f}")
|
||||||
logger.info(f" Total trial time: {time.perf_counter() - trial_start_time:.2f}s")
|
logger.info(f" Total time: {time.perf_counter() - trial_start_time:.2f}s")
|
||||||
|
|
||||||
# --- 5. Return Metric for Optuna ---
|
# Return the single objective (validation metric)
|
||||||
return average_val_metric
|
return float(validation_metric_value)
|
||||||
|
|
||||||
|
|
||||||
# --- Main HPO Execution ---
|
# --- Main HPO Execution ---
|
||||||
def run_hpo():
|
def run_hpo():
|
||||||
"""Main execution function for HPO."""
|
"""Main execution function for HPO."""
|
||||||
global args # Make args accessible in objective (simplifies passing) - or use functools.partial
|
|
||||||
args = parse_arguments()
|
args = parse_arguments()
|
||||||
config_path = Path(args.config)
|
config_path = Path(args.config)
|
||||||
output_dir = Path(args.output_dir)
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True) # Ensure output dir exists
|
|
||||||
|
|
||||||
# Adjust log level if debug flag is set
|
|
||||||
if args.debug:
|
|
||||||
root_logger.setLevel(logging.DEBUG)
|
|
||||||
optuna_lg.setLevel(logging.DEBUG)
|
|
||||||
pl_logger.setLevel(logging.DEBUG)
|
|
||||||
logger.debug("Debug mode enabled.")
|
|
||||||
else:
|
|
||||||
# Reduce verbosity during HPO runs
|
|
||||||
optuna_lg.setLevel(logging.WARNING)
|
|
||||||
pl_logger.setLevel(logging.INFO) # Keep INFO for PL start/end messages
|
|
||||||
|
|
||||||
# --- Configuration Loading ---
|
|
||||||
try:
|
try:
|
||||||
base_config = load_config(config_path)
|
base_config = load_config(config_path) # Load base config once
|
||||||
except Exception:
|
logger.info(f"Successfully loaded configuration from {config_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.critical(f"Failed to load configuration from {config_path}: {e}", exc_info=True)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# --- Seed Setting (for HPO script itself) ---
|
# Setup output dir...
|
||||||
set_seeds(args.seed)
|
if args.output_dir:
|
||||||
|
hpo_base_output_dir = Path(args.output_dir) # Use specific name for HPO dir
|
||||||
|
logger.info(f"Using HPO output directory from command line: {hpo_base_output_dir}")
|
||||||
|
elif base_config.optuna.storage and base_config.optuna.storage.startswith("sqlite:///"):
|
||||||
|
hpo_base_output_dir = Path(base_config.optuna.storage.replace("sqlite:///", "")).parent
|
||||||
|
logger.info(f"Using HPO output directory from Optuna storage path: {hpo_base_output_dir}")
|
||||||
|
else:
|
||||||
|
# Use output_dir from main config if available, otherwise default
|
||||||
|
main_output_dir = Path(getattr(base_config, 'output_dir', 'output'))
|
||||||
|
hpo_base_output_dir = main_output_dir / 'hpo_results'
|
||||||
|
logger.info(f"Using HPO output directory: {hpo_base_output_dir}")
|
||||||
|
hpo_base_output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# --- Load Data Once ---
|
# Setup logging... (ensure file handler uses hpo_base_output_dir)
|
||||||
# Assume data doesn't change based on HPs (unless sequence_length is tuned heavily)
|
|
||||||
try:
|
try:
|
||||||
logger.info("Loading base dataset...")
|
level_name = base_config.log_level.upper()
|
||||||
|
effective_log_level = logging.getLevelName(level_name)
|
||||||
|
log_file = hpo_base_output_dir / f"{base_config.optuna.study_name}_hpo.log"
|
||||||
|
file_handler = logging.FileHandler(log_file, mode='a')
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(name)-25s - %(levelname)-7s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
# Add handler only if it's not already added (e.g., if run_hpo is called multiple times)
|
||||||
|
if not any(isinstance(h, logging.FileHandler) and h.baseFilename == str(log_file.resolve()) for h in logger.handlers):
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
logger.setLevel(effective_log_level)
|
||||||
|
logger.info(f"Set log level to {level_name}. Logging HPO run to console and {log_file}")
|
||||||
|
if effective_log_level <= logging.DEBUG: logger.debug("Debug logging enabled.")
|
||||||
|
except (AttributeError, ValueError) as e:
|
||||||
|
logger.warning(f"Could not set log level from config. Defaulting to INFO. Error: {e}")
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
# Setup seeding...
|
||||||
|
try:
|
||||||
|
seed = base_config.random_seed
|
||||||
|
set_seeds(seed)
|
||||||
|
except AttributeError:
|
||||||
|
logger.warning("Config missing 'random_seed'. Using default seed 42.")
|
||||||
|
set_seeds(42)
|
||||||
|
|
||||||
|
# Load data...
|
||||||
|
try:
|
||||||
|
logger.info("Loading base dataset for HPO...")
|
||||||
df = load_raw_data(base_config.data)
|
df = load_raw_data(base_config.data)
|
||||||
logger.info("Base dataset loaded.")
|
logger.info(f"Base dataset loaded. Shape: {df.shape}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.critical(f"Failed to load raw data for HPO: {e}", exc_info=True)
|
logger.critical(f"Failed to load raw data for HPO: {e}", exc_info=True)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
# --- Optuna Study Setup ---
|
# --- Optuna Study Setup ---
|
||||||
storage_path = args.storage_db
|
try:
|
||||||
if storage_path:
|
hpo_config = base_config.optuna
|
||||||
# Ensure directory exists if using SQLite file storage
|
if not hpo_config.enabled:
|
||||||
db_path = Path(storage_path.replace("sqlite:///", ""))
|
logger.info("Optuna optimization is disabled in the configuration.")
|
||||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
sys.exit(0)
|
||||||
storage_path = f"sqlite:///{db_path.resolve()}" # Use absolute path
|
except AttributeError:
|
||||||
logger.info(f"Using Optuna storage: {storage_path}")
|
logger.critical("Optuna configuration section ('optuna') missing.")
|
||||||
else:
|
sys.exit(1)
|
||||||
logger.warning("No Optuna storage DB specified, using in-memory storage (results lost on exit).")
|
|
||||||
|
|
||||||
|
storage_path = hpo_config.storage
|
||||||
|
if storage_path and storage_path.startswith("sqlite:///"):
|
||||||
|
db_path = Path(storage_path.replace("sqlite:///", ""))
|
||||||
|
if not db_path.is_absolute():
|
||||||
|
db_path = hpo_base_output_dir / db_path # Relative to HPO output dir
|
||||||
|
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
storage_path = f"sqlite:///{db_path.resolve()}"
|
||||||
|
logger.info(f"Using Optuna storage: {storage_path}")
|
||||||
|
elif storage_path:
|
||||||
|
logger.info(f"Using Optuna storage: {storage_path} (non-SQLite)")
|
||||||
|
else:
|
||||||
|
logger.warning("No Optuna storage DB specified in config, using in-memory storage.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Create or load the study
|
# Change to single objective 'minimize'
|
||||||
study = optuna.create_study(
|
study = optuna.create_study(
|
||||||
study_name=args.study_name,
|
study_name=hpo_config.study_name,
|
||||||
storage=storage_path,
|
storage=storage_path,
|
||||||
direction=args.direction,
|
direction="minimize", # Changed to single direction
|
||||||
load_if_exists=True, # Load previous results if study exists
|
load_if_exists=True,
|
||||||
pruner=optuna.pruners.MedianPruner() if args.pruning else optuna.pruners.NopPruner() # Example pruner
|
pruner=optuna.pruners.MedianPruner() if hpo_config.pruning else optuna.pruners.NopPruner()
|
||||||
)
|
)
|
||||||
|
# Remove multi-objective check/attribute setting
|
||||||
|
# if not study._is_multi_objective:
|
||||||
|
# logger.warning(f"Study '{hpo_config.study_name}' exists but is not multi-objective.")
|
||||||
|
# elif 'objective_names' not in study.user_attrs:
|
||||||
|
# study.set_user_attr('objective_names', objective_names)
|
||||||
|
|
||||||
|
|
||||||
# --- Run Optimization ---
|
# --- Run Optimization ---
|
||||||
logger.info(f"Starting Optuna optimization: study='{args.study_name}', n_trials={args.n_trials}, metric='{args.metric_to_optimize}', direction='{args.direction}'")
|
logger.info(f"Starting Optuna single-objective optimization: study='{hpo_config.study_name}', n_trials={hpo_config.n_trials}") # Updated log message
|
||||||
study.optimize(
|
study.optimize(
|
||||||
lambda trial: objective(trial, base_config, df, output_dir, args.metric_to_optimize, args.pruning),
|
lambda trial: objective(trial, base_config, df), # Pass base_config
|
||||||
n_trials=args.n_trials,
|
n_trials=hpo_config.n_trials,
|
||||||
timeout=None # Optional: Set timeout in seconds
|
timeout=None,
|
||||||
# Optional: Add callbacks (e.g., logging callback)
|
gc_after_trial=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- Report Best Trial ---
|
# --- Report and Process Best Trial ---
|
||||||
logger.info("--- Optuna HPO Finished ---")
|
logger.info("--- Optuna HPO Finished ---")
|
||||||
logger.info(f"Number of finished trials: {len(study.trials)}")
|
logger.info(f"Number of finished trials: {len(study.trials)}")
|
||||||
|
|
||||||
|
# Get the single best trial
|
||||||
best_trial = study.best_trial
|
best_trial = study.best_trial
|
||||||
logger.info(f"Best trial number: {best_trial.number}")
|
if best_trial is None:
|
||||||
logger.info(f" Best validation {args.metric_to_optimize}: {best_trial.value:.5f}")
|
logger.warning("Optuna study finished, but no successful trial was completed.")
|
||||||
logger.info(" Best hyperparameters:")
|
else:
|
||||||
for key, value in best_trial.params.items():
|
logger.info(f"Best trial found (Trial {best_trial.number}):")
|
||||||
logger.info(f" {key}: {value}")
|
# Log details for the best trial
|
||||||
|
validation_metric_monitor = base_config.optuna.metric_to_optimize
|
||||||
|
logger.info(f" Objective ({validation_metric_monitor}): {best_trial.value:.5f}") # Use .value for single objective
|
||||||
|
logger.info(f" Hyperparameters:")
|
||||||
|
for key, value in best_trial.params.items():
|
||||||
|
logger.info(f" {key}: {value}")
|
||||||
|
|
||||||
# --- Save Best Hyperparameters (Optional) ---
|
# --- Re-run and Save Artifacts for the Best Trial ---
|
||||||
best_params_file = output_dir / f"{args.study_name}_best_params.json"
|
logger.info(f"-> Re-running Best Trial {best_trial.number} to save artifacts...")
|
||||||
try:
|
trial_output_dir = hpo_base_output_dir / f"best_trial_num{best_trial.number}" # Simplified directory name
|
||||||
with open(best_params_file, 'w') as f:
|
trial_output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
import json
|
|
||||||
json.dump(best_trial.params, f, indent=4)
|
|
||||||
logger.info(f"Best hyperparameters saved to {best_params_file}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to save best parameters: {e}")
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. Create config for this trial
|
||||||
|
best_config_dict = copy.deepcopy(base_config.model_dump(mode='python'))
|
||||||
|
# Update with trial's hyperparameters
|
||||||
|
for key, value in best_trial.params.items(): # Use best_trial.params
|
||||||
|
# Need to place the param in the correct nested config dict
|
||||||
|
if key in best_config_dict['training']: best_config_dict['training'][key] = value
|
||||||
|
elif key in best_config_dict['model']: best_config_dict['model'][key] = value
|
||||||
|
elif key in best_config_dict['features']: best_config_dict['features'][key] = value
|
||||||
|
# Add more sections if HPO tunes them
|
||||||
|
|
||||||
|
# Add non-tuned features forecast_horizon back
|
||||||
|
best_config_dict['features']['forecast_horizon'] = base_config.features.forecast_horizon
|
||||||
|
|
||||||
|
# Ensure evaluation plots and model saving are enabled for this final run
|
||||||
|
best_config_dict['evaluation']['save_plots'] = True
|
||||||
|
best_config_dict['training']['save_model'] = True # Assuming you want to save the best model
|
||||||
|
|
||||||
|
# Validate the final config for this trial
|
||||||
|
best_trial_config = MainConfig(**best_config_dict)
|
||||||
|
|
||||||
|
# Save the specific config used for this run
|
||||||
|
with open(trial_output_dir / "best_config.yaml", 'w') as f:
|
||||||
|
yaml.dump(best_config_dict, f, default_flow_style=False, sort_keys=False)
|
||||||
|
|
||||||
|
# 2. Run classic training (which saves model & plots)
|
||||||
|
logger.info(f"-> Running classic training for Best Trial {best_trial.number}...")
|
||||||
|
# Pass the specific config and output directory
|
||||||
|
run_classic_training(
|
||||||
|
config=best_trial_config,
|
||||||
|
full_df=df,
|
||||||
|
output_base_dir=trial_output_dir # outputs -> hpo_results/best_trial_num<n>/classic_run/
|
||||||
|
)
|
||||||
|
logger.info(f"-> Finished re-running and saving artifacts for Best Trial {best_trial.number} to {trial_output_dir}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"-> Failed to re-run or save artifacts for Best Trial {best_trial.number}: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# --- Save Best Hyperparameters ---
|
||||||
|
best_params_file = hpo_base_output_dir / f"{hpo_config.study_name}_best_params.json" # Simplified filename
|
||||||
|
try:
|
||||||
|
with open(best_params_file, 'w') as f:
|
||||||
|
import json
|
||||||
|
json.dump(best_trial.params, f, indent=4) # Use best_trial.params
|
||||||
|
logger.info(f"Hyperparameters of the best trial saved to {best_params_file}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save parameters for best trial: {e}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.critical(f"An critical error occurred during the Optuna study: {e}", exc_info=True)
|
logger.critical(f"A critical error occurred during the Optuna study: {e}", exc_info=True)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Reference in New Issue
Block a user