intermediate backup

This commit is contained in:
2025-05-02 14:36:19 +02:00
parent 980696aef5
commit 2b0a5728d4
16 changed files with 2780 additions and 316 deletions

View File

@ -14,7 +14,6 @@ mpl_logger.setLevel(logging.WARNING) # Example: set to WARNING or ERROR
# --- Basic Logging Setup ---
# Configure logging early to catch basic issues.
# The level might be adjusted after config loading.
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)-7s - %(message)s',
datefmt='%H:%M:%S')

View File

@ -1,22 +1,88 @@
# Configuration for the forecasting model EDA
# This file defines the settings for data loading, analysis, and visualization
# Configuration for Time Series Forecasting Pipeline
# -- General Settings --
log_level: INFO # Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
debug: true
project_name: "TimeSeriesForecasting" # Name for the project/run
random_seed: 42 # Optional: Global random seed for reproducibility
# -- IO Settings --
data_file: data/Day-ahead_Prices_60min.csv # Path to the input data CSV relative to project root
output_dir: output/reports # Directory to save generated plots and report artifacts
latex_template_file: null # Path to the LaTeX template file relative to project root
# --- Data Loading Configuration ---
data:
data_path: "data/Day-ahead_Prices_60min.csv" # Path to your CSV
# --- Raw Data Specifics ---
raw_datetime_col: "MTU (CET/CEST)" # EXACT name in your raw CSV
raw_target_col: "Day-ahead Price [EUR/MWh]" # EXACT name in your raw CSV
raw_datetime_format: '%d.%m.%Y %H:%M' # Format string is now hardcoded in load_raw_data based on analysis
# --- Standardized Names & Processing ---
datetime_col: "Timestamp" # Desired name for the index after processing
target_col: "Price" # Desired name for the target column after processing
expected_frequency: "h" # Expected frequency ('h', 'D', '15min', etc. or null)
fill_initial_target_nans: true # Fill target NaNs immediately after loading?
# -- Zoom Settings (Plotting and Analysis) --
# Optional: Specify a date range for zoomed-in plots (YYYY-MM-DD format)
# Example: zoom_start_date: "2023-01-01"
# Example: zoom_end_date: "2023-12-31"
zoom_start_date: null # Default to null
zoom_end_date: null # Default to null
# --- Feature Engineering & Preprocessing Configuration ---
features:
sequence_length: 72 # REQUIRED: Lookback window size (e.g., 72 hours = 3 days)
forecast_horizon: 24 # REQUIRED: Number of steps ahead to predict (e.g., 24 hours)
lags: [24, 48, 72, 168] # List of lag features to create (e.g., 1 day, 2 days, 3 days, 1 week)
rolling_window_sizes: [24, 72, 168] # List of window sizes for rolling stats (mean, std)
use_time_features: true # Create calendar features (hour, dayofweek, month, etc.)?
sinus_curve: true # Create sinusoidal feature for time of day?
cosin_curve: true # Create cosinusoidal feature for time of day?
fill_nan: 'ffill' # Method to fill NaNs created by lags/rolling windows ('ffill', 'bfill', 0, etc.)
scaling_method: 'standard' # Scaling method ('standard', 'minmax', or null/None for no scaling) Fit per fold.
# -- Data Settings --
expected_data_frequency: "h" # Expected frequency of the time series data (h=hourly, D=daily, M=monthly, Y=yearly)
# Optional: Wavelet Transform configuration
wavelet_transform:
apply: false # Apply wavelet transform?
target_or_feature: "target" # Apply to 'target' before other features, or 'feature' after?
wavelet_type: "db4" # Type of wavelet (e.g., 'db4', 'sym4')
level: 3 # Decomposition level (must be > 0)
use_coeffs: ["approx", "detail_1"] # Which coefficients to use as features
# Optional: Feature Clipping configuration
clipping:
apply: false # Apply clipping to generated features (excluding target)?
clip_min: 0 # Minimum value for clipping
clip_max: 400 # Maximum value for clipping
# --- Model Architecture Configuration ---
model:
# input_size: null # Removed: Calculated automatically based on features and passed directly to model
hidden_size: 128 # REQUIRED: Number of units in LSTM hidden layers
num_layers: 2 # REQUIRED: Number of LSTM layers
dropout: 0.2 # REQUIRED: Dropout rate (between 0.0 and 1.0)
use_residual_skips: false # Add residual connection from input to LSTM output?
# forecast_horizon: null # Set automatically from features.forecast_horizon
# --- Training Configuration (PyTorch Lightning) ---
training:
batch_size: 64 # REQUIRED: Batch size for training
epochs: 50 # REQUIRED: Max number of training epochs per fold
learning_rate: 0.001 # REQUIRED: Initial learning rate for Adam optimizer
loss_function: "MSE" # Loss function ('MSE' or 'MAE')
early_stopping_patience: 10 # Optional: Patience for early stopping (epochs). Set null/None to disable. Must be >= 1 if set.
scheduler_step_size: null # Optional: Step size for StepLR scheduler (epochs). Set null/None to disable. Must be > 0 if set.
scheduler_gamma: null # Optional: Gamma factor for StepLR scheduler. Set null/None to disable. Must be 0 < gamma < 1 if set.
gradient_clip_val: 1.0 # Optional: Value for gradient clipping. Set null/None to disable. Must be >= 0.0 if set.
num_workers: 0 # Number of workers for DataLoader (>= 0). 0 means data loading happens in the main process.
precision: 32 # Training precision (16, 32, 64, 'bf16')
# --- Cross-Validation Configuration (Rolling Window) ---
cross_validation:
n_splits: 5 # REQUIRED: Number of CV folds (must be > 0)
test_size_fraction: 0.1 # REQUIRED: Fraction of the *fixed training window size* for the test set (0 < frac < 1)
val_size_fraction: 0.1 # REQUIRED: Fraction of the *fixed training window size* for the validation set (0 < frac < 1)
initial_train_size: null # Optional: Size of the fixed training window (integer samples or float fraction of total data > 0). If null, estimated automatically.
# --- Evaluation Configuration ---
evaluation:
eval_batch_size: 128 # REQUIRED: Batch size for evaluation/testing (must be > 0)
save_plots: true # Save evaluation plots (predictions, residuals)?
plot_sample_size: 1000 # Optional: Max number of points in time series plots (must be > 0 if set)
# --- Optuna Hyperparameter Optimization Configuration ---
optuna:
enabled: false # Enable Optuna HPO? If true, requires optuna.py script.
n_trials: 20 # Number of trials to run (must be > 0)
storage: null # Optional: Optuna storage URL (e.g., "sqlite:///output/hpo_results/study.db"). If null, uses in-memory.
direction: "minimize" # Optimization direction ('minimize' or 'maximize')
metric_to_optimize: "val_mae_orig_scale" # Metric logged by LightningModule to optimize
pruning: true # Enable Optuna trial pruning?

View File

@ -1,75 +0,0 @@
import argparse
import logging
import sys
from pathlib import Path
import time
# Import necessary components from your project structure
from data_analysis.utils.config_model import load_settings, Settings # Import loading function and model
from data_analysis.analysis.pipeline import run_eda_pipeline # Import the pipeline entry point
# Silence overly verbose libraries if needed (e.g., matplotlib)
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING) # Example: set to WARNING or ERROR
# --- Basic Logging Setup ---
# Configure logging early to catch basic issues.
# The level might be adjusted after config loading.
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)-7s - %(message)s',
datefmt='%H:%M:%S')
# Get the root logger
logger = logging.getLogger()
# --- Argument Parsing ---
def parse_arguments():
"""Parses command-line arguments."""
parser = argparse.ArgumentParser(
description="Run the Energy Forecasting EDA pipeline.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
'-c', '--config',
type=str,
default='config.yaml', # Provide a default config file name
help="Path to the YAML configuration file."
)
# Add other potential command-line overrides here if needed later
# parser.add_argument('--debug', action='store_true', help="Override log level to DEBUG.")
args = parser.parse_args()
return args
# --- Main Execution ---
def main():
"""Main execution function."""
args = parse_arguments()
config_path = Path(args.config)
start_time = time.perf_counter()
# --- Configuration Loading ---
_ = load_settings(config_path)
logger.info(f"Using configuration from: {config_path.resolve()} (or defaults if loading failed)")
# --- Pipeline Execution ---
try:
# Call the main function from your pipeline module
run_eda_pipeline()
end_time = time.perf_counter()
logger.info(f"Main script finished successfully in {end_time - start_time:.2f} seconds.")
except SystemExit as e:
# Catch SystemExit if pipeline runner exits intentionally
logger.warning(f"Pipeline exited with code {e.code}.")
sys.exit(e.code) # Propagate exit code
except Exception as e:
logger.critical(f"An critical error occurred during pipeline execution: {e}", exc_info=True)
end_time = time.perf_counter()
logger.info(f"Main script failed after {end_time - start_time:.2f} seconds.")
sys.exit(1)
return
if __name__ == "__main__":
main()
exit(1)

View File

@ -5,4 +5,39 @@ This module provides a configurable PyTorch-based LSTM model for time series for
with support for feature engineering, cross-validation, and evaluation.
"""
__version__ = "0.1.0"
__version__ = "0.1.0"
# Expose core components for easier import
from .data_processing import (
load_raw_data,
engineer_features,
TimeSeriesCrossValidationSplitter,
prepare_fold_data_and_loaders,
TimeSeriesDataset
)
from .model import LSTMForecastLightningModule
from .evaluation import (
evaluate_fold_predictions,
# Optionally expose the standalone evaluation utility if needed externally
# evaluate_model_on_fold_test_set
)
# Expose main configuration class from utils
from .utils import MainConfig
# Expose the main execution script function if it's intended to be callable as a function
# from .forecasting_model import run # Assuming the main script is named forecasting_model.py
# Define __all__ for explicit public API (optional but good practice)
__all__ = [
"load_raw_data",
"engineer_features",
"TimeSeriesCrossValidationSplitter",
"prepare_fold_data_and_loaders",
"TimeSeriesDataset",
"LSTMForecastLightningModule",
"evaluate_fold_predictions",
# "evaluate_model_on_fold_test_set", # Uncomment if exposed
"MainConfig",
# "run", # Uncomment if exposed
]

View File

@ -1,67 +1,751 @@
import logging
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from typing import Tuple, Generator, List, Optional
from utils.config_model import DataConfig, FeatureConfig, TrainingConfig, EvaluationConfig
from typing import Tuple, Generator, List, Optional, Union, Dict, Literal, Type
# Use relative import for utils within the package
from .utils.config_model import DataConfig, FeatureConfig, TrainingConfig, EvaluationConfig, CrossValidationConfig
# Optional: Import wavelet library if needed later
# import pywt
logger = logging.getLogger(__name__)
# --- Data Loading ---
def load_raw_data(config: DataConfig) -> pd.DataFrame:
"""
Load and preprocess raw data from CSV.
Load raw time series data from a CSV file, handling specific formats,
performing initial cleaning, frequency checks, and NaN filling based on config.
Args:
config: DataConfig object containing file path, raw/standard column names,
frequency settings, and NaN handling flags.
Returns:
DataFrame with a standardized datetime index (named config.datetime_col)
and a standardized, cleaned target column (named config.target_col).
Raises:
FileNotFoundError: If the data path does not exist.
ValueError: If specified raw columns are not found, datetime parsing fails,
or frequency checks indicate critical issues.
Exception: For other pandas read_csv or processing errors.
"""
# TODO: Implement CSV loading and datetime parsing
pass
logger.info(f"Loading raw data from: {config.data_path}")
try:
# --- Initial Load ---
df = pd.read_csv(config.data_path, header=0)
logger.debug(f"Loaded raw data shape: {df.shape}")
# --- Validate Raw Columns ---
if config.raw_datetime_col not in df.columns:
raise ValueError(f"Raw datetime column '{config.raw_datetime_col}' not found in {config.data_path}")
if config.raw_target_col not in df.columns:
raise ValueError(f"Raw target column '{config.raw_target_col}' not found in {config.data_path}")
# --- Time Parsing (Specific Format Handling) ---
logger.info(f"Parsing raw datetime column: '{config.raw_datetime_col}'")
try:
# Extract the start time part 'dd.mm.yyyy hh:mm'
# Handle potential errors during split if format deviates
start_times = df[config.raw_datetime_col].astype(str).str.split(' - ', expand=True)[0]
# Define the specific format
datetime_format = config.raw_datetime_format or '%d.%m.%Y %H:%M'
# Parse to datetime, coercing errors to NaT
parsed_timestamps = pd.to_datetime(start_times, format=datetime_format, errors='coerce')
except Exception as e:
logger.error(f"Failed to split or parse raw datetime column '{config.raw_datetime_col}' using expected format: {e}", exc_info=True)
raise ValueError("Datetime parsing failed. Check raw_datetime_col format and data.")
# Check for parsing errors (NaT values)
num_parsing_errors = parsed_timestamps.isnull().sum()
if num_parsing_errors > 0:
original_len = len(df)
df = df.loc[parsed_timestamps.notnull()].copy() # Keep only rows with valid timestamps
parsed_timestamps = parsed_timestamps.dropna()
logger.warning(f"Dropped {num_parsing_errors} rows ({num_parsing_errors/original_len:.1%}) due to timestamp parsing errors "
f"(expected format: '{datetime_format}' on start time).")
if df.empty:
raise ValueError("No valid timestamps found after parsing. Check data format.")
# Assign parsed timestamp and set as index with standardized name
df[config.datetime_col] = parsed_timestamps
df = df.set_index(config.datetime_col)
logger.debug(f"Set '{config.datetime_col}' as index.")
# --- Target Column Processing ---
logger.info(f"Processing target column: '{config.raw_target_col}' -> '{config.target_col}'")
# Convert raw target to numeric, coercing errors
df[config.target_col] = pd.to_numeric(df[config.raw_target_col], errors='coerce')
# Handle NaNs caused by coercion
num_coercion_errors = df[config.target_col].isnull().sum()
if num_coercion_errors > 0:
logger.warning(f"Found {num_coercion_errors} non-numeric values in raw target column '{config.raw_target_col}'. Coerced to NaN.")
# Keep rows with NaN for now, handle based on config flag below
# --- Column Selection ---
# Keep only the standardized target column for the forecasting pipeline
# Discard raw columns and any others loaded initially
df = df[[config.target_col]]
logger.debug(f"Selected target column '{config.target_col}'. Shape: {df.shape}")
# --- Initial Target NaN Filling (Optional) ---
if config.fill_initial_target_nans:
missing_prices = df[config.target_col].isnull().sum()
if missing_prices > 0:
logger.info(f"Found {missing_prices} missing values in target column '{config.target_col}'. Applying ffill then bfill.")
df[config.target_col] = df[config.target_col].ffill()
df[config.target_col] = df[config.target_col].bfill() # Fill remaining NaNs at the start
final_missing = df[config.target_col].isnull().sum()
if final_missing > 0:
logger.error(f"{final_missing} missing values REMAIN in target column after ffill/bfill. Cannot proceed.")
raise ValueError("Target column contains unfillable NaN values.")
else:
logger.debug("No missing values found in target column.")
else:
logger.info("Skipping initial NaN filling for target column as per config.")
# Warning if NaNs exist and aren't being filled here
if df[config.target_col].isnull().any():
logger.warning(f"NaNs exist in target column '{config.target_col}' and initial filling is disabled.")
# --- Frequency Check & Setting ---
logger.info("Checking time index frequency...")
df = df.sort_index() # Ensure index is sorted before frequency checks
# Handle duplicate timestamps before frequency inference
duplicates = df.index.duplicated().sum()
if duplicates > 0:
logger.warning(f"Found {duplicates} duplicate timestamps. Keeping the first occurrence.")
df = df[~df.index.duplicated(keep='first')]
if config.expected_frequency:
inferred_freq = pd.infer_freq(df.index)
logger.debug(f"Inferred frequency: {inferred_freq}")
if inferred_freq == config.expected_frequency:
logger.info(f"Inferred frequency matches expected ('{config.expected_frequency}'). Setting index frequency.")
df = df.asfreq(config.expected_frequency)
# Check for NaNs introduced by asfreq (filling gaps)
missing_after_asfreq = df[config.target_col].isnull().sum()
if missing_after_asfreq > 0:
logger.warning(f"{missing_after_asfreq} NaNs appeared after setting frequency to '{config.expected_frequency}'. Applying ffill/bfill.")
# Only fill if initial filling was also enabled, otherwise just warn? Be explicit.
if config.fill_initial_target_nans:
df[config.target_col] = df[config.target_col].ffill().bfill()
if df[config.target_col].isnull().any():
logger.error("NaNs still present after attempting to fill gaps from asfreq. Check data continuity.")
raise ValueError("Unfillable NaNs after setting frequency.")
else:
logger.warning("Initial NaN filling was disabled, leaving NaNs introduced by asfreq.")
elif inferred_freq:
logger.warning(f"Inferred frequency ('{inferred_freq}') does NOT match expected ('{config.expected_frequency}'). Index frequency will not be explicitly set. This might affect time-based features or models assuming regular intervals.")
# Consider raising an error depending on strictness needed
# raise ValueError("Inferred frequency does not match expected frequency.")
else:
logger.error(f"Could not infer frequency, but expected frequency was set to '{config.expected_frequency}'. Check data for gaps or irregularities. Index frequency will not be explicitly set.")
# This is often a critical issue for time series models
raise ValueError("Could not infer frequency. Ensure data has regular intervals matching expected_frequency.")
else:
logger.info("No expected frequency specified in config. Skipping frequency check and setting.")
logger.info(f"Data loading and initial preparation complete. Final shape: {df.shape}")
return df
except FileNotFoundError:
logger.error(f"Data file not found at: {config.data_path}")
raise
except ValueError as e: # Catch ValueErrors raised internally or by pandas
logger.error(f"Data processing error: {e}", exc_info=True)
raise
except Exception as e: # Catch other unexpected errors
logger.error(f"Failed to load or process data from {config.data_path}: {e}", exc_info=True)
raise
# --- Feature Engineering ---
def engineer_features(df: pd.DataFrame, target_col: str, feature_config: FeatureConfig) -> pd.DataFrame:
"""
Create features from the target column and datetime index.
Create time-series features from the target column and datetime index.
This function operates on a specific slice of data provided during
cross-validation setup.
Args:
df: DataFrame containing the target column and a datetime index.
Should contain enough history for lookbacks (lags, rolling windows).
target_col: The name of the column to engineer features from.
feature_config: Configuration object specifying which features to create.
Returns:
DataFrame with original target and engineered features.
Raises:
ValueError: If target_col is not in df or configuration is invalid.
ImportError: If wavelets are requested but pywt is not installed.
"""
# TODO: Implement feature engineering (lags, rolling stats, time features, wavelets)
pass
if target_col not in df.columns:
raise ValueError(f"Target column '{target_col}' not found in DataFrame for feature engineering.")
logger.info("Starting feature engineering...")
features_df = df[[target_col]].copy() # Start with the target
# 1. Lags
if feature_config.lags:
logger.debug(f"Creating lag features for lags: {feature_config.lags}")
for lag in feature_config.lags:
if lag <= 0:
logger.warning(f"Ignoring non-positive lag value: {lag}")
continue
features_df[f'{target_col}_lag_{lag}'] = df[target_col].shift(lag)
# 2. Rolling Window Statistics
if feature_config.rolling_window_sizes:
logger.debug(f"Creating rolling window features for sizes: {feature_config.rolling_window_sizes}")
for window in feature_config.rolling_window_sizes:
if window <= 0:
logger.warning(f"Ignoring non-positive rolling window size: {window}")
continue
# Shift by 1 so the window does not include the current observation
# Use closed='left' to ensure window ends *before* the current point
rolling_obj = df[target_col].shift(1).rolling(window=window, min_periods=window // 2, closed='left')
features_df[f'{target_col}_rolling_mean_{window}'] = rolling_obj.mean()
features_df[f'{target_col}_rolling_std_{window}'] = rolling_obj.std()
# Add more stats if needed (e.g., min, max, median)
# 3. Time/Calendar Features
if feature_config.use_time_features:
logger.debug("Creating time/calendar features.")
idx = features_df.index # Use index from features_df
features_df['hour'] = idx.hour
features_df['dayofweek'] = idx.dayofweek
features_df['dayofmonth'] = idx.day
features_df['dayofyear'] = idx.dayofyear
# Ensure 'weekofyear' is Int64 to handle potential NAs if index isn't perfectly continuous (though unlikely here)
features_df['weekofyear'] = idx.isocalendar().week.astype(pd.Int64Dtype()) # pandas >= 1.1.0
features_df['month'] = idx.month
features_df['year'] = idx.year
features_df['quarter'] = idx.quarter
features_df['is_weekend'] = (idx.dayofweek >= 5).astype(int)
# 4. Sinusoidal Time Features (Optional, based on config)
if feature_config.sinus_curve:
logger.debug("Creating sinusoidal daily time feature.")
seconds_in_day = 24 * 60 * 60
seconds_past_midnight = features_df.index.hour * 3600 + features_df.index.minute * 60 + features_df.index.second
features_df['sin_day'] = np.sin(2 * np.pi * seconds_past_midnight / seconds_in_day)
if feature_config.cosin_curve: # Assuming this means cos for day
logger.debug("Creating cosinusoidal daily time feature.")
seconds_in_day = 24 * 60 * 60
seconds_past_midnight = features_df.index.hour * 3600 + features_df.index.minute * 60 + features_df.index.second
features_df['cos_day'] = np.cos(2 * np.pi * seconds_past_midnight / seconds_in_day)
# 5. Wavelet Transform (Optional)
if feature_config.wavelet_transform and feature_config.wavelet_transform.apply:
logger.warning("Wavelet feature engineering is specified but not implemented yet.")
# 6. Handling NaNs generated during feature engineering (for *generated* features)
feature_cols_generated = [col for col in features_df.columns if col != target_col]
if feature_cols_generated: # Only fill NaNs if features were actually generated
nan_handler = feature_config.fill_nan
if nan_handler is not None:
fill_value: Optional[Union[str, float]] = None
fill_method: Optional[str] = None
if isinstance(nan_handler, str):
if nan_handler in ['ffill', 'bfill']:
fill_method = nan_handler
logger.debug(f"Filling NaNs in generated features using method: '{fill_method}'")
elif nan_handler == 'mean':
logger.warning("NaN filling with 'mean' in generated features is applied globally here;"
" consider per-fold mean filling if lookahead is a concern.")
# Calculate mean only on the slice provided, potentially leaking info if slice includes val/test
# Better to use ffill/bfill here or handle after split
fill_value = features_df[feature_cols_generated].mean() # Calculate mean per feature column
logger.debug("Filling NaNs in generated features using column means.")
else:
logger.warning(f"Unsupported string fill_nan method '{nan_handler}' for generated features. Using 'ffill'.")
fill_method = 'ffill'
elif isinstance(nan_handler, (int, float)):
fill_value = float(nan_handler)
logger.debug(f"Filling NaNs in generated features with value: {fill_value}")
else:
logger.warning(f"Invalid fill_nan type: {type(nan_handler)}. NaNs in features may remain.")
# Apply filling only to generated feature columns
if fill_method:
features_df[feature_cols_generated] = features_df[feature_cols_generated].fillna(method=fill_method)
if fill_method == 'ffill':
features_df[feature_cols_generated] = features_df[feature_cols_generated].fillna(method='bfill')
elif fill_value is not None:
# fillna with Series/dict for column-wise mean, or scalar for constant value
features_df[feature_cols_generated] = features_df[feature_cols_generated].fillna(value=fill_value)
else:
logger.warning("`fill_nan` is None. NaNs generated by feature engineering may remain.")
remaining_nans = features_df[feature_cols_generated].isnull().sum().sum()
if remaining_nans > 0:
logger.warning(f"{remaining_nans} NaN values remain in generated features.")
# 7. Clipping (Optional) - Apply *after* feature generation but *before* scaling
if feature_config.clipping and feature_config.clipping.apply: # Check nested config
clip_config = feature_config.clipping
logger.debug(f"Clipping features (excluding target '{target_col}') between {clip_config.clip_min} and {clip_config.clip_max}")
feature_cols_to_clip = [col for col in features_df.columns if col != target_col]
if not feature_cols_to_clip:
logger.warning("Clipping enabled, but no feature columns found to clip (only target exists?).")
else:
features_df[feature_cols_to_clip] = features_df[feature_cols_to_clip].clip(
lower=clip_config.clip_min, upper=clip_config.clip_max
)
logger.info(f"Feature engineering completed. DataFrame shape: {features_df.shape}")
logger.debug(f"Feature columns: {features_df.columns.tolist()}")
return features_df
# --- Cross Validation ---
class TimeSeriesCrossValidationSplitter:
"""
Generates indices for time series cross-validation using a rolling (sliding) window.
The training window has a fixed size. For each split, the entire window
(train, validation, and test sets) slides forward by a specified step size
(typically the size of the test set). Validation and test set sizes are
calculated as fractions of the fixed training window size.
"""
def __init__(self, config: CrossValidationConfig, n_samples: int):
self.config = config
"""
Args:
config: CrossValidationConfig with split parameters.
n_samples: Total number of samples in the dataset.
"""
self.n_splits = config.n_splits
self.val_frac = config.val_size_fraction
self.test_frac = config.test_size_fraction
self.initial_train_size = config.initial_train_size # Used as the FIXED train size for rolling window
self.n_samples = n_samples
if not (0 < self.val_frac < 1):
raise ValueError(f"val_size_fraction must be between 0 and 1, got {self.val_frac}")
if not (0 < self.test_frac < 1):
raise ValueError(f"test_size_fraction must be between 0 and 1, got {self.test_frac}")
if self.n_splits <= 0:
raise ValueError(f"n_splits must be positive, got {self.n_splits}")
logger.info(f"Initializing TimeSeriesCrossValidationSplitter (Rolling Window): n_splits={self.n_splits}, "
f"val_frac={self.val_frac}, test_frac={self.test_frac}, initial_train_size (fixed)={self.initial_train_size}") # Clarified log
def _calculate_initial_train_size(self) -> int:
"""Determines the fixed training window size based on config or estimation."""
# Check if integer is provided
if isinstance(self.initial_train_size, int) and self.initial_train_size > 0:
if self.initial_train_size >= self.n_samples:
raise ValueError(f"initial_train_size ({self.initial_train_size}) must be less than total samples ({self.n_samples})")
logger.info(f"Using specified fixed training window size: {self.initial_train_size}")
return self.initial_train_size
# Check if float/fraction is provided
elif isinstance(self.initial_train_size, float) and 0 < self.initial_train_size < 1:
calculated_size = int(self.n_samples * self.initial_train_size)
if calculated_size <= 0:
raise ValueError("initial_train_size fraction results in non-positive size.")
logger.info(f"Using fixed training window size calculated from fraction: {calculated_size}")
return calculated_size
# Estimate if None
elif self.initial_train_size is None:
min_samples_per_split_step = 2 # Heuristic minimum samples for val+test in one step
# Estimate val/test based on *potential* train size (crude)
# Assume train is roughly (1 - val - test) fraction for estimation
estimated_train_frac = max(0.1, 1.0 - self.val_frac - self.test_frac) # Ensure non-zero
estimated_train_n = int(self.n_samples * estimated_train_frac)
val_test_size_per_step = max(min_samples_per_split_step, int(estimated_train_n * (self.val_frac + self.test_frac)))
# Tentative initial train size is total minus one val/test block
fixed_train_n_est = self.n_samples - val_test_size_per_step
# Basic sanity checks
if fixed_train_n_est <= 0:
raise ValueError("Could not estimate a valid initial_train_size (<= 0). Please specify it or check CV fractions.")
# Need at least 1 sample for train, val, test each theoretically
est_val_size = max(1, int(fixed_train_n_est * self.val_frac))
est_test_size = max(1, int(fixed_train_n_est * self.test_frac))
if fixed_train_n_est + est_val_size + est_test_size > self.n_samples:
# If the simple estimate is too large, reduce it more drastically
# Try setting train size = 50% and see if val/test fit?
fixed_train_n_est = int(self.n_samples * 0.5)
est_val_size = max(1, int(fixed_train_n_est * self.val_frac))
est_test_size = max(1, int(fixed_train_n_est * self.test_frac))
if fixed_train_n_est <=0 or (fixed_train_n_est + est_val_size + est_test_size > self.n_samples):
raise ValueError("Could not estimate a valid initial_train_size. Data too small relative to val/test fractions? Please specify initial_train_size.")
logger.warning(f"initial_train_size not set, estimated fixed train size for rolling window: {fixed_train_n_est}. "
"This is a heuristic; viability depends on n_splits and step size. Validation happens in split().")
return fixed_train_n_est
else:
raise ValueError(f"Invalid initial_train_size: {self.initial_train_size}")
def split(self) -> Generator[Tuple[np.ndarray, np.ndarray, np.ndarray], None, None]:
"""
Generate train/val/test splits using expanding window approach.
Generate train/validation/test indices for each fold using a rolling window.
Pre-calculates the number of possible splits based on data size and window parameters.
Yields:
Tuple of (train_indices, val_indices, test_indices) for each fold.
Raises:
ValueError: If parameters lead to invalid split sizes or overlaps,
or if the data is too small for the configuration.
"""
# TODO: Implement expanding window CV splitter
pass
indices = np.arange(self.n_samples)
fixed_train_n = self._calculate_initial_train_size() # This is now the fixed size
# Calculate val/test sizes based on the *fixed* training size. Min size of 1.
val_size = max(1, int(fixed_train_n * self.val_frac))
test_size = max(1, int(fixed_train_n * self.test_frac))
# Calculate the total size of one complete train+val+test window
fold_window_size = fixed_train_n + val_size + test_size
# Check if even the first window fits
if fold_window_size > self.n_samples:
raise ValueError(f"Configuration Error: The total window size (Train {fixed_train_n} + Val {val_size} + Test {test_size} = {fold_window_size}) "
f"exceeds total samples ({self.n_samples}). Decrease initial_train_size, fractions, or increase data.")
# Determine the step size (how much the window slides)
# Default: slide by the test set size for contiguous, non-overlapping test periods
step_size = test_size
if step_size <= 0:
raise ValueError(f"Step size (derived from test_size {test_size}) must be positive.")
# --- Calculate the number of splits actually possible ---
# Last possible start index for the train set
last_possible_train_start_idx = self.n_samples - fold_window_size
# Calculate how many steps fit within this range (integer division)
# If last possible start is 5, step is 2: steps possible at 0, 2, 4 => (5 // 2) + 1 = 2 + 1 = 3
num_possible_steps = max(0, last_possible_train_start_idx // step_size) + 1 # +1 because we start at index 0
# Use the minimum of requested splits and possible splits
actual_n_splits = min(self.n_splits, num_possible_steps)
if actual_n_splits < self.n_splits:
logger.warning(f"Data size ({self.n_samples} samples) only allows for {actual_n_splits} splits "
f"with fixed train size {fixed_train_n}, val size {val_size}, test size {test_size} (total window {fold_window_size}) and step size {step_size} "
f"(requested {self.n_splits}).")
elif actual_n_splits == 0:
# This case should be caught by the fold_window_size > self.n_samples check, but belt-and-suspenders
logger.error("Data too small for even one split with the rolling window configuration.")
return # Return generator that yields nothing
# --- Generate the splits ---
for i in range(actual_n_splits):
logger.debug(f"Generating indices for fold {i+1}/{actual_n_splits} (Rolling Window)") # Log using actual_n_splits
# Calculate window boundaries for this fold
train_start_idx = i * step_size
train_end_idx = train_start_idx + fixed_train_n
val_start_idx = train_end_idx
val_end_idx = val_start_idx + val_size
test_start_idx = val_end_idx
test_end_idx = test_start_idx + test_size # = train_start_idx + fold_window_size
# Determine indices for this fold using slicing
train_indices = indices[train_start_idx:train_end_idx]
val_indices = indices[val_start_idx:val_end_idx]
test_indices = indices[test_start_idx:test_end_idx]
# --- Basic Validation Checks (Optional, should be guaranteed by calculations) ---
# Ensure no overlap (guaranteed by slicing if sizes > 0)
# Ensure sequence (guaranteed by slicing)
logger.info(f"Fold {i+1}: Train indices {train_indices[0]}-{train_indices[-1]} (size {len(train_indices)}), "
f"Val indices {val_indices[0]}-{val_indices[-1]} (size {len(val_indices)}), "
f"Test indices {test_indices[0]}-{test_indices[-1]} (size {len(test_indices)})")
yield train_indices, val_indices, test_indices
# --- Dataset Class ---
class TimeSeriesDataset(Dataset):
def __init__(self, data_array: np.ndarray, sequence_length: int, forecast_horizon: int):
self.data = data_array
"""
PyTorch Dataset for time series forecasting.
Takes a NumPy array (features + target), sequence length, and forecast horizon,
and returns (input_sequence, target_sequence) tuples. Compatible with PyTorch
DataLoaders used by PyTorch Lightning.
"""
def __init__(self, data_array: np.ndarray, sequence_length: int, forecast_horizon: int, target_col_index: int = 0):
"""
Args:
data_array: Numpy array of shape (n_samples, n_features).
Assumes the target variable is one of the columns.
sequence_length: Length of the input sequence (lookback window).
forecast_horizon: Number of steps ahead to predict.
target_col_index: Index of the target column in data_array. Defaults to 0.
"""
if sequence_length <= 0:
raise ValueError("sequence_length must be positive.")
if forecast_horizon <= 0:
raise ValueError("forecast_horizon must be positive.")
if data_array.ndim != 2:
raise ValueError(f"data_array must be 2D, but got shape {data_array.shape}")
min_len_required = sequence_length + forecast_horizon
if min_len_required > data_array.shape[0]:
raise ValueError(f"sequence_length ({sequence_length}) + forecast_horizon ({forecast_horizon}) = {min_len_required} "
f"exceeds total samples provided ({data_array.shape[0]})")
if not (0 <= target_col_index < data_array.shape[1]):
raise ValueError(f"target_col_index ({target_col_index}) out of bounds for data with {data_array.shape[1]} columns.")
self.data = torch.tensor(data_array, dtype=torch.float32)
self.sequence_length = sequence_length
self.forecast_horizon = forecast_horizon
self.target_col_index = target_col_index
self.n_samples = data_array.shape[0]
self.n_features = data_array.shape[1]
logger.debug(f"TimeSeriesDataset created: data shape={self.data.shape}, "
f"seq_len={self.sequence_length}, forecast_horizon={self.forecast_horizon}, "
f"target_idx={self.target_col_index}")
def __len__(self) -> int:
# TODO: Implement length calculation
pass
"""Returns the total number of sequences that can be generated."""
return self.n_samples - self.sequence_length - self.forecast_horizon + 1
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
# TODO: Implement sequence extraction
pass
"""
Returns a single (input_sequence, target_sequence) pair.
"""
if not (0 <= idx < len(self)):
raise IndexError(f"Index {idx} out of bounds for dataset with length {len(self)}")
input_start = idx
input_end = idx + self.sequence_length
input_sequence = self.data[input_start:input_end, :]
target_start = input_end
target_end = target_start + self.forecast_horizon
target_sequence = self.data[target_start:target_end, self.target_col_index]
return input_sequence, target_sequence
# --- Data Preparation ---
def prepare_fold_data_and_loaders(
full_df: pd.DataFrame,
full_df: pd.DataFrame, # Should contain only the target initially
train_idx: np.ndarray,
val_idx: np.ndarray,
test_idx: np.ndarray,
target_col: str,
feature_config: FeatureConfig,
train_config: TrainingConfig,
eval_config: EvaluationConfig
) -> Tuple[DataLoader, DataLoader, DataLoader, object, int]:
) -> Tuple[DataLoader, DataLoader, DataLoader, Union[StandardScaler, MinMaxScaler, None], int]:
"""
Prepare data loaders for a single fold.
Prepares data loaders for a single cross-validation fold.
This is essential for time-series CV where scaling must be fitted *only*
on the training data of the current fold to prevent lookahead bias.
The resulting DataLoaders can be used directly with a PyTorch Lightning Trainer
within the cross-validation loop in `main.py`.
Steps:
1. Determines the full data range needed for the fold (incl. history for features).
2. Engineers features on this slice using `engineer_features`.
3. Splits the engineered data into train, validation, test sets based on indices.
4. **Fits a scaler ONLY on the training data for the current fold.**
5. Transforms train, validation, and test sets using the fitted scaler.
6. Creates `TimeSeriesDataset` instances for each set.
7. Creates `DataLoader` instances for each set.
Args:
full_df: The complete raw DataFrame (datetime index, target column).
train_idx: Array of integer indices for the training set.
val_idx: Array of integer indices for the validation set.
test_idx: Array of integer indices for the test set.
target_col: Name of the target column.
feature_config: Configuration for feature engineering.
train_config: Configuration for training (used for batch size, device hints).
eval_config: Configuration for evaluation (used for batch size).
Returns:
Tuple containing:
- train_loader: DataLoader for the training set.
- val_loader: DataLoader for the validation set.
- test_loader: DataLoader for the test set.
- target_scaler: The scaler fitted on the target variable (for inverse transform). Can be None.
- input_size: The number of features in the input sequences (X).
Raises:
ValueError: If indices are invalid, data splitting fails, or NaNs persist.
ImportError: If feature engineering requires an uninstalled library.
"""
# TODO: Implement data preparation pipeline
pass
logger.info(f"Preparing data loaders for fold: train_size={len(train_idx)}, val_size={len(val_idx)}, test_size={len(test_idx)}")
if len(train_idx) == 0 or len(val_idx) == 0 or len(test_idx) == 0:
raise ValueError("Received empty indices for train, validation, or test set.")
# 1. Determine data slice needed including history for feature lookback
max_lookback = 0
if feature_config.lags:
max_lookback = max(max_lookback, max(feature_config.lags))
if feature_config.rolling_window_sizes:
max_lookback = max(max_lookback, max(feature_config.rolling_window_sizes) -1 )
max_history_needed = max(max_lookback, feature_config.sequence_length)
slice_start_idx = max(0, train_idx[0] - max_history_needed)
slice_end_idx = test_idx[-1] + 1
if slice_start_idx >= slice_end_idx:
raise ValueError(f"Calculated slice start ({slice_start_idx}) >= slice end ({slice_end_idx}). Check indices.")
fold_data_slice = full_df.iloc[slice_start_idx:slice_end_idx]
logger.debug(f"Required data slice for fold: indices {slice_start_idx} to {slice_end_idx-1} "
f"(size {len(fold_data_slice)}) for history and fold data.")
if fold_data_slice.empty:
raise ValueError(f"Data slice for fold is empty (indices {slice_start_idx} to {slice_end_idx-1}).")
# 2. Feature Engineering on the slice
try:
engineered_df = engineer_features(fold_data_slice.copy(), target_col, feature_config)
if engineered_df.empty:
raise ValueError("Feature engineering resulted in an empty DataFrame.")
except Exception as e:
logger.error(f"Feature engineering failed for fold: {e}")
raise
# 3. Map absolute indices to iloc positions in the potentially modified engineered_df
try:
# Use index intersection to find valid locations
train_indices_dt = full_df.index[train_idx]
val_indices_dt = full_df.index[val_idx]
test_indices_dt = full_df.index[test_idx]
adj_train_idx_loc = engineered_df.index.get_indexer(train_indices_dt.intersection(engineered_df.index))
adj_val_idx_loc = engineered_df.index.get_indexer(val_indices_dt.intersection(engineered_df.index))
adj_test_idx_loc = engineered_df.index.get_indexer(test_indices_dt.intersection(engineered_df.index))
# Filter out any -1s just in case (shouldn't happen with intersection)
adj_train_idx_loc = adj_train_idx_loc[adj_train_idx_loc != -1]
adj_val_idx_loc = adj_val_idx_loc[adj_val_idx_loc != -1]
adj_test_idx_loc = adj_test_idx_loc[adj_test_idx_loc != -1]
if len(adj_train_idx_loc) == 0 or len(adj_val_idx_loc) == 0 or len(adj_test_idx_loc) == 0:
logger.error(f"Index mapping resulted in empty splits: Train({len(adj_train_idx_loc)}), Val({len(adj_val_idx_loc)}), Test({len(adj_test_idx_loc)})")
logger.debug(f"Original counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}")
logger.debug(f"Engineered DF index span: {engineered_df.index.min()} to {engineered_df.index.max()}")
raise ValueError("Mapping original indices to engineered DataFrame resulted in empty splits. Check CV indices and NaN handling.")
except Exception as e:
logger.error(f"Error mapping indices to engineered DataFrame: {e}")
raise ValueError("Failed to map CV indices to the feature-engineered data slice.")
# 4. Split engineered data using iloc positions
train_df = engineered_df.iloc[adj_train_idx_loc]
val_df = engineered_df.iloc[adj_val_idx_loc]
test_df = engineered_df.iloc[adj_test_idx_loc]
logger.debug(f"Fold split shapes after feature engineering: Train={train_df.shape}, Val={val_df.shape}, Test={test_df.shape}")
if train_df.empty or val_df.empty or test_df.empty:
raise ValueError("One or more data splits (train, val, test) are empty after feature engineering and splitting.")
# --- Final Check for NaNs before scaling/Dataset creation ---
if train_df.isnull().any().any():
nan_cols = train_df.columns[train_df.isnull().any()].tolist()
logger.error(f"NaNs found in FINAL training data before scaling. Columns: {nan_cols}")
logger.debug(f"NaN counts per column in train_df:\n{train_df.isnull().sum()[train_df.isnull().any()]}")
raise ValueError("NaNs present in training data before scaling. Check feature engineering NaN handling.")
if val_df.isnull().any().any() or test_df.isnull().any().any():
logger.warning("NaNs found in final validation or test data splits. This might cause issues during evaluation or testing.")
# 5. Scaling (Fit on Train, Transform All) - CRITICAL PER-FOLD STEP
feature_cols = train_df.columns.tolist()
try:
target_col_index_in_features = feature_cols.index(target_col)
except ValueError:
raise ValueError(f"Target column '{target_col}' not found in the final feature columns: {feature_cols}")
scaler: Optional[Union[StandardScaler, MinMaxScaler]] = None
target_scaler: Optional[Union[StandardScaler, MinMaxScaler]] = None
ScalerClass: Optional[Type[Union[StandardScaler, MinMaxScaler]]] = None
if feature_config.scaling_method == 'standard':
ScalerClass = StandardScaler
elif feature_config.scaling_method == 'minmax':
ScalerClass = MinMaxScaler
elif feature_config.scaling_method is None:
logger.info("No scaling applied for this fold.")
else:
raise ValueError(f"Unsupported scaling method: {feature_config.scaling_method}")
train_data = train_df[feature_cols].values
val_data = val_df[feature_cols].values
test_data = test_df[feature_cols].values
if ScalerClass is not None:
scaler = ScalerClass()
target_scaler = ScalerClass()
logger.info(f"Applying {feature_config.scaling_method} scaling. Fitting on training data for the fold.")
scaler.fit(train_data)
target_scaler.fit(train_data[:, target_col_index_in_features].reshape(-1, 1))
train_data_scaled = scaler.transform(train_data)
val_data_scaled = scaler.transform(val_data)
test_data_scaled = scaler.transform(test_data)
logger.debug("Scaling complete for the fold.")
else:
train_data_scaled = train_data
val_data_scaled = val_data
test_data_scaled = test_data
input_size = train_data_scaled.shape[1]
# 6. Dataset Instantiation
logger.debug("Creating TimeSeriesDataset instances for the fold.")
try:
train_dataset = TimeSeriesDataset(
train_data_scaled, feature_config.sequence_length, feature_config.forecast_horizon, target_col_index=target_col_index_in_features
)
val_dataset = TimeSeriesDataset(
val_data_scaled, feature_config.sequence_length, feature_config.forecast_horizon, target_col_index=target_col_index_in_features
)
test_dataset = TimeSeriesDataset(
test_data_scaled, feature_config.sequence_length, feature_config.forecast_horizon, target_col_index=target_col_index_in_features
)
except ValueError as e:
logger.error(f"Error creating TimeSeriesDataset: {e}")
logger.error(f"Shapes fed to Dataset: Train={train_data_scaled.shape}, Val={val_data_scaled.shape}, Test={test_data_scaled.shape}")
logger.error(f"SeqLen={feature_config.sequence_length}, Horizon={feature_config.forecast_horizon}")
raise
# 7. DataLoader Creation
logger.debug("Creating DataLoaders for the fold.")
num_workers = getattr(train_config, 'num_workers', 0)
pin_memory = torch.cuda.is_available() # Pin memory if CUDA is available
train_loader = DataLoader(
train_dataset, batch_size=train_config.batch_size, shuffle=True,
num_workers=num_workers, pin_memory=pin_memory, drop_last=False
)
val_loader = DataLoader(
val_dataset, batch_size=eval_config.eval_batch_size, shuffle=False,
num_workers=num_workers, pin_memory=pin_memory, drop_last=False
)
test_loader = DataLoader(
test_dataset, batch_size=eval_config.eval_batch_size, shuffle=False,
num_workers=num_workers, pin_memory=pin_memory, drop_last=False
)
logger.info("Data loaders prepared successfully for the fold.")
return train_loader, val_loader, test_loader, target_scaler, input_size

View File

@ -1,82 +1,325 @@
import logging
import os
from pathlib import Path # Added
import numpy as np
import torch
import torchmetrics
from torch.utils.data import DataLoader
from typing import Dict, Any, Optional
from utils.config_model import EvaluationConfig
from sklearn.preprocessing import StandardScaler, MinMaxScaler # For type hinting target_scaler
from typing import Dict, Any, Optional, Union, List, Tuple
# import matplotlib.pyplot as plt # No longer needed directly
# import seaborn as sns # No longer needed directly
def calculate_mae(y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""
Calculate Mean Absolute Error.
"""
# TODO: Implement MAE calculation
pass
# Assuming config_model and io.plotting are accessible
from forecasting_model.utils.config_model import EvaluationConfig
from forecasting_model.io.plotting import ( # Import the plotting utilities
setup_plot_style,
save_plot,
create_time_series_plot,
create_scatter_plot,
create_residuals_plot,
create_residuals_distribution_plot
)
def calculate_rmse(y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""
Calculate Root Mean Squared Error.
"""
# TODO: Implement RMSE calculation
pass
def plot_predictions_vs_actual(
y_true: np.ndarray,
y_pred: np.ndarray,
title_suffix: str,
filename: str,
max_points: Optional[int] = None
) -> None:
"""
Create line plot of predictions vs actual values.
"""
# TODO: Implement prediction vs actual plot
pass
logger = logging.getLogger(__name__)
def plot_scatter_predictions(
y_true: np.ndarray,
y_pred: np.ndarray,
title_suffix: str,
filename: str
) -> None:
# --- Metric Calculations (Utilities - Optional) ---
# (Keep calculate_mae_np, calculate_rmse_np if needed as standalone utils)
# ... (code for calculate_mae_np, calculate_rmse_np unchanged) ...
def calculate_mae_np(y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""
Create scatter plot of predictions vs actual values.
"""
# TODO: Implement scatter plot
pass
[Optional Utility] Calculate Mean Absolute Error using NumPy.
Prefer torchmetrics inside training/validation loops.
def plot_residuals_time(
residuals: np.ndarray,
title_suffix: str,
filename: str,
max_points: Optional[int] = None
) -> None:
"""
Create plot of residuals over time.
"""
# TODO: Implement residuals time plot
pass
Args:
y_true: Ground truth values (flattened).
y_pred: Predicted values (flattened).
def plot_residuals_distribution(
residuals: np.ndarray,
title_suffix: str,
filename: str
) -> None:
Returns:
Calculated MAE, or NaN if inputs are invalid.
"""
Create histogram/KDE of residuals.
"""
# TODO: Implement residuals distribution plot
pass
if y_true.shape != y_pred.shape:
logger.error(f"Shape mismatch for MAE: y_true={y_true.shape}, y_pred={y_pred.shape}")
return np.nan
if len(y_true) == 0:
logger.warning("Attempting to calculate MAE on empty arrays.")
return np.nan
try:
# Use scikit-learn for robustness if available, otherwise basic numpy
from sklearn.metrics import mean_absolute_error
mae = mean_absolute_error(y_true, y_pred)
except ImportError:
mae = np.mean(np.abs(y_true - y_pred))
return float(mae)
def evaluate_fold(
model: torch.nn.Module,
test_loader: DataLoader,
loss_fn: torch.nn.Module,
device: torch.device,
target_scaler: Any,
def calculate_rmse_np(y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""
[Optional Utility] Calculate Root Mean Squared Error using NumPy.
Prefer torchmetrics inside training/validation loops.
Args:
y_true: Ground truth values (flattened).
y_pred: Predicted values (flattened).
Returns:
Calculated RMSE, or NaN if inputs are invalid.
"""
if y_true.shape != y_pred.shape:
logger.error(f"Shape mismatch for RMSE: y_true={y_true.shape}, y_pred={y_pred.shape}")
return np.nan
if len(y_true) == 0:
logger.warning("Attempting to calculate RMSE on empty arrays.")
return np.nan
try:
# Use scikit-learn for robustness if available, otherwise basic numpy
from sklearn.metrics import mean_squared_error
mse = mean_squared_error(y_true, y_pred, squared=True)
except ImportError:
mse = np.mean((y_true - y_pred)**2)
rmse = np.sqrt(mse)
return float(rmse)
# --- Plotting Functions (Utilities) ---
# REMOVED - These are now imported from io.plotting
# --- Fold Evaluation Function ---
def evaluate_fold_predictions(
y_true_scaled: np.ndarray,
y_pred_scaled: np.ndarray,
target_scaler: Union[StandardScaler, MinMaxScaler, None],
eval_config: EvaluationConfig,
fold_num: int
fold_num: int,
output_dir: str, # Base output directory (e.g., output/cv_results)
time_index: Optional[np.ndarray] = None # Optional: Pass time index for x-axis
) -> Dict[str, float]:
"""
Evaluate model on test set and generate plots.
Processes prediction results for a fold's test set using torchmetrics.
Takes scaled predictions and targets, inverse transforms them,
calculates final metrics (MAE, RMSE) using torchmetrics.functional,
and generates evaluation plots using utilities from io.plotting. Assumes
model inference is already done.
Args:
y_true_scaled: Numpy array of scaled ground truth targets (n_samples, horizon).
y_pred_scaled: Numpy array of scaled model predictions (n_samples, horizon).
target_scaler: The scaler fitted on the target variable during training. Needed
for inverse transforming to original scale. Can be None.
eval_config: Configuration object for evaluation parameters (e.g., plotting).
fold_num: The current fold number (e.g., 0, 1, ...).
output_dir: The base directory to save fold-specific outputs (plots, metrics).
time_index: Optional array representing the time index for the test set,
used for x-axis in time-based plots. If None, uses integer indices.
Returns:
Dictionary containing evaluation metrics {'MAE': value, 'RMSE': value} on the
original scale. Metrics will be NaN if inverse transform or calculation fails.
Raises:
ValueError: If input shapes are inconsistent or required scaler is missing.
"""
# TODO: Implement full evaluation pipeline
pass
logger.info(f"Processing evaluation results for Fold {fold_num + 1}...")
fold_id = fold_num + 1 # Use 1-based indexing for reporting/filenames
if y_true_scaled.shape != y_pred_scaled.shape:
raise ValueError(f"Shape mismatch between targets and predictions: "
f"{y_true_scaled.shape} vs {y_pred_scaled.shape}")
if y_true_scaled.ndim != 2:
raise ValueError(f"Expected 2D arrays for targets and predictions, got {y_true_scaled.ndim}D")
n_samples, horizon = y_true_scaled.shape
logger.debug(f"Processing {n_samples} samples with horizon {horizon}.")
# --- Inverse Transform (Outputs NumPy) ---
y_true_flat_scaled = y_true_scaled.reshape(-1, 1)
y_pred_flat_scaled = y_pred_scaled.reshape(-1, 1)
y_true_inv_np: np.ndarray
y_pred_inv_np: np.ndarray
if target_scaler is not None:
try:
logger.debug("Inverse transforming predictions and targets.")
y_true_inv_np = target_scaler.inverse_transform(y_true_flat_scaled)
y_pred_inv_np = target_scaler.inverse_transform(y_pred_flat_scaled)
# Flatten NumPy arrays for metric calculation and plotting
y_true_np = y_true_inv_np.flatten()
y_pred_np = y_pred_inv_np.flatten()
except Exception as e:
logger.error(f"Error during inverse scaling for Fold {fold_id}: {e}", exc_info=True)
logger.error("Metrics calculation will be skipped due to inverse transform failure.")
return {'MAE': np.nan, 'RMSE': np.nan}
else:
logger.info("No target scaler provided, assuming inputs are already on original scale.")
# Flatten NumPy arrays for metric calculation and plotting
y_true_np = y_true_flat_scaled.flatten()
y_pred_np = y_pred_flat_scaled.flatten()
# --- Calculate Metrics using torchmetrics.functional ---
metrics: Dict[str, float] = {'MAE': np.nan, 'RMSE': np.nan} # Initialize with NaN
try:
if len(y_true_np) > 0: # Check if data exists after potential failures
y_true_tensor = torch.from_numpy(y_true_np).float().cpu()
y_pred_tensor = torch.from_numpy(y_pred_np).float().cpu()
mae_tensor = torchmetrics.functional.mean_absolute_error(y_pred_tensor, y_true_tensor)
mse_tensor = torchmetrics.functional.mean_squared_error(y_pred_tensor, y_true_tensor)
rmse_tensor = torch.sqrt(mse_tensor)
metrics['MAE'] = mae_tensor.item()
metrics['RMSE'] = rmse_tensor.item()
logger.info(f"Fold {fold_id} Test Set Metrics (torchmetrics): MAE={metrics['MAE']:.4f}, RMSE={metrics['RMSE']:.4f}")
else:
logger.warning(f"Skipping metric calculation for Fold {fold_id} due to empty data after inverse transform.")
except Exception as e:
logger.error(f"Failed to calculate metrics using torchmetrics for Fold {fold_id}: {e}", exc_info=True)
# metrics already initialized to NaN
# --- Generate Plots (Optional - uses plotting utilities) ---
if eval_config.save_plots and len(y_true_np) > 0:
logger.info(f"Generating evaluation plots for Fold {fold_id}...")
# Define plot directory and setup style
fold_plot_dir = Path(output_dir) / f"fold_{fold_id:02d}" / "plots"
setup_plot_style() # Apply consistent styling
title_suffix = f"Fold {fold_id} Test Set"
residuals_np = y_true_np - y_pred_np
# Determine x-axis: use provided time_index if available, else integer indices
# Note: Flattened y_true/y_pred have length n_samples * horizon
# Need an appropriate index for this flattened view if time_index is provided.
# Simple approach: use integer indices for flattened data.
plot_indices = np.arange(len(y_true_np))
xlabel = "Time Index (Flattened Horizon x Samples)"
# If time_index corresponding to the start of each forecast is passed,
# more sophisticated x-axis handling could be done, but integer indices are simpler.
try:
# Create and save each plot using utility functions
fig_ts = create_time_series_plot(
plot_indices, y_true_np, y_pred_np,
f"Predictions vs Actual - {title_suffix}",
xlabel=xlabel,
ylabel="Value (Original Scale)",
max_points=eval_config.plot_sample_size
)
save_plot(fig_ts, fold_plot_dir / "predictions_vs_actual.png")
fig_scatter = create_scatter_plot(
y_true_np, y_pred_np,
f"Scatter Plot - {title_suffix}",
xlabel="Actual Values (Original Scale)",
ylabel="Predicted Values (Original Scale)"
)
save_plot(fig_scatter, fold_plot_dir / "scatter_predictions.png")
fig_res_time = create_residuals_plot(
plot_indices, residuals_np,
f"Residuals Over Time - {title_suffix}",
xlabel=xlabel,
ylabel="Residual (Original Scale)",
max_points=eval_config.plot_sample_size
)
save_plot(fig_res_time, fold_plot_dir / "residuals_time.png")
fig_res_dist = create_residuals_distribution_plot(
residuals_np,
f"Residuals Distribution - {title_suffix}",
xlabel="Residual Value (Original Scale)",
ylabel="Density"
)
save_plot(fig_res_dist, fold_plot_dir / "residuals_distribution.png")
logger.info(f"Evaluation plots saved to: {fold_plot_dir}")
except Exception as e:
logger.error(f"Failed to generate or save one or more plots for Fold {fold_id}: {e}", exc_info=True)
# Continue without plots, metrics are already calculated.
elif eval_config.save_plots and len(y_true_np) == 0:
logger.warning(f"Skipping plot generation for Fold {fold_id} due to empty data.")
logger.info(f"Evaluation processing finished for Fold {fold_id}.")
return metrics
# --- (Optional) Wrapper for non-PL usage or direct testing ---
# This function still calls evaluate_fold_predictions internally, so it benefits
# from the updated plotting logic without needing direct changes here.
def evaluate_model_on_fold_test_set(
model: torch.nn.Module,
test_loader: DataLoader,
device: torch.device,
target_scaler: Union[StandardScaler, MinMaxScaler, None],
eval_config: EvaluationConfig,
fold_num: int,
output_dir: str
) -> Dict[str, float]:
"""
[Optional Function] Evaluates a given model on a fold's test set.
Runs the inference loop, collects scaled results, then processes them using
`evaluate_fold_predictions` (which now uses plotting utilities).
Useful for standalone testing or if not using pl.Trainer.test().
"""
# ... (Implementation of inference loop remains the same) ...
logger.info(f"Starting full evaluation (inference + processing) for Fold {fold_num + 1}...")
model.eval()
model.to(device)
all_preds_scaled_list: List[torch.Tensor] = []
all_targets_scaled_list: List[torch.Tensor] = []
with torch.no_grad():
for i, (X_batch, y_batch) in enumerate(test_loader):
try:
X_batch = X_batch.to(device)
outputs = model(X_batch) # Scaled outputs
# Ensure outputs match target shape (e.g., handle trailing dimension)
if outputs.shape != y_batch.shape:
if outputs.ndim == y_batch.ndim + 1 and outputs.shape[-1] == 1:
outputs = outputs.squeeze(-1)
if outputs.shape != y_batch.shape:
raise ValueError(f"Shape mismatch: Output {outputs.shape}, Target {y_batch.shape}")
all_preds_scaled_list.append(outputs.cpu())
all_targets_scaled_list.append(y_batch.cpu()) # Keep targets on CPU
except Exception as e:
logger.error(f"Error during inference batch {i} for Fold {fold_num+1}: {e}", exc_info=True)
raise ValueError(f"Inference failed on batch {i} for Fold {fold_num+1}")
# Concatenate results from all batches
try:
if not all_preds_scaled_list or not all_targets_scaled_list:
logger.error(f"No prediction results collected for Fold {fold_num + 1}. Check test_loader.")
return {'MAE': np.nan, 'RMSE': np.nan}
y_pred_scaled = torch.cat(all_preds_scaled_list, dim=0).numpy()
y_true_scaled = torch.cat(all_targets_scaled_list, dim=0).numpy()
except Exception as e:
logger.error(f"Error concatenating prediction results for Fold {fold_num + 1}: {e}", exc_info=True)
raise ValueError("Failed to combine batch results during evaluation inference.")
# Process the collected predictions using the refactored function
# No time_index passed here by default, plotting will use integer indices
return evaluate_fold_predictions(
y_true_scaled=y_true_scaled,
y_pred_scaled=y_pred_scaled,
target_scaler=target_scaler,
eval_config=eval_config,
fold_num=fold_num,
output_dir=output_dir,
time_index=None # Explicitly pass None
)

View File

@ -1,5 +1,26 @@
"""
IO utilities for the forecasting model.
Input/Output utilities for the forecasting model package.
Currently, primarily includes plotting functions used internally by evaluation.
"""
This package contains utilities for data loading, saving, and visualization.
"""
# Expose plotting utilities if intended for external use
# from .plotting import (
# setup_plot_style,
# save_plot,
# create_time_series_plot,
# create_scatter_plot,
# create_residuals_plot,
# create_residuals_distribution_plot
# )
# __all__ = [
# "setup_plot_style",
# "save_plot",
# "create_time_series_plot",
# "create_scatter_plot",
# "create_residuals_plot",
# "create_residuals_distribution_plot",
# ]
# If nothing is intended for public API from this submodule, leave this file empty
# or with just a docstring.

View File

@ -1,75 +1,307 @@
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from typing import Optional
from typing import Optional, Union
import logging
from pathlib import Path
logger = logging.getLogger(__name__)
def setup_plot_style() -> None:
def setup_plot_style(use_seaborn: bool = True) -> None:
"""
Set up consistent plotting style.
"""
# TODO: Implement plot style configuration
pass
Set up a consistent plotting style using seaborn if enabled.
def save_plot(fig: plt.Figure, filename: str) -> None:
Args:
use_seaborn: Whether to apply seaborn styling.
"""
Save plot to file with proper error handling.
if use_seaborn:
try:
sns.set_theme(style="whitegrid", palette="muted")
plt.rcParams['figure.figsize'] = (12, 6) # Default figure size
logger.debug("Seaborn plot style set.")
except Exception as e:
logger.warning(f"Failed to set seaborn theme: {e}. Using default matplotlib style.")
else:
# Optional: Define a default matplotlib style if seaborn is not used
plt.style.use('default')
logger.debug("Using default matplotlib plot style.")
def save_plot(fig: plt.Figure, filename: Union[str, Path]) -> None:
"""
# TODO: Implement plot saving with error handling
pass
Save matplotlib figure to a file with directory creation and error handling.
Args:
fig: The matplotlib Figure object to save.
filename: The full path (including filename and extension) to save the plot to.
Can be a string or Path object.
Raises:
OSError: If the directory cannot be created.
Exception: For other file saving errors.
"""
filepath = Path(filename)
try:
# Create the parent directory if it doesn't exist
filepath.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(filepath, bbox_inches='tight', dpi=150) # Save with tight bounding box and decent resolution
logger.info(f"Plot saved successfully to: {filepath}")
except OSError as e:
logger.error(f"Failed to create directory for plot {filepath}: {e}", exc_info=True)
raise # Re-raise OSError for directory creation issues
except Exception as e:
logger.error(f"Failed to save plot to {filepath}: {e}", exc_info=True)
raise # Re-raise other saving errors
finally:
# Close the figure to free up memory, regardless of saving success
plt.close(fig)
def create_time_series_plot(
x: np.ndarray,
y_true: np.ndarray,
y_pred: np.ndarray,
title: str,
xlabel: str,
ylabel: str,
xlabel: str = "Time Index",
ylabel: str = "Value",
max_points: Optional[int] = None
) -> plt.Figure:
"""
Create a time series plot with actual vs predicted values.
Create a time series plot comparing actual vs predicted values.
Args:
x: The array for the x-axis (e.g., time steps, indices).
y_true: Ground truth values (1D array).
y_pred: Predicted values (1D array).
title: Title for the plot.
xlabel: Label for the x-axis.
ylabel: Label for the y-axis.
max_points: Maximum number of points to display (subsamples if needed).
Returns:
The generated matplotlib Figure object.
Raises:
ValueError: If input array shapes are incompatible.
"""
# TODO: Implement time series plot creation
pass
if not (x.shape == y_true.shape == y_pred.shape and x.ndim == 1):
raise ValueError("Input arrays (x, y_true, y_pred) must be 1D and have the same shape.")
if len(x) == 0:
logger.warning("Attempting to create time series plot with empty data.")
# Return an empty figure or raise error? Let's return empty.
return plt.figure()
logger.debug(f"Creating time series plot: {title}")
fig, ax = plt.subplots(figsize=(15, 6)) # Consistent size
n_points = len(x)
indices = np.arange(n_points) # Use internal indices for potential slicing
if max_points and n_points > max_points:
step = max(1, n_points // max_points)
plot_indices = indices[::step]
plot_x = x[::step]
plot_y_true = y_true[::step]
plot_y_pred = y_pred[::step]
effective_title = f'{title} (Sampled {len(plot_indices)} points)'
else:
plot_x = x
plot_y_true = y_true
plot_y_pred = y_pred
effective_title = title
ax.plot(plot_x, plot_y_true, label='Actual', marker='.', linestyle='-', markersize=4, linewidth=1.5)
ax.plot(plot_x, plot_y_pred, label='Predicted', marker='x', linestyle='--', markersize=4, alpha=0.8, linewidth=1)
ax.set_title(effective_title, fontsize=14)
ax.set_xlabel(xlabel, fontsize=12)
ax.set_ylabel(ylabel, fontsize=12)
ax.legend()
ax.grid(True, linestyle='--', alpha=0.6)
fig.tight_layout()
return fig
def create_scatter_plot(
y_true: np.ndarray,
y_pred: np.ndarray,
title: str,
xlabel: str,
ylabel: str
xlabel: str = "Actual Values",
ylabel: str = "Predicted Values"
) -> plt.Figure:
"""
Create a scatter plot of actual vs predicted values.
Args:
y_true: Ground truth values (1D array).
y_pred: Predicted values (1D array).
title: Title for the plot.
xlabel: Label for the x-axis.
ylabel: Label for the y-axis.
Returns:
The generated matplotlib Figure object.
Raises:
ValueError: If input array shapes are incompatible.
"""
# TODO: Implement scatter plot creation
pass
if not (y_true.shape == y_pred.shape and y_true.ndim == 1):
raise ValueError("Input arrays (y_true, y_pred) must be 1D and have the same shape.")
if len(y_true) == 0:
logger.warning("Attempting to create scatter plot with empty data.")
return plt.figure()
logger.debug(f"Creating scatter plot: {title}")
fig, ax = plt.subplots(figsize=(8, 8)) # Square figure common for scatter
# Determine plot limits, handle potential NaNs
valid_mask = ~np.isnan(y_true) & ~np.isnan(y_pred)
if not np.any(valid_mask):
logger.warning(f"No valid (non-NaN) data points found for scatter plot '{title}'.")
# Return empty figure, plot would be blank
return fig
y_true_valid = y_true[valid_mask]
y_pred_valid = y_pred[valid_mask]
min_val = min(y_true_valid.min(), y_pred_valid.min())
max_val = max(y_true_valid.max(), y_pred_valid.max())
plot_range = max_val - min_val
if plot_range < 1e-6: # Handle cases where all points are identical
plot_range = 1.0 # Avoid zero range
lim_min = min_val - 0.05 * plot_range
lim_max = max_val + 0.05 * plot_range
ax.scatter(y_true_valid, y_pred_valid, alpha=0.5, s=10, label='Predictions')
ax.plot([lim_min, lim_max], [lim_min, lim_max], 'r--', label='Ideal (y=x)', linewidth=1.5)
ax.set_title(title, fontsize=14)
ax.set_xlabel(xlabel, fontsize=12)
ax.set_ylabel(ylabel, fontsize=12)
ax.set_xlim(lim_min, lim_max)
ax.set_ylim(lim_min, lim_max)
ax.legend()
ax.grid(True, linestyle='--', alpha=0.6)
ax.set_aspect('equal', adjustable='box') # Ensure square scaling
fig.tight_layout()
return fig
def create_residuals_plot(
x: np.ndarray,
residuals: np.ndarray,
title: str,
xlabel: str,
ylabel: str,
xlabel: str = "Time Index",
ylabel: str = "Residual (Actual - Predicted)",
max_points: Optional[int] = None
) -> plt.Figure:
"""
Create a plot of residuals over time.
Args:
x: The array for the x-axis (e.g., time steps, indices).
residuals: Array of residual values (1D array).
title: Title for the plot.
xlabel: Label for the x-axis.
ylabel: Label for the y-axis.
max_points: Maximum number of points to display (subsamples if needed).
Returns:
The generated matplotlib Figure object.
Raises:
ValueError: If input array shapes are incompatible.
"""
# TODO: Implement residuals plot creation
pass
if not (x.shape == residuals.shape and x.ndim == 1):
raise ValueError("Input arrays (x, residuals) must be 1D and have the same shape.")
if len(x) == 0:
logger.warning("Attempting to create residuals time plot with empty data.")
return plt.figure()
logger.debug(f"Creating residuals time plot: {title}")
fig, ax = plt.subplots(figsize=(15, 5)) # Often wider than tall
n_points = len(x)
indices = np.arange(n_points)
if max_points and n_points > max_points:
step = max(1, n_points // max_points)
plot_indices = indices[::step]
plot_x = x[::step]
plot_residuals = residuals[::step]
effective_title = f'{title} (Sampled {len(plot_indices)} points)'
else:
plot_x = x
plot_residuals = residuals
effective_title = title
ax.plot(plot_x, plot_residuals, marker='.', linestyle='-', markersize=4, linewidth=1, label='Residuals')
ax.axhline(0, color='red', linestyle='--', label='Zero Error', linewidth=1.5)
ax.set_title(effective_title, fontsize=14)
ax.set_xlabel(xlabel, fontsize=12)
ax.set_ylabel(ylabel, fontsize=12)
ax.legend()
ax.grid(True, linestyle='--', alpha=0.6)
fig.tight_layout()
return fig
def create_residuals_distribution_plot(
residuals: np.ndarray,
title: str,
xlabel: str,
ylabel: str
xlabel: str = "Residual Value",
ylabel: str = "Density"
) -> plt.Figure:
"""
Create a distribution plot of residuals.
Create a distribution plot (histogram and KDE) of residuals using seaborn.
Args:
residuals: Array of residual values (1D array).
title: Title for the plot.
xlabel: Label for the x-axis.
ylabel: Label for the y-axis.
Returns:
The generated matplotlib Figure object.
Raises:
ValueError: If input array shape is invalid.
"""
# TODO: Implement residuals distribution plot creation
pass
if residuals.ndim != 1:
raise ValueError("Input array (residuals) must be 1D.")
if len(residuals) == 0:
logger.warning("Attempting to create residuals distribution plot with empty data.")
return plt.figure()
logger.debug(f"Creating residuals distribution plot: {title}")
fig, ax = plt.subplots(figsize=(8, 6))
# Filter out NaNs before plotting and calculating stats
residuals_valid = residuals[~np.isnan(residuals)]
if len(residuals_valid) == 0:
logger.warning(f"No valid (non-NaN) data points found for residual distribution plot '{title}'.")
return fig # Return empty figure
# Use seaborn histplot which combines histogram and KDE
try:
sns.histplot(residuals_valid, kde=True, bins=50, stat="density", ax=ax)
except Exception as e:
logger.error(f"Seaborn histplot failed for '{title}': {e}. Falling back to matplotlib hist.", exc_info=True)
# Fallback to basic matplotlib histogram if seaborn fails
ax.hist(residuals_valid, bins=50, density=True, alpha=0.7)
ylabel = "Frequency" # Adjust label if only histogram shown
mean_res = np.mean(residuals_valid)
std_res = np.std(residuals_valid)
ax.axvline(float(mean_res), color='red', linestyle='--', label=f'Mean: {mean_res:.3f}')
ax.set_title(f'{title}\n(Std Dev: {std_res:.3f})', fontsize=14)
ax.set_xlabel(xlabel, fontsize=12)
ax.set_ylabel(ylabel, fontsize=12)
ax.legend()
ax.grid(True, axis='y', linestyle='--', alpha=0.6)
fig.tight_layout()
return fig

View File

@ -1,18 +1,103 @@
import logging
import numpy as np
import torch
import torch.nn as nn
from typing import Optional
from utils.config_model import ModelConfig
import torch.optim as optim
import pytorch_lightning as pl
import torchmetrics
from typing import Optional, Dict, Any, Union, List, Tuple
from sklearn.preprocessing import StandardScaler, MinMaxScaler
class LSTMForecastModel(nn.Module):
def __init__(self, model_config: ModelConfig):
# Assuming config_model is in sibling directory utils/
from forecasting_model.utils.config_model import ModelConfig, TrainingConfig
logger = logging.getLogger(__name__)
class LSTMForecastLightningModule(pl.LightningModule):
"""
PyTorch Lightning Module for LSTM-based time series forecasting.
Encapsulates the model architecture, training, validation, and test logic.
Uses torchmetrics for efficient metric calculation.
"""
def __init__(
self,
model_config: ModelConfig,
train_config: TrainingConfig,
input_size: int,
target_scaler: Optional[Union[StandardScaler, MinMaxScaler]] = None,
):
super().__init__()
self.config = model_config
self.use_residual_skips = model_config.use_residual_skips
# TODO: Initialize LSTM layers
# TODO: Initialize dropout
# TODO: Initialize output layer
# TODO: Initialize residual connection layer if needed
# --- Validate & Store Configs ---
# Validate the input_size passed during instantiation
if input_size <= 0:
raise ValueError("`input_size` must be provided as a positive integer during model instantiation.")
# Store the validated input_size directly for use in layer definitions
self._input_size = input_size # Use a temporary attribute before hparams are saved
# Ensure forecast_horizon is set in the config for the output layer
if not hasattr(model_config, 'forecast_horizon') or model_config.forecast_horizon is None or model_config.forecast_horizon <= 0:
raise ValueError("ModelConfig requires `forecast_horizon` to be set and positive.")
self.output_size = model_config.forecast_horizon
# Store configurations - input_size argument will be saved via save_hyperparameters
self.model_config = model_config
self.train_config = train_config
self.target_scaler = target_scaler # Store scaler for this fold
# Use save_hyperparameters() to automatically log configs and allow loading
# Pass input_size explicitly to be saved in hparams
# Exclude scaler as it's stateful and fold-specific
self.save_hyperparameters('model_config', 'train_config', 'input_size', ignore=['target_scaler'])
# --- Define Model Layers ---
# Access input_size via hparams now
self.lstm = nn.LSTM(
input_size=self.hparams.input_size,
hidden_size=self.hparams.model_config.hidden_size,
num_layers=self.hparams.model_config.num_layers,
batch_first=True, # Input shape: (batch, seq_len, features)
dropout=self.hparams.model_config.dropout if self.hparams.model_config.num_layers > 1 else 0.0
)
self.dropout = nn.Dropout(self.hparams.model_config.dropout)
# Output layer maps LSTM hidden state to the forecast horizon
# We typically take the output of the last time step
self.fc = nn.Linear(self.hparams.model_config.hidden_size, self.output_size)
# Optional residual connection handling
self.use_residual_skips = self.hparams.model_config.use_residual_skips
self.residual_projection = None
if self.use_residual_skips:
# If input size doesn't match hidden size, project input
if self.hparams.input_size != self.hparams.model_config.hidden_size:
# Use hparams.input_size here
self.residual_projection = nn.Linear(self.hparams.input_size, self.hparams.model_config.hidden_size)
logger.info("Residual connections enabled.")
if self.residual_projection:
logger.info("Residual projection layer added.")
# --- Define Loss Function ---
if self.hparams.train_config.loss_function.upper() == 'MSE':
self.criterion = nn.MSELoss()
elif self.hparams.train_config.loss_function.upper() == 'MAE':
self.criterion = nn.L1Loss()
else:
raise ValueError(f"Unsupported loss function: {self.hparams.train_config.loss_function}")
# --- Define Metrics (TorchMetrics) ---
metrics = torchmetrics.MetricCollection([
torchmetrics.MeanAbsoluteError(),
torchmetrics.MeanSquaredError(squared=False) # RMSE
])
self.train_metrics = metrics.clone(prefix='train_')
self.val_metrics = metrics.clone(prefix='val_')
self.test_metrics = metrics.clone(prefix='test_')
self.val_mae_original_scale = torchmetrics.MeanAbsoluteError()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
@ -24,5 +109,184 @@ class LSTMForecastModel(nn.Module):
Returns:
Predictions tensor of shape (batch_size, forecast_horizon)
"""
# TODO: Implement forward pass with optional residual connections
pass
# LSTM forward pass
lstm_out, (hidden, cell) = self.lstm(x) # Shape: (batch, seq_len, hidden_size)
# Output from the last time step
last_time_step_out = lstm_out[:, -1, :] # Shape: (batch_size, hidden_size)
# Apply dropout
last_time_step_out = self.dropout(last_time_step_out)
# Optional Residual Connection
if self.use_residual_skips:
residual = x[:, -1, :] # Input from the last time step: (batch_size, input_size)
if self.residual_projection:
residual = self.residual_projection(residual) # Project to hidden_size
last_time_step_out = last_time_step_out + residual
# Final fully connected layer
predictions = self.fc(last_time_step_out) # Shape: (batch_size, output_size/horizon)
return predictions # Shape: (batch_size, forecast_horizon)
def _calculate_loss(self, outputs, targets):
# Ensure shapes match before loss calculation
if outputs.shape != targets.shape:
# Squeeze potential extra dim: (batch, horizon, 1) -> (batch, horizon)
if outputs.ndim == targets.ndim + 1 and outputs.shape[-1] == 1:
outputs = outputs.squeeze(-1)
if outputs.shape != targets.shape:
raise ValueError(f"Output shape {outputs.shape} doesn't match target shape {targets.shape} for loss calculation.")
return self.criterion(outputs, targets)
def _inverse_transform(self, data: torch.Tensor) -> Optional[torch.Tensor]:
"""Helper to inverse transform data using the stored target scaler."""
if self.target_scaler is None:
# logger.warning("Cannot inverse transform: target_scaler not available.")
return None # Cannot inverse transform
# Scaler expects 2D input (N, 1)
# Ensure data is on CPU and is float64 for sklearn scaler typically
data_cpu = data.detach().cpu().numpy().astype(np.float64)
original_shape = data_cpu.shape
if data_cpu.ndim == 1:
data_flat = data_cpu.reshape(-1, 1)
elif data_cpu.ndim == 2: # (batch, horizon)
data_flat = data_cpu.reshape(-1, 1)
else:
logger.warning(f"Unexpected shape for inverse transform: {original_shape}. Reshaping to (-1, 1).")
data_flat = data_cpu.reshape(-1, 1)
try:
inversed_np = self.target_scaler.inverse_transform(data_flat)
# Return as tensor on the original device
inversed_tensor = torch.from_numpy(inversed_np).float().to(data.device)
# Reshape back? Or keep flat? Keep flat for direct metric use often.
return inversed_tensor.flatten()
# return inversed_tensor.reshape(original_shape) # If original shape needed
except Exception as e:
logger.error(f"Failed to inverse transform data: {e}", exc_info=True)
return None # Return None if inverse transform fails
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
x, y = batch # Shapes: x=(batch, seq_len, features), y=(batch, horizon)
outputs = self(x) # Scaled outputs: (batch, horizon)
loss = self._calculate_loss(outputs, y)
# Log scaled metrics
metrics = self.train_metrics(outputs, y) # Update internal state
self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log_dict(self.train_metrics, on_step=False, on_epoch=True, logger=True) # Log all metrics in collection
return loss
def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
x, y = batch
outputs = self(x) # Scaled outputs
loss = self._calculate_loss(outputs, y)
# Log scaled metrics
metrics = self.val_metrics(outputs, y) # Update internal state
self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log_dict(self.val_metrics, on_step=False, on_epoch=True, logger=True)
# Log MAE on ORIGINAL scale if scaler is available (often the primary metric for checkpointing/Optuna)
if self.target_scaler is not None:
outputs_inv = self._inverse_transform(outputs)
y_inv = self._inverse_transform(y)
if outputs_inv is not None and y_inv is not None:
# Ensure shapes are compatible (flattened by _inverse_transform)
if outputs_inv.shape == y_inv.shape:
self.val_mae_original_scale.update(outputs_inv, y_inv)
self.log('val_mae_orig_scale', self.val_mae_original_scale, on_step=False, on_epoch=True, prog_bar=True, logger=True)
else:
logger.warning(f"Shape mismatch after inverse transform in validation: Preds {outputs_inv.shape}, Targets {y_inv.shape}")
else:
logger.warning("Could not compute original scale MAE in validation due to inverse transform failure.")
def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
# Optional: Keep this method ONLY if you want trainer.test() to log metrics.
# For getting predictions for evaluation, use predict_step.
# If evaluate_fold_predictions handles all metrics, this can be simplified or removed.
# Let's simplify it for now to only log loss if needed.
try:
x, y = batch
outputs = self(x)
loss = self._calculate_loss(outputs, y)
# Log scaled test metrics if you still want trainer.test() to report them
metrics = self.test_metrics(outputs, y)
self.log('test_loss_step', loss, on_step=True, on_epoch=False) # Log step loss if needed
self.log_dict(self.test_metrics, on_step=False, on_epoch=True, logger=True)
# No return needed if just logging
except Exception as e:
logger.error(f"Error occurred in test_step for batch {batch_idx}: {e}", exc_info=True)
# Optionally log something to indicate failure
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Dict[str, torch.Tensor]:
"""
Runs inference for prediction and returns scaled predictions and targets.
'batch' might contain only features depending on the DataLoader setup for predict.
Let's assume the test_loader yields (x, y) pairs for convenience here.
"""
if isinstance(batch, (list, tuple)) and len(batch) == 2:
x, y = batch
else:
# Assume batch contains only features if not a pair
x = batch
y = None # No targets available during prediction if dataloader only yields features
outputs = self(x) # Scaled outputs
result = {'preds_scaled': outputs.detach().cpu()}
if y is not None:
# Include targets if they were part of the batch (e.g., using test_loader for predict)
result['targets_scaled'] = y.detach().cpu()
return result
def configure_optimizers(self) -> Union[optim.Optimizer, Tuple[List[optim.Optimizer], List[Dict[str, Any]]]]:
"""
Configure the optimizer (Adam) and optional LR scheduler.
"""
optimizer = optim.Adam(
self.parameters(),
lr=self.hparams.train_config.learning_rate # Access lr via hparams
)
logger.info(f"Configured Adam optimizer with LR: {self.hparams.train_config.learning_rate}")
# Optional LR Scheduler configuration
scheduler_config = None
if hasattr(self.hparams.train_config, 'scheduler_step_size') and \
self.hparams.train_config.scheduler_step_size is not None and \
hasattr(self.hparams.train_config, 'scheduler_gamma') and \
self.hparams.train_config.scheduler_gamma is not None:
if self.hparams.train_config.scheduler_step_size > 0 and 0 < self.hparams.train_config.scheduler_gamma < 1:
logger.info(f"Configuring StepLR scheduler with step_size={self.hparams.train_config.scheduler_step_size} "
f"and gamma={self.hparams.train_config.scheduler_gamma}")
scheduler = optim.lr_scheduler.StepLR(
optimizer,
step_size=self.hparams.train_config.scheduler_step_size,
gamma=self.hparams.train_config.scheduler_gamma
)
scheduler_config = {
'scheduler': scheduler,
'interval': 'epoch', # or 'step'
'frequency': 1,
'monitor': 'val_loss', # Optional: Only step if monitor improves (for ReduceLROnPlateau)
}
else:
logger.warning("Scheduler parameters provided but invalid (step_size must be >0, 0<gamma<1). No scheduler configured.")
# Example for ReduceLROnPlateau (if needed later)
# scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
# scheduler_config = {'scheduler': scheduler, 'monitor': 'val_loss'}
if scheduler_config:
return [optimizer], [scheduler_config]
else:
return optimizer

View File

@ -1,50 +0,0 @@
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from typing import Optional, Dict, Any
from ..utils.config_model import TrainingConfig
class Trainer:
def __init__(
self,
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
loss_fn: nn.Module,
device: torch.device,
config: TrainingConfig,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
target_scaler: Optional[Any] = None
):
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
self.loss_fn = loss_fn
self.device = device
self.config = config
self.scheduler = scheduler
self.target_scaler = target_scaler
# TODO: Initialize optimizer (Adam)
# TODO: Initialize early stopping if configured
def train_epoch(self) -> Dict[str, float]:
"""
Train for one epoch.
"""
# TODO: Implement training loop for one epoch
pass
def evaluate(self, loader: DataLoader) -> Dict[str, float]:
"""
Evaluate model on given data loader.
"""
# TODO: Implement evaluation with metrics on original scale
pass
def train(self) -> Dict[str, Any]:
"""
Main training loop with validation and early stopping.
"""
# TODO: Implement full training loop with validation
pass

View File

@ -2,4 +2,32 @@
Utility functions and classes for the forecasting model.
This package contains configuration models, helper functions, and other utilities.
"""
"""
# Expose configuration models
from .config_model import (
MainConfig,
DataConfig,
FeatureConfig,
ModelConfig,
TrainingConfig,
CrossValidationConfig,
EvaluationConfig,
OptunaConfig,
WaveletTransformConfig, # Expose nested configs if they might be used directly
ClippingConfig
)
# Define __all__ for explicit public API
__all__ = [
"MainConfig",
"DataConfig",
"FeatureConfig",
"ModelConfig",
"TrainingConfig",
"CrossValidationConfig",
"EvaluationConfig",
"OptunaConfig",
"WaveletTransformConfig",
"ClippingConfig",
]

View File

@ -1,62 +1,151 @@
from pydantic import BaseModel, Field
from typing import Optional, List, Union
from pydantic import BaseModel, Field, field_validator, model_validator
from typing import Optional, List, Union, Literal
from enum import Enum
# --- Nested Configs ---
class WaveletTransformConfig(BaseModel):
"""Configuration for optional wavelet transform features."""
apply: bool = False
target_or_feature: str = "target"
target_or_feature: Literal['target', 'feature'] = "target"
wavelet_type: str = "db4"
level: int = 3
level: int = Field(3, gt=0) # Level must be positive
use_coeffs: List[str] = ["approx", "detail_1"]
class ClippingConfig(BaseModel):
"""Configuration for optional feature clipping."""
apply: bool = False
clip_min: float = -5.0
clip_max: float = 5.0
@model_validator(mode='after')
def check_clip_range(self) -> 'ClippingConfig':
if self.apply and self.clip_max <= self.clip_min:
raise ValueError(f'clip_max ({self.clip_max}) must be greater than clip_min ({self.clip_min}) when clipping is applied.')
return self
# --- Main Config Sections ---
class DataConfig(BaseModel):
data_path: str
datetime_col: str
target_col: str
"""Configuration related to data loading and initial preparation."""
data_path: str = Field(..., description="Path to the input CSV data file.")
# --- Raw Data Specifics ---
raw_datetime_col: str = Field(..., description="Name of the raw datetime column in the CSV (e.g., 'MTU (CET/CEST)')")
raw_target_col: str = Field(..., description="Name of the raw target/price column in the CSV (e.g., 'Day-ahead Price [EUR/MWh]')")
raw_datetime_format: str = '%d.%m.%Y %H:%M' # Example, make it configurable if needed
# --- Standardized Names & Processing ---
datetime_col: str = Field(..., description="Standardized name for the datetime index after processing (e.g., 'Timestamp')")
target_col: str = Field(..., description="Standardized name for the target column after processing (e.g., 'Price')")
expected_frequency: Optional[str] = Field('h', description="Expected pandas frequency string (e.g., 'h', 'D', '15min'). If null, no frequency check/setting is performed.")
fill_initial_target_nans: bool = Field(True, description="Forward/backward fill NaNs in the target column immediately after loading?")
class FeatureConfig(BaseModel):
sequence_length: int
forecast_horizon: int
lags: List[int]
rolling_window_sizes: List[int]
use_time_features: bool
scaling_method: Optional[str] = None
"""Configuration for feature engineering and preprocessing."""
sequence_length: int = Field(..., gt=0)
forecast_horizon: int = Field(..., gt=0)
lags: List[int] = []
rolling_window_sizes: List[int] = []
use_time_features: bool = True
sinus_curve: bool = False # Added
cosin_curve: bool = False # Added
wavelet_transform: Optional[WaveletTransformConfig] = None
fill_nan: Optional[Union[str, float, int]] = 'ffill' # Added (e.g., 'ffill', 0)
clipping: ClippingConfig = ClippingConfig() # Default instance
scaling_method: Optional[Literal['standard', 'minmax']] = 'standard' # Added literal validation
@field_validator('lags', 'rolling_window_sizes')
@classmethod
def check_positive_list_values(cls, v: List[int]) -> List[int]:
if any(val <= 0 for val in v):
raise ValueError('Lists lags/rolling_window_sizes must contain only positive values')
return v
class ModelConfig(BaseModel):
input_size: Optional[int] = None # Will be calculated
hidden_size: int
num_layers: int
dropout: float
"""Configuration for the forecasting model architecture."""
# input_size: Optional[int] = Field(None, gt=0) # Removed: Determined dynamically
hidden_size: int = Field(..., gt=0)
num_layers: int = Field(..., gt=0)
dropout: float = Field(..., ge=0.0, le=1.0)
use_residual_skips: bool = False
output_size: Optional[int] = None # Will be calculated
# Add forecast_horizon here to ensure LightningModule gets it directly
forecast_horizon: Optional[int] = Field(None, gt=0) # Will be set from FeatureConfig
class TrainingConfig(BaseModel):
batch_size: int
epochs: int
learning_rate: float
loss_function: str
device: str
early_stopping_patience: Optional[int] = None
scheduler_step_size: Optional[int] = None
scheduler_gamma: Optional[float] = None
"""Configuration for the training process (PyTorch Lightning)."""
batch_size: int = Field(..., gt=0)
epochs: int = Field(..., gt=0) # Max epochs
learning_rate: float = Field(..., gt=0.0)
loss_function: Literal['MSE', 'MAE'] = 'MSE'
# device: str = 'auto' # Handled by PL Trainer accelerator/devices args
early_stopping_patience: Optional[int] = Field(None, ge=1) # Patience must be >= 1 if set
scheduler_step_size: Optional[int] = Field(None, gt=0)
scheduler_gamma: Optional[float] = Field(None, gt=0.0, lt=1.0)
gradient_clip_val: Optional[float] = Field(None, ge=0.0) # Added
num_workers: int = Field(0, ge=0) # Added
precision: Literal[16, 32, 64, 'bf16'] = 32 # Added
class CrossValidationConfig(BaseModel):
n_splits: int
test_size_fraction: float
val_size_fraction: float
initial_train_size: Optional[Union[int, float]] = None
"""Configuration for time series cross-validation."""
n_splits: int = Field(..., gt=0)
test_size_fraction: float = Field(..., gt=0.0, lt=1.0, description="Fraction of the fixed training window size for the test set.")
val_size_fraction: float = Field(..., gt=0.0, lt=1.0, description="Fraction of the fixed training window size for the validation set.")
initial_train_size: Optional[Union[int, float]] = Field(None, gt=0.0, description="Size of the fixed training window (absolute number or fraction of total data > 0). If null, estimated automatically.")
class EvaluationConfig(BaseModel):
metrics: List[str]
eval_batch_size: int
save_plots: bool
plot_sample_size: int
"""Configuration for the final evaluation process."""
# metrics: List[str] = ['MAE', 'RMSE'] # Defined internally now
eval_batch_size: int = Field(..., gt=0)
save_plots: bool = True
plot_sample_size: Optional[int] = Field(1000, gt=0) # Max points for plots
class OptunaConfig(BaseModel):
"""Optional configuration for Optuna hyperparameter optimization."""
enabled: bool = False
n_trials: int = Field(20, gt=0)
storage: Optional[str] = None # e.g., "sqlite:///output/hpo_results/study.db"
direction: Literal['minimize', 'maximize'] = 'minimize'
metric_to_optimize: str = 'val_mae_orig_scale'
pruning: bool = True
# --- Top-Level Configuration Model ---
class MainConfig(BaseModel):
"""Main configuration model nesting all sections."""
project_name: str = "TimeSeriesForecasting"
random_seed: Optional[int] = 42 # Added top-level seed
data: DataConfig
features: FeatureConfig
model: ModelConfig
model: ModelConfig # ModelConfig no longer contains input_size
training: TrainingConfig
cross_validation: CrossValidationConfig
evaluation: EvaluationConfig
evaluation: EvaluationConfig
optuna: Optional[OptunaConfig] = OptunaConfig() # Added optional Optuna config
@model_validator(mode='after')
def check_forecast_horizon_consistency(self) -> 'MainConfig':
# Ensure model config gets forecast horizon from features config if not set
if self.features and self.model:
if self.model.forecast_horizon is None:
# If model config doesn't have it, set it from features config
self.model.forecast_horizon = self.features.forecast_horizon
elif self.model.forecast_horizon != self.features.forecast_horizon:
# If both are set but differ, raise error
raise ValueError(
f"ModelConfig forecast_horizon ({self.model.forecast_horizon}) must match "
f"FeatureConfig forecast_horizon ({self.features.forecast_horizon})."
)
# After potential setting, ensure model.forecast_horizon is actually set
if self.model and (self.model.forecast_horizon is None or self.model.forecast_horizon <= 0):
raise ValueError("ModelConfig requires a positive forecast_horizon (must be set in features config if not set explicitly in model config).")
# Input size check is removed as it's not part of static config anymore
return self
class Config:
# Example configuration for Pydantic itself
validate_assignment = True # Re-validate on assignment
# extra = 'forbid' # Forbid extra fields not defined in schema

468
forecasting_model_run.py Normal file
View File

@ -0,0 +1,468 @@
import argparse
import logging
import sys
import os
import random
from pathlib import Path
import time
import json
import numpy as np
import pandas as pd
import torch
import yaml
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import CSVLogger
# Import necessary components from your project structure
# Assuming forecasting_model is a package installable or in PYTHONPATH
from forecasting_model.utils.config_model import MainConfig
from forecasting_model.data_processing import (
load_raw_data,
TimeSeriesCrossValidationSplitter,
prepare_fold_data_and_loaders
)
from forecasting_model.model import LSTMForecastLightningModule
from forecasting_model.evaluation import evaluate_fold_predictions
from typing import Dict, List, Any, Optional
# Silence overly verbose libraries if needed
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)
pil_logger = logging.getLogger('PIL')
pil_logger.setLevel(logging.WARNING)
# --- Basic Logging Setup ---
# Configure logging early. Level might be adjusted by config.
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)-7s - %(message)s',
datefmt='%H:%M:%S')
# Get the root logger
logger = logging.getLogger()
# --- Argument Parsing ---
def parse_arguments():
"""Parses command-line arguments."""
parser = argparse.ArgumentParser(
description="Run the Time Series Forecasting training pipeline.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
'-c', '--config',
type=str,
default='config.yaml',
help="Path to the YAML configuration file."
)
parser.add_argument(
'--seed',
type=int,
default=None, # Default to None, use config value if not provided
help="Override random seed defined in config."
)
parser.add_argument(
'--debug',
action='store_true',
help="Override log level to DEBUG."
)
parser.add_argument(
'--output-dir',
type=str,
default='output/cv_results', # Default output base directory
help="Base directory for saving cross-validation results (checkpoints, logs, plots)."
)
args = parser.parse_args()
return args
# --- Helper Functions ---
def load_config(config_path: Path) -> MainConfig:
"""
Load and validate configuration from YAML file using Pydantic.
Args:
config_path: Path to the YAML configuration file.
Returns:
Validated MainConfig object.
Raises:
FileNotFoundError: If the config file doesn't exist.
yaml.YAMLError: If the file is not valid YAML.
pydantic.ValidationError: If the config doesn't match the schema.
"""
if not config_path.is_file():
logger.error(f"Configuration file not found at: {config_path}")
raise FileNotFoundError(f"Config file not found: {config_path}")
logger.info(f"Loading configuration from: {config_path}")
try:
with open(config_path, 'r') as f:
config_dict = yaml.safe_load(f)
# Validate configuration using Pydantic model
config = MainConfig(**config_dict)
logger.info("Configuration loaded and validated successfully.")
return config
except yaml.YAMLError as e:
logger.error(f"Error parsing YAML file {config_path}: {e}", exc_info=True)
raise
except Exception as e: # Catches Pydantic validation errors too
logger.error(f"Error validating configuration {config_path}: {e}", exc_info=True)
raise
def set_seeds(seed: Optional[int] = 42) -> None:
"""
Set random seeds for reproducibility across libraries.
Args:
seed: The seed value to use. If None, uses default 42.
"""
if seed is None:
seed = 42
logger.warning(f"No seed provided, using default seed: {seed}")
else:
logger.info(f"Setting random seed: {seed}")
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# Ensure reproducibility for CUDA operations where possible
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # For multi-GPU
# These settings can slow down training but improve reproducibility
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# PyTorch Lightning seeding (optional, as we seed torch directly)
# pl.seed_everything(seed, workers=True) # workers=True ensures dataloader reproducibility
def aggregate_cv_metrics(all_fold_metrics: List[Dict[str, float]]) -> Dict[str, Dict[str, float]]:
"""
Calculate mean and standard deviation of metrics across folds.
Handles potential NaN values by ignoring them.
Args:
all_fold_metrics: A list where each element is a dictionary of
metrics for one fold (e.g., {'MAE': v1, 'RMSE': v2}).
Returns:
A dictionary where keys are metric names and values are dicts
containing 'mean' and 'std' for that metric across folds.
Example: {'MAE': {'mean': m, 'std': s}, 'RMSE': {'mean': m2, 'std': s2}}
"""
if not all_fold_metrics:
logger.warning("Received empty list for metric aggregation.")
return {}
aggregated: Dict[str, Dict[str, float]] = {}
# Get metric names from the first valid fold's results
first_valid_metrics = next((m for m in all_fold_metrics if m), None)
if not first_valid_metrics:
logger.warning("No valid fold metrics found for aggregation.")
return {}
metric_names = list(first_valid_metrics.keys())
for metric in metric_names:
# Collect values for this metric across all folds, ignoring NaNs
values = [fold_metrics.get(metric) for fold_metrics in all_fold_metrics if fold_metrics and metric in fold_metrics]
valid_values = [v for v in values if v is not None and not np.isnan(v)]
if not valid_values:
logger.warning(f"No valid values found for metric '{metric}' across folds.")
mean_val = np.nan
std_val = np.nan
else:
mean_val = float(np.mean(valid_values))
std_val = float(np.std(valid_values))
logger.debug(f"Aggregated '{metric}': Mean={mean_val:.4f}, Std={std_val:.4f} from {len(valid_values)} folds.")
aggregated[metric] = {'mean': mean_val, 'std': std_val}
return aggregated
def save_results(results: Dict, filename: Path):
"""Save dictionary results to a JSON file."""
try:
filename.parent.mkdir(parents=True, exist_ok=True)
with open(filename, 'w') as f:
json.dump(results, f, indent=4)
logger.info(f"Saved results to {filename}")
except Exception as e:
logger.error(f"Failed to save results to {filename}: {e}", exc_info=True)
# --- Main Training & Evaluation Function ---
def run_training_pipeline(config: MainConfig, output_base_dir: Path):
"""Runs the full cross-validation training and evaluation pipeline."""
start_time = time.perf_counter()
# --- Data Loading ---
try:
df = load_raw_data(config.data)
except Exception as e:
logger.critical(f"Failed to load raw data: {e}", exc_info=True)
sys.exit(1) # Cannot proceed without data
# --- Cross-Validation Setup ---
try:
cv_splitter = TimeSeriesCrossValidationSplitter(config.cross_validation, len(df))
except ValueError as e:
logger.critical(f"Failed to initialize CV splitter: {e}", exc_info=True)
sys.exit(1)
all_fold_test_metrics: List[Dict[str, float]] = []
all_fold_best_val_scores: Dict[int, Optional[float]] = {} # Store best val score per fold
# --- Cross-Validation Loop ---
logger.info(f"Starting {config.cross_validation.n_splits}-Fold Cross-Validation...")
for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split()):
fold_start_time = time.perf_counter()
fold_id = fold_num + 1
logger.info(f"--- Starting Fold {fold_id}/{config.cross_validation.n_splits} ---")
fold_output_dir = output_base_dir / f"fold_{fold_id:02d}"
fold_output_dir.mkdir(parents=True, exist_ok=True)
logger.debug(f"Fold output directory: {fold_output_dir}")
try:
# --- Per-Fold Data Preparation ---
logger.info("Preparing data loaders for the fold...")
train_loader, val_loader, test_loader, target_scaler, input_size = prepare_fold_data_and_loaders(
full_df=df,
train_idx=train_idx,
val_idx=val_idx,
test_idx=test_idx,
target_col=config.data.target_col, # Pass target col name explicitly
feature_config=config.features,
train_config=config.training,
eval_config=config.evaluation
)
logger.info(f"Data loaders prepared. Input size determined: {input_size}")
# --- Model Initialization ---
# Pass input_size directly, ModelConfig no longer holds it.
# Ensure forecast horizon is consistent (checked in MainConfig validation)
current_model_config = config.model # Use the validated model config
model = LSTMForecastLightningModule(
model_config=current_model_config, # Does not contain input_size
train_config=config.training,
input_size=input_size, # Pass the dynamically determined input_size
target_scaler=target_scaler # Pass the fold-specific scaler
)
logger.info("LSTMForecastLightningModule initialized.")
# --- PyTorch Lightning Callbacks ---
# Monitor the validation MAE on the original scale (logged by LightningModule)
monitor_metric = "val_mae_orig_scale"
monitor_mode = "min"
early_stop_callback = None
if config.training.early_stopping_patience is not None and config.training.early_stopping_patience > 0:
early_stop_callback = EarlyStopping(
monitor=monitor_metric,
min_delta=0.0001, # Minimum change to qualify as improvement
patience=config.training.early_stopping_patience,
verbose=True,
mode=monitor_mode
)
logger.info(f"Enabled EarlyStopping: monitor='{monitor_metric}', patience={config.training.early_stopping_patience}")
# Checkpoint callback to save the best model based on validation metric
checkpoint_callback = ModelCheckpoint(
dirpath=fold_output_dir / "checkpoints",
filename=f"best_model_fold_{fold_id}", # {{epoch}}-{{val_loss:.2f}} etc. possible
save_top_k=1,
monitor=monitor_metric,
mode=monitor_mode,
verbose=True
)
logger.info(f"Enabled ModelCheckpoint: monitor='{monitor_metric}', mode='{monitor_mode}'")
# Learning rate monitor callback
lr_monitor = LearningRateMonitor(logging_interval='epoch')
callbacks = [checkpoint_callback, lr_monitor]
if early_stop_callback:
callbacks.append(early_stop_callback)
# --- PyTorch Lightning Logger ---
# Log metrics to a CSV file within the fold directory
pl_logger = CSVLogger(save_dir=str(output_base_dir), name=f"fold_{fold_id:02d}", version='logs')
logger.info(f"Using CSVLogger, logs will be saved in: {pl_logger.log_dir}")
# --- PyTorch Lightning Trainer ---
# Determine accelerator and devices based on PyTorch check
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
devices = 1 if accelerator == 'gpu' else None # Or specify specific GPU IDs [0], [1] etc.
precision = getattr(config.training, 'precision', 32) # Default to 32-bit
trainer = pl.Trainer(
accelerator=accelerator,
devices=devices,
max_epochs=config.training.epochs,
callbacks=callbacks,
logger=pl_logger,
log_every_n_steps=max(1, len(train_loader)//10), # Log ~10 times per epoch
enable_progress_bar=True, # Set to False for less verbose runs (e.g., HPO)
gradient_clip_val=getattr(config.training, 'gradient_clip_val', None),
precision=precision,
# deterministic=True, # For stricter reproducibility (can slow down)
)
logger.info(f"Initialized PyTorch Lightning Trainer: accelerator='{accelerator}', devices={devices}, precision={precision}")
# --- Training ---
logger.info(f"Starting training for Fold {fold_id}...")
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
logger.info(f"Training finished for Fold {fold_id}.")
# Store best validation score for this fold
best_val_score = trainer.checkpoint_callback.best_model_score
best_model_path = trainer.checkpoint_callback.best_model_path
all_fold_best_val_scores[fold_id] = best_val_score.item() if best_val_score else None
if best_val_score is not None:
logger.info(f"Best validation score ({monitor_metric}) for Fold {fold_id}: {all_fold_best_val_scores[fold_id]:.4f}")
logger.info(f"Best model checkpoint path: {best_model_path}")
else:
logger.warning(f"Could not retrieve best validation score/path for Fold {fold_id} (metric: {monitor_metric}). Evaluation might use last model.")
best_model_path = None # Ensure evaluation doesn't try to load 'best' if checkpointing failed
# --- Prediction on Test Set ---
# Use trainer.predict() to get model outputs
logger.info(f"Starting prediction for Fold {fold_id} using best checkpoint...")
# predict_step returns dict {'preds_scaled': ..., 'targets_scaled': ...}
# We pass the test_loader here, which yields (x, y) pairs, so predict_step will include targets
prediction_results_list = trainer.predict(
# model=model, # Not needed if using ckpt_path
ckpt_path=best_model_path if best_model_path else 'last', # Load best model or last if best failed
dataloaders=test_loader
# return_predictions=True # Default is True
)
# Check if prediction returned results
if not prediction_results_list:
logger.error(f"Predict phase did not return any results for Fold {fold_id}. Check predict_step and logs.")
fold_metrics = {'MAE': np.nan, 'RMSE': np.nan}
else:
try:
# Concatenate predictions and targets from predict_step results
all_preds_scaled = torch.cat([batch_res['preds_scaled'] for batch_res in prediction_results_list], dim=0).numpy()
# Check if targets were included (they should be if using test_loader)
if 'targets_scaled' in prediction_results_list[0]:
all_targets_scaled = torch.cat([batch_res['targets_scaled'] for batch_res in prediction_results_list], dim=0).numpy()
else:
# This case shouldn't happen if using test_loader, but good safeguard
logger.error(f"Targets not found in prediction results for Fold {fold_id}. Cannot evaluate.")
raise ValueError("Targets missing from prediction results.")
# --- Final Evaluation & Plotting ---
logger.info(f"Processing prediction results for Fold {fold_id}...")
fold_metrics = evaluate_fold_predictions(
y_true_scaled=all_targets_scaled,
y_pred_scaled=all_preds_scaled,
target_scaler=target_scaler, # Use the scaler from this fold
eval_config=config.evaluation,
fold_num=fold_num, # Pass zero-based index
output_dir=output_base_dir, # Base dir for saving plots etc.
# time_index=df.iloc[test_idx].index # Pass time index if needed
)
# Save fold metrics
save_results(fold_metrics, fold_output_dir / "test_metrics.json")
except KeyError as e:
logger.error(f"KeyError processing prediction results for Fold {fold_id}: Missing key {e}. Check predict_step return format.", exc_info=True)
fold_metrics = {'MAE': np.nan, 'RMSE': np.nan}
except Exception as e:
logger.error(f"Error processing prediction results for Fold {fold_id}: {e}", exc_info=True)
fold_metrics = {'MAE': np.nan, 'RMSE': np.nan}
all_fold_test_metrics.append(fold_metrics)
# --- (Optional) Log final test metrics using trainer.test() ---
# If you want the metrics logged by test_step aggregated, call test now.
# logger.info(f"Logging final test metrics via trainer.test() for Fold {fold_id}...")
# try:
# trainer.test(ckpt_path=best_model_path if best_model_path else 'last', dataloaders=test_loader, verbose=False)
# except Exception as e:
# logger.warning(f"trainer.test() call failed for Fold {fold_id}: {e}")
except Exception as e:
# Catch errors during the fold processing (data prep, training, prediction, eval)
logger.error(f"An error occurred during Fold {fold_id} pipeline: {e}", exc_info=True)
all_fold_test_metrics.append({'MAE': np.nan, 'RMSE': np.nan})
# --- Cleanup per fold ---
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.debug("Cleared CUDA cache.")
fold_end_time = time.perf_counter()
logger.info(f"--- Finished Fold {fold_id} in {fold_end_time - fold_start_time:.2f} seconds ---")
# --- Aggregation and Final Reporting ---
logger.info("Cross-validation finished. Aggregating results...")
aggregated_metrics = aggregate_cv_metrics(all_fold_test_metrics)
# Save aggregated results
final_results = {
'aggregated_test_metrics': aggregated_metrics,
'per_fold_test_metrics': all_fold_test_metrics,
'per_fold_best_val_scores': all_fold_best_val_scores,
}
save_results(final_results, output_base_dir / "aggregated_cv_results.json")
# Log final results
logger.info("--- Aggregated Cross-Validation Test Results ---")
if aggregated_metrics:
for metric, stats in aggregated_metrics.items():
logger.info(f"{metric}: {stats['mean']:.4f} ± {stats['std']:.4f}")
else:
logger.warning("No metrics available for aggregation.")
logger.info("-------------------------------------------------")
end_time = time.perf_counter()
logger.info(f"Training pipeline finished successfully in {end_time - start_time:.2f} seconds.")
# --- Main Execution ---
def run():
"""Main execution function."""
args = parse_arguments()
config_path = Path(args.config)
output_dir = Path(args.output_dir)
# Adjust log level if debug flag is set
if args.debug:
logger.setLevel(logging.DEBUG)
logger.debug("# --- Debug mode enabled. --- #")
# --- Configuration Loading ---
try:
config = load_config(config_path)
except Exception:
# Error already logged in load_config
sys.exit(1)
# --- Seed Setting ---
# Use command-line seed if provided, otherwise use config seed
seed = args.seed if args.seed is not None else getattr(config, 'random_seed', 42)
set_seeds(seed)
# --- Pipeline Execution ---
try:
run_training_pipeline(config, output_dir)
except SystemExit as e:
logger.warning(f"Pipeline exited with code {e.code}.")
sys.exit(e.code) # Propagate exit code
except Exception as e:
logger.critical(f"An critical error occurred during pipeline execution: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
run()

View File

@ -10,9 +10,6 @@ from forecasting_model.data_processing import (
TimeSeriesCrossValidationSplitter,
prepare_fold_data_and_loaders
)
from forecasting_model.model import LSTMForecastModel
from forecasting_model.trainer import Trainer
from forecasting_model.evaluation import evaluate_fold
# Configure logging
logging.basicConfig(

395
optuna_run.py Normal file
View File

@ -0,0 +1,395 @@
import argparse
import logging
import sys
import copy # For deep copying config
from pathlib import Path
import time
import numpy as np
import pandas as pd
import torch
import optuna
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
# Import the Optuna callback for pruning
from optuna.integration.pytorch_lightning import PyTorchLightningPruningCallback
# Import necessary components from the forecasting_model package
from forecasting_model.utils.config_model import MainConfig
from forecasting_model.data_processing import (
load_raw_data,
TimeSeriesCrossValidationSplitter,
prepare_fold_data_and_loaders
)
from forecasting_model.model import LSTMForecastLightningModule
# We don't need evaluation functions here, Optuna optimizes based on validation metrics
# from forecasting_model.evaluation import ...
from typing import Dict, List, Any, Optional
# Import helper functions from forecasting_model.py (or move them to a shared utils file)
# For now, let's redefine simplified versions or assume they exist in utils
from forecasting_model_run import load_config, set_seeds # Assuming these are accessible
# Silence overly verbose libraries if needed
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)
pil_logger = logging.getLogger('PIL')
pil_logger.setLevel(logging.WARNING)
pl_logger = logging.getLogger('pytorch_lightning')
pl_logger.setLevel(logging.INFO) # Keep PL logs, but maybe set higher later
# --- Basic Logging Setup ---
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)-25s - %(levelname)-7s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
root_logger = logging.getLogger()
logger = logging.getLogger(__name__) # Logger for this script
optuna_lg = logging.getLogger('optuna') # Optuna's logger
# --- Argument Parsing ---
def parse_arguments():
"""Parses command-line arguments for Optuna HPO."""
parser = argparse.ArgumentParser(
description="Run Hyperparameter Optimization using Optuna for Time Series Forecasting.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
'-c', '--config',
type=str,
default='config.yaml',
help="Path to the BASE YAML configuration file."
)
parser.add_argument(
'--output-dir',
type=str,
default='output/hpo_results',
help="Directory for saving Optuna study database and potentially best trial info."
)
parser.add_argument(
'--study-name',
type=str,
default='lstm_forecasting_hpo',
help="Name for the Optuna study."
)
parser.add_argument(
'--n-trials',
type=int,
default=20,
help="Number of Optuna trials to run."
)
parser.add_argument(
'--storage-db',
type=str,
default=None, # Default to in-memory if not specified
help="Optuna storage database URL (e.g., 'sqlite:///output/hpo_results/study.db'). If None, uses in-memory storage."
)
parser.add_argument(
'--metric-to-optimize',
type=str,
default='val_mae_orig_scale',
help="Metric logged during validation to optimize (must match metric name in LightningModule)."
)
parser.add_argument(
'--direction',
type=str,
default='minimize',
choices=['minimize', 'maximize'],
help="Direction for Optuna optimization."
)
parser.add_argument(
'--pruning',
action='store_true',
help="Enable Optuna's trial pruning based on intermediate validation results."
)
parser.add_argument(
'--seed',
type=int,
default=42, # Fixed seed for the HPO process itself
help="Random seed for the main HPO script."
)
parser.add_argument(
'--debug',
action='store_true',
help="Override log level to DEBUG."
)
args = parser.parse_args()
return args
# --- Optuna Objective Function ---
def objective(
trial: optuna.Trial,
base_config: MainConfig, # Pass the loaded base config
df: pd.DataFrame, # Pass the loaded data
output_base_dir: Path, # Base dir for any potential trial artifacts (usually avoid saving checkpoints here)
metric_to_optimize: str,
enable_pruning: bool
) -> float:
"""
Optuna objective function. Trains and evaluates one set of hyperparameters
using cross-validation and returns the average validation metric.
"""
logger.info(f"\n--- Starting Optuna Trial {trial.number} ---")
trial_start_time = time.perf_counter()
# --- 1. Suggest Hyperparameters ---
# Make a deep copy of the base config to modify for this trial
# Using dict conversion and back might be easier than Pydantic's copy for deep nested updates
try:
trial_config_dict = copy.deepcopy(base_config.dict()) # Convert to dict for easier modification
except Exception as e:
logger.error(f"Failed to deep copy base configuration: {e}")
raise # Cannot proceed without config
# Suggest values for hyperparameters we want to tune
# Example suggestions (adjust ranges and types as needed):
trial_config_dict['training']['learning_rate'] = trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True)
trial_config_dict['training']['batch_size'] = trial.suggest_categorical('batch_size', [32, 64, 128])
trial_config_dict['model']['hidden_size'] = trial.suggest_int('hidden_size', 32, 256, step=32)
trial_config_dict['model']['num_layers'] = trial.suggest_int('num_layers', 1, 4)
trial_config_dict['model']['dropout'] = trial.suggest_float('dropout', 0.0, 0.5, step=0.1)
# Example: Suggest sequence length? (Requires careful handling as it affects data prep)
# trial_config_dict['features']['sequence_length'] = trial.suggest_int('sequence_length', 24, 168, step=24)
# --- 2. Re-validate Trial Config (Optional but Recommended) ---
try:
trial_config = MainConfig(**trial_config_dict)
logger.debug(f"Trial {trial.number} Config: {trial_config.training} {trial_config.model} {trial_config.features}")
except Exception as e:
logger.error(f"Trial {trial.number}: Invalid configuration generated from suggested parameters: {e}")
# Return a high value (for minimization) to penalize invalid configs
return float('inf')
# --- 3. Run Cross-Validation for this Trial ---
cv_splitter = TimeSeriesCrossValidationSplitter(trial_config.cross_validation, len(df))
fold_best_val_metrics: List[Optional[float]] = []
for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split()):
fold_id = fold_num + 1
logger.info(f"Trial {trial.number}, Fold {fold_id}: Starting fold evaluation.")
fold_start_time = time.perf_counter()
# Create a temporary directory for this specific trial+fold if needed (usually avoid for HPO)
# fold_trial_dir = output_base_dir / f"trial_{trial.number}" / f"fold_{fold_id:02d}"
# fold_trial_dir.mkdir(parents=True, exist_ok=True)
try:
# --- Per-Fold Data Prep ---
# Use trial_config for batch sizes etc.
train_loader, val_loader, _, target_scaler, input_size = prepare_fold_data_and_loaders(
full_df=df, train_idx=train_idx, val_idx=val_idx, test_idx=test_idx, # Test loader not needed here
target_col=trial_config.data.target_col,
feature_config=trial_config.features,
train_config=trial_config.training,
eval_config=trial_config.evaluation # Pass eval for batch size if needed by prep?
)
# --- Model Instantiation ---
current_model_config = trial_config.model.copy(update={'input_size': input_size,
'forecast_horizon': trial_config.features.forecast_horizon})
model = LSTMForecastLightningModule(
model_config=current_model_config,
train_config=trial_config.training,
target_scaler=target_scaler
)
# --- Callbacks for this Trial/Fold ---
# Monitor the metric Optuna cares about
monitor_mode = "min" if args.direction == "minimize" else "max"
callbacks = []
if trial_config.training.early_stopping_patience is not None and trial_config.training.early_stopping_patience > 0:
early_stopping = EarlyStopping(
monitor=metric_to_optimize,
patience=trial_config.training.early_stopping_patience,
mode=monitor_mode,
verbose=False # Less verbose during HPO
)
callbacks.append(early_stopping)
# Add Optuna Pruning Callback
if enable_pruning:
pruning_callback = PyTorchLightningPruningCallback(trial, monitor=metric_to_optimize)
callbacks.append(pruning_callback)
# Optional: LR Monitor
# callbacks.append(LearningRateMonitor(logging_interval='epoch'))
# --- Trainer for this Trial/Fold ---
trainer = pl.Trainer(
accelerator='gpu' if torch.cuda.is_available() else 'cpu',
devices=1 if torch.cuda.is_available() else None,
max_epochs=trial_config.training.epochs,
callbacks=callbacks,
logger=False, # Disable default PL logging during HPO
enable_checkpointing=False, # Disable checkpoint saving during HPO
enable_progress_bar=False, # Disable progress bar for cleaner logs
enable_model_summary=False, # Disable model summary
gradient_clip_val=getattr(trial_config.training, 'gradient_clip_val', None),
precision=getattr(trial_config.training, 'precision', 32),
# Log GPU usage if available?
# log_gpu_memory='min_max',
)
# --- Fit the Model ---
logger.info(f"Trial {trial.number}, Fold {fold_id}: Fitting model...")
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
# --- Get Best Validation Score for Pruning/Reporting ---
# Access the monitored metric value from the trainer's logged metrics or callback state
# Ensure the key matches exactly what's logged in validation_step
best_val_score = trainer.callback_metrics.get(metric_to_optimize)
if best_val_score is None:
logger.warning(f"Trial {trial.number}, Fold {fold_id}: Metric '{metric_to_optimize}' not found in trainer metrics. Using inf/nan.")
# Handle cases where training might have failed or metric wasn't logged
best_val_score = float('inf') if monitor_mode == 'min' else float('-inf') # Return worst possible value
else:
best_val_score = best_val_score.item() # Convert tensor to float
logger.info(f"Trial {trial.number}, Fold {fold_id}: Best validation score ({metric_to_optimize}) = {best_val_score:.4f}")
fold_best_val_metrics.append(best_val_score)
# --- Intermediate Pruning Report (Optional but Recommended) ---
# Report the intermediate value (best score for this fold) to Optuna
# trial.report(best_val_score, fold_id) # Report score at step `fold_id`
# Check if the trial should be pruned based on reported values
# if trial.should_prune():
# logger.info(f"Trial {trial.number}: Pruned after fold {fold_id}.")
# raise optuna.TrialPruned()
logger.info(f"Trial {trial.number}, Fold {fold_id}: Finished in {time.perf_counter() - fold_start_time:.2f}s")
except optuna.TrialPruned:
# Re-raise prune exception to let Optuna handle it
raise
except Exception as e:
logger.error(f"Trial {trial.number}, Fold {fold_id}: Failed with error: {e}", exc_info=True)
# Record a failure for this fold (e.g., append NaN or worst value)
fold_best_val_metrics.append(float('inf') if monitor_mode == 'min' else float('-inf'))
# Optionally: Break the CV loop for this trial if one fold fails catastrophically?
# break
# --- 4. Calculate Average Metric Across Folds ---
if not fold_best_val_metrics:
logger.error(f"Trial {trial.number}: No validation results obtained across folds.")
return float('inf') # Return worst value
# Handle potential infinities or NaNs from failed folds
valid_scores = [s for s in fold_best_val_metrics if np.isfinite(s)]
if not valid_scores:
logger.error(f"Trial {trial.number}: All folds failed or produced non-finite scores.")
return float('inf')
average_val_metric = np.mean(valid_scores)
logger.info(f"--- Trial {trial.number}: Finished ---")
logger.info(f" Average validation {metric_to_optimize}: {average_val_metric:.5f}")
logger.info(f" Total trial time: {time.perf_counter() - trial_start_time:.2f}s")
# --- 5. Return Metric for Optuna ---
return average_val_metric
# --- Main HPO Execution ---
def run_hpo():
"""Main execution function for HPO."""
global args # Make args accessible in objective (simplifies passing) - or use functools.partial
args = parse_arguments()
config_path = Path(args.config)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) # Ensure output dir exists
# Adjust log level if debug flag is set
if args.debug:
root_logger.setLevel(logging.DEBUG)
optuna_lg.setLevel(logging.DEBUG)
pl_logger.setLevel(logging.DEBUG)
logger.debug("Debug mode enabled.")
else:
# Reduce verbosity during HPO runs
optuna_lg.setLevel(logging.WARNING)
pl_logger.setLevel(logging.INFO) # Keep INFO for PL start/end messages
# --- Configuration Loading ---
try:
base_config = load_config(config_path)
except Exception:
sys.exit(1)
# --- Seed Setting (for HPO script itself) ---
set_seeds(args.seed)
# --- Load Data Once ---
# Assume data doesn't change based on HPs (unless sequence_length is tuned heavily)
try:
logger.info("Loading base dataset...")
df = load_raw_data(base_config.data)
logger.info("Base dataset loaded.")
except Exception as e:
logger.critical(f"Failed to load raw data for HPO: {e}", exc_info=True)
sys.exit(1)
# --- Optuna Study Setup ---
storage_path = args.storage_db
if storage_path:
# Ensure directory exists if using SQLite file storage
db_path = Path(storage_path.replace("sqlite:///", ""))
db_path.parent.mkdir(parents=True, exist_ok=True)
storage_path = f"sqlite:///{db_path.resolve()}" # Use absolute path
logger.info(f"Using Optuna storage: {storage_path}")
else:
logger.warning("No Optuna storage DB specified, using in-memory storage (results lost on exit).")
try:
# Create or load the study
study = optuna.create_study(
study_name=args.study_name,
storage=storage_path,
direction=args.direction,
load_if_exists=True, # Load previous results if study exists
pruner=optuna.pruners.MedianPruner() if args.pruning else optuna.pruners.NopPruner() # Example pruner
)
# --- Run Optimization ---
logger.info(f"Starting Optuna optimization: study='{args.study_name}', n_trials={args.n_trials}, metric='{args.metric_to_optimize}', direction='{args.direction}'")
study.optimize(
lambda trial: objective(trial, base_config, df, output_dir, args.metric_to_optimize, args.pruning),
n_trials=args.n_trials,
timeout=None # Optional: Set timeout in seconds
# Optional: Add callbacks (e.g., logging callback)
)
# --- Report Best Trial ---
logger.info("--- Optuna HPO Finished ---")
logger.info(f"Number of finished trials: {len(study.trials)}")
best_trial = study.best_trial
logger.info(f"Best trial number: {best_trial.number}")
logger.info(f" Best validation {args.metric_to_optimize}: {best_trial.value:.5f}")
logger.info(" Best hyperparameters:")
for key, value in best_trial.params.items():
logger.info(f" {key}: {value}")
# --- Save Best Hyperparameters (Optional) ---
best_params_file = output_dir / f"{args.study_name}_best_params.json"
try:
with open(best_params_file, 'w') as f:
import json
json.dump(best_trial.params, f, indent=4)
logger.info(f"Best hyperparameters saved to {best_params_file}")
except Exception as e:
logger.error(f"Failed to save best parameters: {e}")
except Exception as e:
logger.critical(f"An critical error occurred during the Optuna study: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
run_hpo()

68
rules.md Normal file
View File

@ -0,0 +1,68 @@
## Coding Style Rules & Paradigms
### Configuration Driven
* Uses **Pydantic** heavily (`utils/models.py`) to define configuration schemas.
* Configuration is loaded from a **YAML file** (`config.yaml`) at runtime (`main.py`).
* The `Config` object (or relevant sub-configs) is **passed down** through function calls, making parameters explicit.
* A **template configuration** (`_config.yaml`) is often included within the package.
### Modularity
* Code is organized into **logical sub-packages** (`io`, `processing`, `pipeline`, `visualization`, `synthesis`, `utils`, `validation`).
* Each sub-package has an `__init__.py`, often used to **expose key functions/classes** to the parent level.
* **Helper functions** (often internal, prefixed with `_`) are frequently used to break down complex logic within modules (e.g., `processing/surface_helper.py`, `pipeline/runner.py` helpers).
### Logging
* Uses the standard **`logging` library**.
* Loggers are obtained per module using `logger = logging.getLogger(__name__)`.
* **Logging levels** (`DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL`) are used semantically:
* `DEBUG`: Verbose internal steps.
* `INFO`: Major milestones/stages.
* `WARNING`: Recoverable issues or deviations.
* `ERROR`: Specific failures that might be handled.
* `CRITICAL`: Fatal errors causing exits.
* **Root logger configuration** happens in `main.py`, potentially adjusted based on the `debug` flag in the config.
### Error Handling ("Fail Hard but Helpful")
* The main entry point (`main.py`) uses a **top-level `try...except` block** to catch major failures during config loading or pipeline execution.
* **Critical errors** are logged with tracebacks (`exc_info=True`) and result in `sys.exit(1)`.
* Functions often return a **tuple indicating success/failure** and results/error messages (e.g., `(result_data, error_message)` or `(success_flag, result_data)`).
* Lower-level functions may log errors/warnings but **allow processing to continue** if feasible and configured (e.g., `allow_segmentation_errors`).
* **Specific exceptions** are caught where appropriate (`FileNotFoundError`, `pydicom.errors.InvalidDicomError`, `ValueError`, etc.).
* **Pydantic validation errors** during config loading are treated as critical.
### Typing
* Consistent use of **Python type hints** (`typing` module: `Optional`, `Dict`, `List`, `Tuple`, `Union`, `Callable`, `Literal`, etc.).
* **Pydantic models** rely heavily on type hints for validation.
### Data Structures
* **Pydantic models** define primary configuration and result structures (e.g., `Config`, `ProcessingResult`, `CombinedDicomDataset`).
* **NumPy arrays** are fundamental for image/volume data.
* **Pandas DataFrames** are used for aggregating results, metadata, and creating reports (Excel).
* Standard **Python dictionaries** are used extensively for metadata and intermediate data passing.
### Naming Conventions
* Follows **PEP 8**: `snake_case` for variables and functions, `PascalCase` for classes.
* Internal helper functions are typically prefixed with an **underscore (`_`)**.
* Constants are defined in **`UPPER_SNAKE_CASE`** (often in a dedicated `utils/constants.py`).
### Documentation
* **Docstrings** are present for most functions and classes, explaining purpose, arguments (`Args:`), and return values (`Returns:`).
* Minimal **inline comments**; code aims to be self-explanatory, with docstrings providing higher-level context. (Matches your custom instructions).
### Dependencies
* Managed via `requirements.txt`.
* Uses standard **scientific Python stack** (`numpy`, `pandas`, `scipy`, `scikit-image`, `matplotlib`), **domain-specific libraries** (`pydicom`), **utility libraries** (`PyYAML`, `joblib`, `tqdm`, `openpyxl`), and `pydantic` for configuration/validation.
### Parallelism
* Uses **`joblib`** for parallel processing, configurable via the main config (`mainprocess_core_count`, `subprocess_core_count`).
* Parallelism can be **disabled** via configuration or debug mode.