intermediate backup
This commit is contained in:
@ -14,7 +14,6 @@ mpl_logger.setLevel(logging.WARNING) # Example: set to WARNING or ERROR
|
|||||||
|
|
||||||
# --- Basic Logging Setup ---
|
# --- Basic Logging Setup ---
|
||||||
# Configure logging early to catch basic issues.
|
# Configure logging early to catch basic issues.
|
||||||
# The level might be adjusted after config loading.
|
|
||||||
logging.basicConfig(level=logging.INFO,
|
logging.basicConfig(level=logging.INFO,
|
||||||
format='%(asctime)s - %(levelname)-7s - %(message)s',
|
format='%(asctime)s - %(levelname)-7s - %(message)s',
|
||||||
datefmt='%H:%M:%S')
|
datefmt='%H:%M:%S')
|
@ -1,22 +1,88 @@
|
|||||||
# Configuration for the forecasting model EDA
|
# Configuration for Time Series Forecasting Pipeline
|
||||||
# This file defines the settings for data loading, analysis, and visualization
|
|
||||||
|
|
||||||
# -- General Settings --
|
project_name: "TimeSeriesForecasting" # Name for the project/run
|
||||||
log_level: INFO # Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
random_seed: 42 # Optional: Global random seed for reproducibility
|
||||||
debug: true
|
|
||||||
|
|
||||||
# -- IO Settings --
|
# --- Data Loading Configuration ---
|
||||||
data_file: data/Day-ahead_Prices_60min.csv # Path to the input data CSV relative to project root
|
data:
|
||||||
output_dir: output/reports # Directory to save generated plots and report artifacts
|
data_path: "data/Day-ahead_Prices_60min.csv" # Path to your CSV
|
||||||
latex_template_file: null # Path to the LaTeX template file relative to project root
|
# --- Raw Data Specifics ---
|
||||||
|
raw_datetime_col: "MTU (CET/CEST)" # EXACT name in your raw CSV
|
||||||
|
raw_target_col: "Day-ahead Price [EUR/MWh]" # EXACT name in your raw CSV
|
||||||
|
raw_datetime_format: '%d.%m.%Y %H:%M' # Format string is now hardcoded in load_raw_data based on analysis
|
||||||
|
|
||||||
|
# --- Standardized Names & Processing ---
|
||||||
|
datetime_col: "Timestamp" # Desired name for the index after processing
|
||||||
|
target_col: "Price" # Desired name for the target column after processing
|
||||||
|
expected_frequency: "h" # Expected frequency ('h', 'D', '15min', etc. or null)
|
||||||
|
fill_initial_target_nans: true # Fill target NaNs immediately after loading?
|
||||||
|
|
||||||
# -- Zoom Settings (Plotting and Analysis) --
|
# --- Feature Engineering & Preprocessing Configuration ---
|
||||||
# Optional: Specify a date range for zoomed-in plots (YYYY-MM-DD format)
|
features:
|
||||||
# Example: zoom_start_date: "2023-01-01"
|
sequence_length: 72 # REQUIRED: Lookback window size (e.g., 72 hours = 3 days)
|
||||||
# Example: zoom_end_date: "2023-12-31"
|
forecast_horizon: 24 # REQUIRED: Number of steps ahead to predict (e.g., 24 hours)
|
||||||
zoom_start_date: null # Default to null
|
lags: [24, 48, 72, 168] # List of lag features to create (e.g., 1 day, 2 days, 3 days, 1 week)
|
||||||
zoom_end_date: null # Default to null
|
rolling_window_sizes: [24, 72, 168] # List of window sizes for rolling stats (mean, std)
|
||||||
|
use_time_features: true # Create calendar features (hour, dayofweek, month, etc.)?
|
||||||
|
sinus_curve: true # Create sinusoidal feature for time of day?
|
||||||
|
cosin_curve: true # Create cosinusoidal feature for time of day?
|
||||||
|
fill_nan: 'ffill' # Method to fill NaNs created by lags/rolling windows ('ffill', 'bfill', 0, etc.)
|
||||||
|
scaling_method: 'standard' # Scaling method ('standard', 'minmax', or null/None for no scaling) Fit per fold.
|
||||||
|
|
||||||
# -- Data Settings --
|
# Optional: Wavelet Transform configuration
|
||||||
expected_data_frequency: "h" # Expected frequency of the time series data (h=hourly, D=daily, M=monthly, Y=yearly)
|
wavelet_transform:
|
||||||
|
apply: false # Apply wavelet transform?
|
||||||
|
target_or_feature: "target" # Apply to 'target' before other features, or 'feature' after?
|
||||||
|
wavelet_type: "db4" # Type of wavelet (e.g., 'db4', 'sym4')
|
||||||
|
level: 3 # Decomposition level (must be > 0)
|
||||||
|
use_coeffs: ["approx", "detail_1"] # Which coefficients to use as features
|
||||||
|
|
||||||
|
# Optional: Feature Clipping configuration
|
||||||
|
clipping:
|
||||||
|
apply: false # Apply clipping to generated features (excluding target)?
|
||||||
|
clip_min: 0 # Minimum value for clipping
|
||||||
|
clip_max: 400 # Maximum value for clipping
|
||||||
|
|
||||||
|
# --- Model Architecture Configuration ---
|
||||||
|
model:
|
||||||
|
# input_size: null # Removed: Calculated automatically based on features and passed directly to model
|
||||||
|
hidden_size: 128 # REQUIRED: Number of units in LSTM hidden layers
|
||||||
|
num_layers: 2 # REQUIRED: Number of LSTM layers
|
||||||
|
dropout: 0.2 # REQUIRED: Dropout rate (between 0.0 and 1.0)
|
||||||
|
use_residual_skips: false # Add residual connection from input to LSTM output?
|
||||||
|
# forecast_horizon: null # Set automatically from features.forecast_horizon
|
||||||
|
|
||||||
|
# --- Training Configuration (PyTorch Lightning) ---
|
||||||
|
training:
|
||||||
|
batch_size: 64 # REQUIRED: Batch size for training
|
||||||
|
epochs: 50 # REQUIRED: Max number of training epochs per fold
|
||||||
|
learning_rate: 0.001 # REQUIRED: Initial learning rate for Adam optimizer
|
||||||
|
loss_function: "MSE" # Loss function ('MSE' or 'MAE')
|
||||||
|
early_stopping_patience: 10 # Optional: Patience for early stopping (epochs). Set null/None to disable. Must be >= 1 if set.
|
||||||
|
scheduler_step_size: null # Optional: Step size for StepLR scheduler (epochs). Set null/None to disable. Must be > 0 if set.
|
||||||
|
scheduler_gamma: null # Optional: Gamma factor for StepLR scheduler. Set null/None to disable. Must be 0 < gamma < 1 if set.
|
||||||
|
gradient_clip_val: 1.0 # Optional: Value for gradient clipping. Set null/None to disable. Must be >= 0.0 if set.
|
||||||
|
num_workers: 0 # Number of workers for DataLoader (>= 0). 0 means data loading happens in the main process.
|
||||||
|
precision: 32 # Training precision (16, 32, 64, 'bf16')
|
||||||
|
|
||||||
|
# --- Cross-Validation Configuration (Rolling Window) ---
|
||||||
|
cross_validation:
|
||||||
|
n_splits: 5 # REQUIRED: Number of CV folds (must be > 0)
|
||||||
|
test_size_fraction: 0.1 # REQUIRED: Fraction of the *fixed training window size* for the test set (0 < frac < 1)
|
||||||
|
val_size_fraction: 0.1 # REQUIRED: Fraction of the *fixed training window size* for the validation set (0 < frac < 1)
|
||||||
|
initial_train_size: null # Optional: Size of the fixed training window (integer samples or float fraction of total data > 0). If null, estimated automatically.
|
||||||
|
|
||||||
|
# --- Evaluation Configuration ---
|
||||||
|
evaluation:
|
||||||
|
eval_batch_size: 128 # REQUIRED: Batch size for evaluation/testing (must be > 0)
|
||||||
|
save_plots: true # Save evaluation plots (predictions, residuals)?
|
||||||
|
plot_sample_size: 1000 # Optional: Max number of points in time series plots (must be > 0 if set)
|
||||||
|
|
||||||
|
# --- Optuna Hyperparameter Optimization Configuration ---
|
||||||
|
optuna:
|
||||||
|
enabled: false # Enable Optuna HPO? If true, requires optuna.py script.
|
||||||
|
n_trials: 20 # Number of trials to run (must be > 0)
|
||||||
|
storage: null # Optional: Optuna storage URL (e.g., "sqlite:///output/hpo_results/study.db"). If null, uses in-memory.
|
||||||
|
direction: "minimize" # Optimization direction ('minimize' or 'maximize')
|
||||||
|
metric_to_optimize: "val_mae_orig_scale" # Metric logged by LightningModule to optimize
|
||||||
|
pruning: true # Enable Optuna trial pruning?
|
||||||
|
@ -1,75 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
import time
|
|
||||||
|
|
||||||
# Import necessary components from your project structure
|
|
||||||
from data_analysis.utils.config_model import load_settings, Settings # Import loading function and model
|
|
||||||
from data_analysis.analysis.pipeline import run_eda_pipeline # Import the pipeline entry point
|
|
||||||
|
|
||||||
# Silence overly verbose libraries if needed (e.g., matplotlib)
|
|
||||||
mpl_logger = logging.getLogger('matplotlib')
|
|
||||||
mpl_logger.setLevel(logging.WARNING) # Example: set to WARNING or ERROR
|
|
||||||
|
|
||||||
# --- Basic Logging Setup ---
|
|
||||||
# Configure logging early to catch basic issues.
|
|
||||||
# The level might be adjusted after config loading.
|
|
||||||
logging.basicConfig(level=logging.INFO,
|
|
||||||
format='%(asctime)s - %(levelname)-7s - %(message)s',
|
|
||||||
datefmt='%H:%M:%S')
|
|
||||||
# Get the root logger
|
|
||||||
logger = logging.getLogger()
|
|
||||||
|
|
||||||
# --- Argument Parsing ---
|
|
||||||
def parse_arguments():
|
|
||||||
"""Parses command-line arguments."""
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Run the Energy Forecasting EDA pipeline.",
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-c', '--config',
|
|
||||||
type=str,
|
|
||||||
default='config.yaml', # Provide a default config file name
|
|
||||||
help="Path to the YAML configuration file."
|
|
||||||
)
|
|
||||||
# Add other potential command-line overrides here if needed later
|
|
||||||
# parser.add_argument('--debug', action='store_true', help="Override log level to DEBUG.")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
return args
|
|
||||||
|
|
||||||
# --- Main Execution ---
|
|
||||||
def main():
|
|
||||||
"""Main execution function."""
|
|
||||||
args = parse_arguments()
|
|
||||||
config_path = Path(args.config)
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
|
|
||||||
# --- Configuration Loading ---
|
|
||||||
_ = load_settings(config_path)
|
|
||||||
logger.info(f"Using configuration from: {config_path.resolve()} (or defaults if loading failed)")
|
|
||||||
|
|
||||||
# --- Pipeline Execution ---
|
|
||||||
try:
|
|
||||||
# Call the main function from your pipeline module
|
|
||||||
run_eda_pipeline()
|
|
||||||
|
|
||||||
end_time = time.perf_counter()
|
|
||||||
logger.info(f"Main script finished successfully in {end_time - start_time:.2f} seconds.")
|
|
||||||
|
|
||||||
except SystemExit as e:
|
|
||||||
# Catch SystemExit if pipeline runner exits intentionally
|
|
||||||
logger.warning(f"Pipeline exited with code {e.code}.")
|
|
||||||
sys.exit(e.code) # Propagate exit code
|
|
||||||
except Exception as e:
|
|
||||||
logger.critical(f"An critical error occurred during pipeline execution: {e}", exc_info=True)
|
|
||||||
end_time = time.perf_counter()
|
|
||||||
logger.info(f"Main script failed after {end_time - start_time:.2f} seconds.")
|
|
||||||
sys.exit(1)
|
|
||||||
return
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
exit(1)
|
|
@ -5,4 +5,39 @@ This module provides a configurable PyTorch-based LSTM model for time series for
|
|||||||
with support for feature engineering, cross-validation, and evaluation.
|
with support for feature engineering, cross-validation, and evaluation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.1.0"
|
||||||
|
|
||||||
|
# Expose core components for easier import
|
||||||
|
from .data_processing import (
|
||||||
|
load_raw_data,
|
||||||
|
engineer_features,
|
||||||
|
TimeSeriesCrossValidationSplitter,
|
||||||
|
prepare_fold_data_and_loaders,
|
||||||
|
TimeSeriesDataset
|
||||||
|
)
|
||||||
|
from .model import LSTMForecastLightningModule
|
||||||
|
from .evaluation import (
|
||||||
|
evaluate_fold_predictions,
|
||||||
|
# Optionally expose the standalone evaluation utility if needed externally
|
||||||
|
# evaluate_model_on_fold_test_set
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expose main configuration class from utils
|
||||||
|
from .utils import MainConfig
|
||||||
|
|
||||||
|
# Expose the main execution script function if it's intended to be callable as a function
|
||||||
|
# from .forecasting_model import run # Assuming the main script is named forecasting_model.py
|
||||||
|
|
||||||
|
# Define __all__ for explicit public API (optional but good practice)
|
||||||
|
__all__ = [
|
||||||
|
"load_raw_data",
|
||||||
|
"engineer_features",
|
||||||
|
"TimeSeriesCrossValidationSplitter",
|
||||||
|
"prepare_fold_data_and_loaders",
|
||||||
|
"TimeSeriesDataset",
|
||||||
|
"LSTMForecastLightningModule",
|
||||||
|
"evaluate_fold_predictions",
|
||||||
|
# "evaluate_model_on_fold_test_set", # Uncomment if exposed
|
||||||
|
"MainConfig",
|
||||||
|
# "run", # Uncomment if exposed
|
||||||
|
]
|
@ -1,67 +1,751 @@
|
|||||||
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
||||||
from typing import Tuple, Generator, List, Optional
|
from typing import Tuple, Generator, List, Optional, Union, Dict, Literal, Type
|
||||||
from utils.config_model import DataConfig, FeatureConfig, TrainingConfig, EvaluationConfig
|
|
||||||
|
# Use relative import for utils within the package
|
||||||
|
from .utils.config_model import DataConfig, FeatureConfig, TrainingConfig, EvaluationConfig, CrossValidationConfig
|
||||||
|
# Optional: Import wavelet library if needed later
|
||||||
|
# import pywt
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# --- Data Loading ---
|
# --- Data Loading ---
|
||||||
def load_raw_data(config: DataConfig) -> pd.DataFrame:
|
def load_raw_data(config: DataConfig) -> pd.DataFrame:
|
||||||
"""
|
"""
|
||||||
Load and preprocess raw data from CSV.
|
Load raw time series data from a CSV file, handling specific formats,
|
||||||
|
performing initial cleaning, frequency checks, and NaN filling based on config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: DataConfig object containing file path, raw/standard column names,
|
||||||
|
frequency settings, and NaN handling flags.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with a standardized datetime index (named config.datetime_col)
|
||||||
|
and a standardized, cleaned target column (named config.target_col).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If the data path does not exist.
|
||||||
|
ValueError: If specified raw columns are not found, datetime parsing fails,
|
||||||
|
or frequency checks indicate critical issues.
|
||||||
|
Exception: For other pandas read_csv or processing errors.
|
||||||
"""
|
"""
|
||||||
# TODO: Implement CSV loading and datetime parsing
|
logger.info(f"Loading raw data from: {config.data_path}")
|
||||||
pass
|
try:
|
||||||
|
# --- Initial Load ---
|
||||||
|
df = pd.read_csv(config.data_path, header=0)
|
||||||
|
logger.debug(f"Loaded raw data shape: {df.shape}")
|
||||||
|
|
||||||
|
# --- Validate Raw Columns ---
|
||||||
|
if config.raw_datetime_col not in df.columns:
|
||||||
|
raise ValueError(f"Raw datetime column '{config.raw_datetime_col}' not found in {config.data_path}")
|
||||||
|
if config.raw_target_col not in df.columns:
|
||||||
|
raise ValueError(f"Raw target column '{config.raw_target_col}' not found in {config.data_path}")
|
||||||
|
|
||||||
|
# --- Time Parsing (Specific Format Handling) ---
|
||||||
|
logger.info(f"Parsing raw datetime column: '{config.raw_datetime_col}'")
|
||||||
|
try:
|
||||||
|
# Extract the start time part 'dd.mm.yyyy hh:mm'
|
||||||
|
# Handle potential errors during split if format deviates
|
||||||
|
start_times = df[config.raw_datetime_col].astype(str).str.split(' - ', expand=True)[0]
|
||||||
|
# Define the specific format
|
||||||
|
datetime_format = config.raw_datetime_format or '%d.%m.%Y %H:%M'
|
||||||
|
# Parse to datetime, coercing errors to NaT
|
||||||
|
parsed_timestamps = pd.to_datetime(start_times, format=datetime_format, errors='coerce')
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to split or parse raw datetime column '{config.raw_datetime_col}' using expected format: {e}", exc_info=True)
|
||||||
|
raise ValueError("Datetime parsing failed. Check raw_datetime_col format and data.")
|
||||||
|
|
||||||
|
# Check for parsing errors (NaT values)
|
||||||
|
num_parsing_errors = parsed_timestamps.isnull().sum()
|
||||||
|
if num_parsing_errors > 0:
|
||||||
|
original_len = len(df)
|
||||||
|
df = df.loc[parsed_timestamps.notnull()].copy() # Keep only rows with valid timestamps
|
||||||
|
parsed_timestamps = parsed_timestamps.dropna()
|
||||||
|
logger.warning(f"Dropped {num_parsing_errors} rows ({num_parsing_errors/original_len:.1%}) due to timestamp parsing errors "
|
||||||
|
f"(expected format: '{datetime_format}' on start time).")
|
||||||
|
if df.empty:
|
||||||
|
raise ValueError("No valid timestamps found after parsing. Check data format.")
|
||||||
|
|
||||||
|
# Assign parsed timestamp and set as index with standardized name
|
||||||
|
df[config.datetime_col] = parsed_timestamps
|
||||||
|
df = df.set_index(config.datetime_col)
|
||||||
|
logger.debug(f"Set '{config.datetime_col}' as index.")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Target Column Processing ---
|
||||||
|
logger.info(f"Processing target column: '{config.raw_target_col}' -> '{config.target_col}'")
|
||||||
|
# Convert raw target to numeric, coercing errors
|
||||||
|
df[config.target_col] = pd.to_numeric(df[config.raw_target_col], errors='coerce')
|
||||||
|
|
||||||
|
# Handle NaNs caused by coercion
|
||||||
|
num_coercion_errors = df[config.target_col].isnull().sum()
|
||||||
|
if num_coercion_errors > 0:
|
||||||
|
logger.warning(f"Found {num_coercion_errors} non-numeric values in raw target column '{config.raw_target_col}'. Coerced to NaN.")
|
||||||
|
# Keep rows with NaN for now, handle based on config flag below
|
||||||
|
|
||||||
|
|
||||||
|
# --- Column Selection ---
|
||||||
|
# Keep only the standardized target column for the forecasting pipeline
|
||||||
|
# Discard raw columns and any others loaded initially
|
||||||
|
df = df[[config.target_col]]
|
||||||
|
logger.debug(f"Selected target column '{config.target_col}'. Shape: {df.shape}")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Initial Target NaN Filling (Optional) ---
|
||||||
|
if config.fill_initial_target_nans:
|
||||||
|
missing_prices = df[config.target_col].isnull().sum()
|
||||||
|
if missing_prices > 0:
|
||||||
|
logger.info(f"Found {missing_prices} missing values in target column '{config.target_col}'. Applying ffill then bfill.")
|
||||||
|
df[config.target_col] = df[config.target_col].ffill()
|
||||||
|
df[config.target_col] = df[config.target_col].bfill() # Fill remaining NaNs at the start
|
||||||
|
|
||||||
|
final_missing = df[config.target_col].isnull().sum()
|
||||||
|
if final_missing > 0:
|
||||||
|
logger.error(f"{final_missing} missing values REMAIN in target column after ffill/bfill. Cannot proceed.")
|
||||||
|
raise ValueError("Target column contains unfillable NaN values.")
|
||||||
|
else:
|
||||||
|
logger.debug("No missing values found in target column.")
|
||||||
|
else:
|
||||||
|
logger.info("Skipping initial NaN filling for target column as per config.")
|
||||||
|
# Warning if NaNs exist and aren't being filled here
|
||||||
|
if df[config.target_col].isnull().any():
|
||||||
|
logger.warning(f"NaNs exist in target column '{config.target_col}' and initial filling is disabled.")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Frequency Check & Setting ---
|
||||||
|
logger.info("Checking time index frequency...")
|
||||||
|
df = df.sort_index() # Ensure index is sorted before frequency checks
|
||||||
|
|
||||||
|
# Handle duplicate timestamps before frequency inference
|
||||||
|
duplicates = df.index.duplicated().sum()
|
||||||
|
if duplicates > 0:
|
||||||
|
logger.warning(f"Found {duplicates} duplicate timestamps. Keeping the first occurrence.")
|
||||||
|
df = df[~df.index.duplicated(keep='first')]
|
||||||
|
|
||||||
|
if config.expected_frequency:
|
||||||
|
inferred_freq = pd.infer_freq(df.index)
|
||||||
|
logger.debug(f"Inferred frequency: {inferred_freq}")
|
||||||
|
|
||||||
|
if inferred_freq == config.expected_frequency:
|
||||||
|
logger.info(f"Inferred frequency matches expected ('{config.expected_frequency}'). Setting index frequency.")
|
||||||
|
df = df.asfreq(config.expected_frequency)
|
||||||
|
# Check for NaNs introduced by asfreq (filling gaps)
|
||||||
|
missing_after_asfreq = df[config.target_col].isnull().sum()
|
||||||
|
if missing_after_asfreq > 0:
|
||||||
|
logger.warning(f"{missing_after_asfreq} NaNs appeared after setting frequency to '{config.expected_frequency}'. Applying ffill/bfill.")
|
||||||
|
# Only fill if initial filling was also enabled, otherwise just warn? Be explicit.
|
||||||
|
if config.fill_initial_target_nans:
|
||||||
|
df[config.target_col] = df[config.target_col].ffill().bfill()
|
||||||
|
if df[config.target_col].isnull().any():
|
||||||
|
logger.error("NaNs still present after attempting to fill gaps from asfreq. Check data continuity.")
|
||||||
|
raise ValueError("Unfillable NaNs after setting frequency.")
|
||||||
|
else:
|
||||||
|
logger.warning("Initial NaN filling was disabled, leaving NaNs introduced by asfreq.")
|
||||||
|
|
||||||
|
elif inferred_freq:
|
||||||
|
logger.warning(f"Inferred frequency ('{inferred_freq}') does NOT match expected ('{config.expected_frequency}'). Index frequency will not be explicitly set. This might affect time-based features or models assuming regular intervals.")
|
||||||
|
# Consider raising an error depending on strictness needed
|
||||||
|
# raise ValueError("Inferred frequency does not match expected frequency.")
|
||||||
|
else:
|
||||||
|
logger.error(f"Could not infer frequency, but expected frequency was set to '{config.expected_frequency}'. Check data for gaps or irregularities. Index frequency will not be explicitly set.")
|
||||||
|
# This is often a critical issue for time series models
|
||||||
|
raise ValueError("Could not infer frequency. Ensure data has regular intervals matching expected_frequency.")
|
||||||
|
else:
|
||||||
|
logger.info("No expected frequency specified in config. Skipping frequency check and setting.")
|
||||||
|
|
||||||
|
logger.info(f"Data loading and initial preparation complete. Final shape: {df.shape}")
|
||||||
|
return df
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.error(f"Data file not found at: {config.data_path}")
|
||||||
|
raise
|
||||||
|
except ValueError as e: # Catch ValueErrors raised internally or by pandas
|
||||||
|
logger.error(f"Data processing error: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
except Exception as e: # Catch other unexpected errors
|
||||||
|
logger.error(f"Failed to load or process data from {config.data_path}: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
# --- Feature Engineering ---
|
# --- Feature Engineering ---
|
||||||
def engineer_features(df: pd.DataFrame, target_col: str, feature_config: FeatureConfig) -> pd.DataFrame:
|
def engineer_features(df: pd.DataFrame, target_col: str, feature_config: FeatureConfig) -> pd.DataFrame:
|
||||||
"""
|
"""
|
||||||
Create features from the target column and datetime index.
|
Create time-series features from the target column and datetime index.
|
||||||
|
This function operates on a specific slice of data provided during
|
||||||
|
cross-validation setup.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: DataFrame containing the target column and a datetime index.
|
||||||
|
Should contain enough history for lookbacks (lags, rolling windows).
|
||||||
|
target_col: The name of the column to engineer features from.
|
||||||
|
feature_config: Configuration object specifying which features to create.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with original target and engineered features.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If target_col is not in df or configuration is invalid.
|
||||||
|
ImportError: If wavelets are requested but pywt is not installed.
|
||||||
"""
|
"""
|
||||||
# TODO: Implement feature engineering (lags, rolling stats, time features, wavelets)
|
if target_col not in df.columns:
|
||||||
pass
|
raise ValueError(f"Target column '{target_col}' not found in DataFrame for feature engineering.")
|
||||||
|
|
||||||
|
logger.info("Starting feature engineering...")
|
||||||
|
features_df = df[[target_col]].copy() # Start with the target
|
||||||
|
|
||||||
|
# 1. Lags
|
||||||
|
if feature_config.lags:
|
||||||
|
logger.debug(f"Creating lag features for lags: {feature_config.lags}")
|
||||||
|
for lag in feature_config.lags:
|
||||||
|
if lag <= 0:
|
||||||
|
logger.warning(f"Ignoring non-positive lag value: {lag}")
|
||||||
|
continue
|
||||||
|
features_df[f'{target_col}_lag_{lag}'] = df[target_col].shift(lag)
|
||||||
|
|
||||||
|
# 2. Rolling Window Statistics
|
||||||
|
if feature_config.rolling_window_sizes:
|
||||||
|
logger.debug(f"Creating rolling window features for sizes: {feature_config.rolling_window_sizes}")
|
||||||
|
for window in feature_config.rolling_window_sizes:
|
||||||
|
if window <= 0:
|
||||||
|
logger.warning(f"Ignoring non-positive rolling window size: {window}")
|
||||||
|
continue
|
||||||
|
# Shift by 1 so the window does not include the current observation
|
||||||
|
# Use closed='left' to ensure window ends *before* the current point
|
||||||
|
rolling_obj = df[target_col].shift(1).rolling(window=window, min_periods=window // 2, closed='left')
|
||||||
|
features_df[f'{target_col}_rolling_mean_{window}'] = rolling_obj.mean()
|
||||||
|
features_df[f'{target_col}_rolling_std_{window}'] = rolling_obj.std()
|
||||||
|
# Add more stats if needed (e.g., min, max, median)
|
||||||
|
|
||||||
|
# 3. Time/Calendar Features
|
||||||
|
if feature_config.use_time_features:
|
||||||
|
logger.debug("Creating time/calendar features.")
|
||||||
|
idx = features_df.index # Use index from features_df
|
||||||
|
features_df['hour'] = idx.hour
|
||||||
|
features_df['dayofweek'] = idx.dayofweek
|
||||||
|
features_df['dayofmonth'] = idx.day
|
||||||
|
features_df['dayofyear'] = idx.dayofyear
|
||||||
|
# Ensure 'weekofyear' is Int64 to handle potential NAs if index isn't perfectly continuous (though unlikely here)
|
||||||
|
features_df['weekofyear'] = idx.isocalendar().week.astype(pd.Int64Dtype()) # pandas >= 1.1.0
|
||||||
|
features_df['month'] = idx.month
|
||||||
|
features_df['year'] = idx.year
|
||||||
|
features_df['quarter'] = idx.quarter
|
||||||
|
features_df['is_weekend'] = (idx.dayofweek >= 5).astype(int)
|
||||||
|
|
||||||
|
# 4. Sinusoidal Time Features (Optional, based on config)
|
||||||
|
if feature_config.sinus_curve:
|
||||||
|
logger.debug("Creating sinusoidal daily time feature.")
|
||||||
|
seconds_in_day = 24 * 60 * 60
|
||||||
|
seconds_past_midnight = features_df.index.hour * 3600 + features_df.index.minute * 60 + features_df.index.second
|
||||||
|
features_df['sin_day'] = np.sin(2 * np.pi * seconds_past_midnight / seconds_in_day)
|
||||||
|
|
||||||
|
if feature_config.cosin_curve: # Assuming this means cos for day
|
||||||
|
logger.debug("Creating cosinusoidal daily time feature.")
|
||||||
|
seconds_in_day = 24 * 60 * 60
|
||||||
|
seconds_past_midnight = features_df.index.hour * 3600 + features_df.index.minute * 60 + features_df.index.second
|
||||||
|
features_df['cos_day'] = np.cos(2 * np.pi * seconds_past_midnight / seconds_in_day)
|
||||||
|
|
||||||
|
|
||||||
|
# 5. Wavelet Transform (Optional)
|
||||||
|
if feature_config.wavelet_transform and feature_config.wavelet_transform.apply:
|
||||||
|
logger.warning("Wavelet feature engineering is specified but not implemented yet.")
|
||||||
|
|
||||||
|
|
||||||
|
# 6. Handling NaNs generated during feature engineering (for *generated* features)
|
||||||
|
feature_cols_generated = [col for col in features_df.columns if col != target_col]
|
||||||
|
if feature_cols_generated: # Only fill NaNs if features were actually generated
|
||||||
|
nan_handler = feature_config.fill_nan
|
||||||
|
if nan_handler is not None:
|
||||||
|
fill_value: Optional[Union[str, float]] = None
|
||||||
|
fill_method: Optional[str] = None
|
||||||
|
|
||||||
|
if isinstance(nan_handler, str):
|
||||||
|
if nan_handler in ['ffill', 'bfill']:
|
||||||
|
fill_method = nan_handler
|
||||||
|
logger.debug(f"Filling NaNs in generated features using method: '{fill_method}'")
|
||||||
|
elif nan_handler == 'mean':
|
||||||
|
logger.warning("NaN filling with 'mean' in generated features is applied globally here;"
|
||||||
|
" consider per-fold mean filling if lookahead is a concern.")
|
||||||
|
# Calculate mean only on the slice provided, potentially leaking info if slice includes val/test
|
||||||
|
# Better to use ffill/bfill here or handle after split
|
||||||
|
fill_value = features_df[feature_cols_generated].mean() # Calculate mean per feature column
|
||||||
|
logger.debug("Filling NaNs in generated features using column means.")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unsupported string fill_nan method '{nan_handler}' for generated features. Using 'ffill'.")
|
||||||
|
fill_method = 'ffill'
|
||||||
|
elif isinstance(nan_handler, (int, float)):
|
||||||
|
fill_value = float(nan_handler)
|
||||||
|
logger.debug(f"Filling NaNs in generated features with value: {fill_value}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Invalid fill_nan type: {type(nan_handler)}. NaNs in features may remain.")
|
||||||
|
|
||||||
|
# Apply filling only to generated feature columns
|
||||||
|
if fill_method:
|
||||||
|
features_df[feature_cols_generated] = features_df[feature_cols_generated].fillna(method=fill_method)
|
||||||
|
if fill_method == 'ffill':
|
||||||
|
features_df[feature_cols_generated] = features_df[feature_cols_generated].fillna(method='bfill')
|
||||||
|
elif fill_value is not None:
|
||||||
|
# fillna with Series/dict for column-wise mean, or scalar for constant value
|
||||||
|
features_df[feature_cols_generated] = features_df[feature_cols_generated].fillna(value=fill_value)
|
||||||
|
else:
|
||||||
|
logger.warning("`fill_nan` is None. NaNs generated by feature engineering may remain.")
|
||||||
|
|
||||||
|
remaining_nans = features_df[feature_cols_generated].isnull().sum().sum()
|
||||||
|
if remaining_nans > 0:
|
||||||
|
logger.warning(f"{remaining_nans} NaN values remain in generated features.")
|
||||||
|
|
||||||
|
|
||||||
|
# 7. Clipping (Optional) - Apply *after* feature generation but *before* scaling
|
||||||
|
if feature_config.clipping and feature_config.clipping.apply: # Check nested config
|
||||||
|
clip_config = feature_config.clipping
|
||||||
|
logger.debug(f"Clipping features (excluding target '{target_col}') between {clip_config.clip_min} and {clip_config.clip_max}")
|
||||||
|
feature_cols_to_clip = [col for col in features_df.columns if col != target_col]
|
||||||
|
if not feature_cols_to_clip:
|
||||||
|
logger.warning("Clipping enabled, but no feature columns found to clip (only target exists?).")
|
||||||
|
else:
|
||||||
|
features_df[feature_cols_to_clip] = features_df[feature_cols_to_clip].clip(
|
||||||
|
lower=clip_config.clip_min, upper=clip_config.clip_max
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Feature engineering completed. DataFrame shape: {features_df.shape}")
|
||||||
|
logger.debug(f"Feature columns: {features_df.columns.tolist()}")
|
||||||
|
|
||||||
|
return features_df
|
||||||
|
|
||||||
|
|
||||||
# --- Cross Validation ---
|
# --- Cross Validation ---
|
||||||
class TimeSeriesCrossValidationSplitter:
|
class TimeSeriesCrossValidationSplitter:
|
||||||
|
"""
|
||||||
|
Generates indices for time series cross-validation using a rolling (sliding) window.
|
||||||
|
|
||||||
|
The training window has a fixed size. For each split, the entire window
|
||||||
|
(train, validation, and test sets) slides forward by a specified step size
|
||||||
|
(typically the size of the test set). Validation and test set sizes are
|
||||||
|
calculated as fractions of the fixed training window size.
|
||||||
|
"""
|
||||||
def __init__(self, config: CrossValidationConfig, n_samples: int):
|
def __init__(self, config: CrossValidationConfig, n_samples: int):
|
||||||
self.config = config
|
"""
|
||||||
|
Args:
|
||||||
|
config: CrossValidationConfig with split parameters.
|
||||||
|
n_samples: Total number of samples in the dataset.
|
||||||
|
"""
|
||||||
|
self.n_splits = config.n_splits
|
||||||
|
self.val_frac = config.val_size_fraction
|
||||||
|
self.test_frac = config.test_size_fraction
|
||||||
|
self.initial_train_size = config.initial_train_size # Used as the FIXED train size for rolling window
|
||||||
self.n_samples = n_samples
|
self.n_samples = n_samples
|
||||||
|
|
||||||
|
if not (0 < self.val_frac < 1):
|
||||||
|
raise ValueError(f"val_size_fraction must be between 0 and 1, got {self.val_frac}")
|
||||||
|
if not (0 < self.test_frac < 1):
|
||||||
|
raise ValueError(f"test_size_fraction must be between 0 and 1, got {self.test_frac}")
|
||||||
|
if self.n_splits <= 0:
|
||||||
|
raise ValueError(f"n_splits must be positive, got {self.n_splits}")
|
||||||
|
|
||||||
|
logger.info(f"Initializing TimeSeriesCrossValidationSplitter (Rolling Window): n_splits={self.n_splits}, "
|
||||||
|
f"val_frac={self.val_frac}, test_frac={self.test_frac}, initial_train_size (fixed)={self.initial_train_size}") # Clarified log
|
||||||
|
|
||||||
|
def _calculate_initial_train_size(self) -> int:
|
||||||
|
"""Determines the fixed training window size based on config or estimation."""
|
||||||
|
# Check if integer is provided
|
||||||
|
if isinstance(self.initial_train_size, int) and self.initial_train_size > 0:
|
||||||
|
if self.initial_train_size >= self.n_samples:
|
||||||
|
raise ValueError(f"initial_train_size ({self.initial_train_size}) must be less than total samples ({self.n_samples})")
|
||||||
|
logger.info(f"Using specified fixed training window size: {self.initial_train_size}")
|
||||||
|
return self.initial_train_size
|
||||||
|
|
||||||
|
# Check if float/fraction is provided
|
||||||
|
elif isinstance(self.initial_train_size, float) and 0 < self.initial_train_size < 1:
|
||||||
|
calculated_size = int(self.n_samples * self.initial_train_size)
|
||||||
|
if calculated_size <= 0:
|
||||||
|
raise ValueError("initial_train_size fraction results in non-positive size.")
|
||||||
|
logger.info(f"Using fixed training window size calculated from fraction: {calculated_size}")
|
||||||
|
return calculated_size
|
||||||
|
|
||||||
|
# Estimate if None
|
||||||
|
elif self.initial_train_size is None:
|
||||||
|
min_samples_per_split_step = 2 # Heuristic minimum samples for val+test in one step
|
||||||
|
# Estimate val/test based on *potential* train size (crude)
|
||||||
|
# Assume train is roughly (1 - val - test) fraction for estimation
|
||||||
|
estimated_train_frac = max(0.1, 1.0 - self.val_frac - self.test_frac) # Ensure non-zero
|
||||||
|
estimated_train_n = int(self.n_samples * estimated_train_frac)
|
||||||
|
val_test_size_per_step = max(min_samples_per_split_step, int(estimated_train_n * (self.val_frac + self.test_frac)))
|
||||||
|
|
||||||
|
# Tentative initial train size is total minus one val/test block
|
||||||
|
fixed_train_n_est = self.n_samples - val_test_size_per_step
|
||||||
|
|
||||||
|
# Basic sanity checks
|
||||||
|
if fixed_train_n_est <= 0:
|
||||||
|
raise ValueError("Could not estimate a valid initial_train_size (<= 0). Please specify it or check CV fractions.")
|
||||||
|
# Need at least 1 sample for train, val, test each theoretically
|
||||||
|
est_val_size = max(1, int(fixed_train_n_est * self.val_frac))
|
||||||
|
est_test_size = max(1, int(fixed_train_n_est * self.test_frac))
|
||||||
|
if fixed_train_n_est + est_val_size + est_test_size > self.n_samples:
|
||||||
|
# If the simple estimate is too large, reduce it more drastically
|
||||||
|
# Try setting train size = 50% and see if val/test fit?
|
||||||
|
fixed_train_n_est = int(self.n_samples * 0.5)
|
||||||
|
est_val_size = max(1, int(fixed_train_n_est * self.val_frac))
|
||||||
|
est_test_size = max(1, int(fixed_train_n_est * self.test_frac))
|
||||||
|
if fixed_train_n_est <=0 or (fixed_train_n_est + est_val_size + est_test_size > self.n_samples):
|
||||||
|
raise ValueError("Could not estimate a valid initial_train_size. Data too small relative to val/test fractions? Please specify initial_train_size.")
|
||||||
|
|
||||||
|
logger.warning(f"initial_train_size not set, estimated fixed train size for rolling window: {fixed_train_n_est}. "
|
||||||
|
"This is a heuristic; viability depends on n_splits and step size. Validation happens in split().")
|
||||||
|
return fixed_train_n_est
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid initial_train_size: {self.initial_train_size}")
|
||||||
|
|
||||||
|
|
||||||
def split(self) -> Generator[Tuple[np.ndarray, np.ndarray, np.ndarray], None, None]:
|
def split(self) -> Generator[Tuple[np.ndarray, np.ndarray, np.ndarray], None, None]:
|
||||||
"""
|
"""
|
||||||
Generate train/val/test splits using expanding window approach.
|
Generate train/validation/test indices for each fold using a rolling window.
|
||||||
|
Pre-calculates the number of possible splits based on data size and window parameters.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Tuple of (train_indices, val_indices, test_indices) for each fold.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If parameters lead to invalid split sizes or overlaps,
|
||||||
|
or if the data is too small for the configuration.
|
||||||
"""
|
"""
|
||||||
# TODO: Implement expanding window CV splitter
|
indices = np.arange(self.n_samples)
|
||||||
pass
|
fixed_train_n = self._calculate_initial_train_size() # This is now the fixed size
|
||||||
|
|
||||||
|
# Calculate val/test sizes based on the *fixed* training size. Min size of 1.
|
||||||
|
val_size = max(1, int(fixed_train_n * self.val_frac))
|
||||||
|
test_size = max(1, int(fixed_train_n * self.test_frac))
|
||||||
|
|
||||||
|
# Calculate the total size of one complete train+val+test window
|
||||||
|
fold_window_size = fixed_train_n + val_size + test_size
|
||||||
|
|
||||||
|
# Check if even the first window fits
|
||||||
|
if fold_window_size > self.n_samples:
|
||||||
|
raise ValueError(f"Configuration Error: The total window size (Train {fixed_train_n} + Val {val_size} + Test {test_size} = {fold_window_size}) "
|
||||||
|
f"exceeds total samples ({self.n_samples}). Decrease initial_train_size, fractions, or increase data.")
|
||||||
|
|
||||||
|
# Determine the step size (how much the window slides)
|
||||||
|
# Default: slide by the test set size for contiguous, non-overlapping test periods
|
||||||
|
step_size = test_size
|
||||||
|
if step_size <= 0:
|
||||||
|
raise ValueError(f"Step size (derived from test_size {test_size}) must be positive.")
|
||||||
|
|
||||||
|
# --- Calculate the number of splits actually possible ---
|
||||||
|
# Last possible start index for the train set
|
||||||
|
last_possible_train_start_idx = self.n_samples - fold_window_size
|
||||||
|
# Calculate how many steps fit within this range (integer division)
|
||||||
|
# If last possible start is 5, step is 2: steps possible at 0, 2, 4 => (5 // 2) + 1 = 2 + 1 = 3
|
||||||
|
num_possible_steps = max(0, last_possible_train_start_idx // step_size) + 1 # +1 because we start at index 0
|
||||||
|
|
||||||
|
# Use the minimum of requested splits and possible splits
|
||||||
|
actual_n_splits = min(self.n_splits, num_possible_steps)
|
||||||
|
|
||||||
|
if actual_n_splits < self.n_splits:
|
||||||
|
logger.warning(f"Data size ({self.n_samples} samples) only allows for {actual_n_splits} splits "
|
||||||
|
f"with fixed train size {fixed_train_n}, val size {val_size}, test size {test_size} (total window {fold_window_size}) and step size {step_size} "
|
||||||
|
f"(requested {self.n_splits}).")
|
||||||
|
elif actual_n_splits == 0:
|
||||||
|
# This case should be caught by the fold_window_size > self.n_samples check, but belt-and-suspenders
|
||||||
|
logger.error("Data too small for even one split with the rolling window configuration.")
|
||||||
|
return # Return generator that yields nothing
|
||||||
|
|
||||||
|
# --- Generate the splits ---
|
||||||
|
for i in range(actual_n_splits):
|
||||||
|
logger.debug(f"Generating indices for fold {i+1}/{actual_n_splits} (Rolling Window)") # Log using actual_n_splits
|
||||||
|
|
||||||
|
# Calculate window boundaries for this fold
|
||||||
|
train_start_idx = i * step_size
|
||||||
|
train_end_idx = train_start_idx + fixed_train_n
|
||||||
|
val_start_idx = train_end_idx
|
||||||
|
val_end_idx = val_start_idx + val_size
|
||||||
|
test_start_idx = val_end_idx
|
||||||
|
test_end_idx = test_start_idx + test_size # = train_start_idx + fold_window_size
|
||||||
|
|
||||||
|
# Determine indices for this fold using slicing
|
||||||
|
train_indices = indices[train_start_idx:train_end_idx]
|
||||||
|
val_indices = indices[val_start_idx:val_end_idx]
|
||||||
|
test_indices = indices[test_start_idx:test_end_idx]
|
||||||
|
|
||||||
|
# --- Basic Validation Checks (Optional, should be guaranteed by calculations) ---
|
||||||
|
# Ensure no overlap (guaranteed by slicing if sizes > 0)
|
||||||
|
# Ensure sequence (guaranteed by slicing)
|
||||||
|
|
||||||
|
logger.info(f"Fold {i+1}: Train indices {train_indices[0]}-{train_indices[-1]} (size {len(train_indices)}), "
|
||||||
|
f"Val indices {val_indices[0]}-{val_indices[-1]} (size {len(val_indices)}), "
|
||||||
|
f"Test indices {test_indices[0]}-{test_indices[-1]} (size {len(test_indices)})")
|
||||||
|
|
||||||
|
yield train_indices, val_indices, test_indices
|
||||||
|
|
||||||
|
|
||||||
# --- Dataset Class ---
|
# --- Dataset Class ---
|
||||||
class TimeSeriesDataset(Dataset):
|
class TimeSeriesDataset(Dataset):
|
||||||
def __init__(self, data_array: np.ndarray, sequence_length: int, forecast_horizon: int):
|
"""
|
||||||
self.data = data_array
|
PyTorch Dataset for time series forecasting.
|
||||||
|
|
||||||
|
Takes a NumPy array (features + target), sequence length, and forecast horizon,
|
||||||
|
and returns (input_sequence, target_sequence) tuples. Compatible with PyTorch
|
||||||
|
DataLoaders used by PyTorch Lightning.
|
||||||
|
"""
|
||||||
|
def __init__(self, data_array: np.ndarray, sequence_length: int, forecast_horizon: int, target_col_index: int = 0):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
data_array: Numpy array of shape (n_samples, n_features).
|
||||||
|
Assumes the target variable is one of the columns.
|
||||||
|
sequence_length: Length of the input sequence (lookback window).
|
||||||
|
forecast_horizon: Number of steps ahead to predict.
|
||||||
|
target_col_index: Index of the target column in data_array. Defaults to 0.
|
||||||
|
"""
|
||||||
|
if sequence_length <= 0:
|
||||||
|
raise ValueError("sequence_length must be positive.")
|
||||||
|
if forecast_horizon <= 0:
|
||||||
|
raise ValueError("forecast_horizon must be positive.")
|
||||||
|
if data_array.ndim != 2:
|
||||||
|
raise ValueError(f"data_array must be 2D, but got shape {data_array.shape}")
|
||||||
|
min_len_required = sequence_length + forecast_horizon
|
||||||
|
if min_len_required > data_array.shape[0]:
|
||||||
|
raise ValueError(f"sequence_length ({sequence_length}) + forecast_horizon ({forecast_horizon}) = {min_len_required} "
|
||||||
|
f"exceeds total samples provided ({data_array.shape[0]})")
|
||||||
|
if not (0 <= target_col_index < data_array.shape[1]):
|
||||||
|
raise ValueError(f"target_col_index ({target_col_index}) out of bounds for data with {data_array.shape[1]} columns.")
|
||||||
|
|
||||||
|
|
||||||
|
self.data = torch.tensor(data_array, dtype=torch.float32)
|
||||||
self.sequence_length = sequence_length
|
self.sequence_length = sequence_length
|
||||||
self.forecast_horizon = forecast_horizon
|
self.forecast_horizon = forecast_horizon
|
||||||
|
self.target_col_index = target_col_index
|
||||||
|
self.n_samples = data_array.shape[0]
|
||||||
|
self.n_features = data_array.shape[1]
|
||||||
|
|
||||||
|
logger.debug(f"TimeSeriesDataset created: data shape={self.data.shape}, "
|
||||||
|
f"seq_len={self.sequence_length}, forecast_horizon={self.forecast_horizon}, "
|
||||||
|
f"target_idx={self.target_col_index}")
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
# TODO: Implement length calculation
|
"""Returns the total number of sequences that can be generated."""
|
||||||
pass
|
return self.n_samples - self.sequence_length - self.forecast_horizon + 1
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# TODO: Implement sequence extraction
|
"""
|
||||||
pass
|
Returns a single (input_sequence, target_sequence) pair.
|
||||||
|
"""
|
||||||
|
if not (0 <= idx < len(self)):
|
||||||
|
raise IndexError(f"Index {idx} out of bounds for dataset with length {len(self)}")
|
||||||
|
input_start = idx
|
||||||
|
input_end = idx + self.sequence_length
|
||||||
|
input_sequence = self.data[input_start:input_end, :]
|
||||||
|
target_start = input_end
|
||||||
|
target_end = target_start + self.forecast_horizon
|
||||||
|
target_sequence = self.data[target_start:target_end, self.target_col_index]
|
||||||
|
return input_sequence, target_sequence
|
||||||
|
|
||||||
# --- Data Preparation ---
|
# --- Data Preparation ---
|
||||||
def prepare_fold_data_and_loaders(
|
def prepare_fold_data_and_loaders(
|
||||||
full_df: pd.DataFrame,
|
full_df: pd.DataFrame, # Should contain only the target initially
|
||||||
train_idx: np.ndarray,
|
train_idx: np.ndarray,
|
||||||
val_idx: np.ndarray,
|
val_idx: np.ndarray,
|
||||||
test_idx: np.ndarray,
|
test_idx: np.ndarray,
|
||||||
|
target_col: str,
|
||||||
feature_config: FeatureConfig,
|
feature_config: FeatureConfig,
|
||||||
train_config: TrainingConfig,
|
train_config: TrainingConfig,
|
||||||
eval_config: EvaluationConfig
|
eval_config: EvaluationConfig
|
||||||
) -> Tuple[DataLoader, DataLoader, DataLoader, object, int]:
|
) -> Tuple[DataLoader, DataLoader, DataLoader, Union[StandardScaler, MinMaxScaler, None], int]:
|
||||||
"""
|
"""
|
||||||
Prepare data loaders for a single fold.
|
Prepares data loaders for a single cross-validation fold.
|
||||||
|
|
||||||
|
This is essential for time-series CV where scaling must be fitted *only*
|
||||||
|
on the training data of the current fold to prevent lookahead bias.
|
||||||
|
The resulting DataLoaders can be used directly with a PyTorch Lightning Trainer
|
||||||
|
within the cross-validation loop in `main.py`.
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
1. Determines the full data range needed for the fold (incl. history for features).
|
||||||
|
2. Engineers features on this slice using `engineer_features`.
|
||||||
|
3. Splits the engineered data into train, validation, test sets based on indices.
|
||||||
|
4. **Fits a scaler ONLY on the training data for the current fold.**
|
||||||
|
5. Transforms train, validation, and test sets using the fitted scaler.
|
||||||
|
6. Creates `TimeSeriesDataset` instances for each set.
|
||||||
|
7. Creates `DataLoader` instances for each set.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
full_df: The complete raw DataFrame (datetime index, target column).
|
||||||
|
train_idx: Array of integer indices for the training set.
|
||||||
|
val_idx: Array of integer indices for the validation set.
|
||||||
|
test_idx: Array of integer indices for the test set.
|
||||||
|
target_col: Name of the target column.
|
||||||
|
feature_config: Configuration for feature engineering.
|
||||||
|
train_config: Configuration for training (used for batch size, device hints).
|
||||||
|
eval_config: Configuration for evaluation (used for batch size).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing:
|
||||||
|
- train_loader: DataLoader for the training set.
|
||||||
|
- val_loader: DataLoader for the validation set.
|
||||||
|
- test_loader: DataLoader for the test set.
|
||||||
|
- target_scaler: The scaler fitted on the target variable (for inverse transform). Can be None.
|
||||||
|
- input_size: The number of features in the input sequences (X).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If indices are invalid, data splitting fails, or NaNs persist.
|
||||||
|
ImportError: If feature engineering requires an uninstalled library.
|
||||||
"""
|
"""
|
||||||
# TODO: Implement data preparation pipeline
|
logger.info(f"Preparing data loaders for fold: train_size={len(train_idx)}, val_size={len(val_idx)}, test_size={len(test_idx)}")
|
||||||
pass
|
if len(train_idx) == 0 or len(val_idx) == 0 or len(test_idx) == 0:
|
||||||
|
raise ValueError("Received empty indices for train, validation, or test set.")
|
||||||
|
|
||||||
|
# 1. Determine data slice needed including history for feature lookback
|
||||||
|
max_lookback = 0
|
||||||
|
if feature_config.lags:
|
||||||
|
max_lookback = max(max_lookback, max(feature_config.lags))
|
||||||
|
if feature_config.rolling_window_sizes:
|
||||||
|
max_lookback = max(max_lookback, max(feature_config.rolling_window_sizes) -1 )
|
||||||
|
max_history_needed = max(max_lookback, feature_config.sequence_length)
|
||||||
|
|
||||||
|
slice_start_idx = max(0, train_idx[0] - max_history_needed)
|
||||||
|
slice_end_idx = test_idx[-1] + 1
|
||||||
|
if slice_start_idx >= slice_end_idx:
|
||||||
|
raise ValueError(f"Calculated slice start ({slice_start_idx}) >= slice end ({slice_end_idx}). Check indices.")
|
||||||
|
|
||||||
|
fold_data_slice = full_df.iloc[slice_start_idx:slice_end_idx]
|
||||||
|
logger.debug(f"Required data slice for fold: indices {slice_start_idx} to {slice_end_idx-1} "
|
||||||
|
f"(size {len(fold_data_slice)}) for history and fold data.")
|
||||||
|
|
||||||
|
if fold_data_slice.empty:
|
||||||
|
raise ValueError(f"Data slice for fold is empty (indices {slice_start_idx} to {slice_end_idx-1}).")
|
||||||
|
|
||||||
|
# 2. Feature Engineering on the slice
|
||||||
|
try:
|
||||||
|
engineered_df = engineer_features(fold_data_slice.copy(), target_col, feature_config)
|
||||||
|
if engineered_df.empty:
|
||||||
|
raise ValueError("Feature engineering resulted in an empty DataFrame.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Feature engineering failed for fold: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# 3. Map absolute indices to iloc positions in the potentially modified engineered_df
|
||||||
|
try:
|
||||||
|
# Use index intersection to find valid locations
|
||||||
|
train_indices_dt = full_df.index[train_idx]
|
||||||
|
val_indices_dt = full_df.index[val_idx]
|
||||||
|
test_indices_dt = full_df.index[test_idx]
|
||||||
|
|
||||||
|
adj_train_idx_loc = engineered_df.index.get_indexer(train_indices_dt.intersection(engineered_df.index))
|
||||||
|
adj_val_idx_loc = engineered_df.index.get_indexer(val_indices_dt.intersection(engineered_df.index))
|
||||||
|
adj_test_idx_loc = engineered_df.index.get_indexer(test_indices_dt.intersection(engineered_df.index))
|
||||||
|
|
||||||
|
# Filter out any -1s just in case (shouldn't happen with intersection)
|
||||||
|
adj_train_idx_loc = adj_train_idx_loc[adj_train_idx_loc != -1]
|
||||||
|
adj_val_idx_loc = adj_val_idx_loc[adj_val_idx_loc != -1]
|
||||||
|
adj_test_idx_loc = adj_test_idx_loc[adj_test_idx_loc != -1]
|
||||||
|
|
||||||
|
|
||||||
|
if len(adj_train_idx_loc) == 0 or len(adj_val_idx_loc) == 0 or len(adj_test_idx_loc) == 0:
|
||||||
|
logger.error(f"Index mapping resulted in empty splits: Train({len(adj_train_idx_loc)}), Val({len(adj_val_idx_loc)}), Test({len(adj_test_idx_loc)})")
|
||||||
|
logger.debug(f"Original counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}")
|
||||||
|
logger.debug(f"Engineered DF index span: {engineered_df.index.min()} to {engineered_df.index.max()}")
|
||||||
|
raise ValueError("Mapping original indices to engineered DataFrame resulted in empty splits. Check CV indices and NaN handling.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error mapping indices to engineered DataFrame: {e}")
|
||||||
|
raise ValueError("Failed to map CV indices to the feature-engineered data slice.")
|
||||||
|
|
||||||
|
|
||||||
|
# 4. Split engineered data using iloc positions
|
||||||
|
train_df = engineered_df.iloc[adj_train_idx_loc]
|
||||||
|
val_df = engineered_df.iloc[adj_val_idx_loc]
|
||||||
|
test_df = engineered_df.iloc[adj_test_idx_loc]
|
||||||
|
|
||||||
|
logger.debug(f"Fold split shapes after feature engineering: Train={train_df.shape}, Val={val_df.shape}, Test={test_df.shape}")
|
||||||
|
if train_df.empty or val_df.empty or test_df.empty:
|
||||||
|
raise ValueError("One or more data splits (train, val, test) are empty after feature engineering and splitting.")
|
||||||
|
|
||||||
|
# --- Final Check for NaNs before scaling/Dataset creation ---
|
||||||
|
if train_df.isnull().any().any():
|
||||||
|
nan_cols = train_df.columns[train_df.isnull().any()].tolist()
|
||||||
|
logger.error(f"NaNs found in FINAL training data before scaling. Columns: {nan_cols}")
|
||||||
|
logger.debug(f"NaN counts per column in train_df:\n{train_df.isnull().sum()[train_df.isnull().any()]}")
|
||||||
|
raise ValueError("NaNs present in training data before scaling. Check feature engineering NaN handling.")
|
||||||
|
if val_df.isnull().any().any() or test_df.isnull().any().any():
|
||||||
|
logger.warning("NaNs found in final validation or test data splits. This might cause issues during evaluation or testing.")
|
||||||
|
|
||||||
|
# 5. Scaling (Fit on Train, Transform All) - CRITICAL PER-FOLD STEP
|
||||||
|
feature_cols = train_df.columns.tolist()
|
||||||
|
try:
|
||||||
|
target_col_index_in_features = feature_cols.index(target_col)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"Target column '{target_col}' not found in the final feature columns: {feature_cols}")
|
||||||
|
|
||||||
|
scaler: Optional[Union[StandardScaler, MinMaxScaler]] = None
|
||||||
|
target_scaler: Optional[Union[StandardScaler, MinMaxScaler]] = None
|
||||||
|
ScalerClass: Optional[Type[Union[StandardScaler, MinMaxScaler]]] = None
|
||||||
|
|
||||||
|
if feature_config.scaling_method == 'standard':
|
||||||
|
ScalerClass = StandardScaler
|
||||||
|
elif feature_config.scaling_method == 'minmax':
|
||||||
|
ScalerClass = MinMaxScaler
|
||||||
|
elif feature_config.scaling_method is None:
|
||||||
|
logger.info("No scaling applied for this fold.")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported scaling method: {feature_config.scaling_method}")
|
||||||
|
|
||||||
|
train_data = train_df[feature_cols].values
|
||||||
|
val_data = val_df[feature_cols].values
|
||||||
|
test_data = test_df[feature_cols].values
|
||||||
|
|
||||||
|
if ScalerClass is not None:
|
||||||
|
scaler = ScalerClass()
|
||||||
|
target_scaler = ScalerClass()
|
||||||
|
logger.info(f"Applying {feature_config.scaling_method} scaling. Fitting on training data for the fold.")
|
||||||
|
scaler.fit(train_data)
|
||||||
|
target_scaler.fit(train_data[:, target_col_index_in_features].reshape(-1, 1))
|
||||||
|
train_data_scaled = scaler.transform(train_data)
|
||||||
|
val_data_scaled = scaler.transform(val_data)
|
||||||
|
test_data_scaled = scaler.transform(test_data)
|
||||||
|
logger.debug("Scaling complete for the fold.")
|
||||||
|
else:
|
||||||
|
train_data_scaled = train_data
|
||||||
|
val_data_scaled = val_data
|
||||||
|
test_data_scaled = test_data
|
||||||
|
|
||||||
|
input_size = train_data_scaled.shape[1]
|
||||||
|
|
||||||
|
# 6. Dataset Instantiation
|
||||||
|
logger.debug("Creating TimeSeriesDataset instances for the fold.")
|
||||||
|
try:
|
||||||
|
train_dataset = TimeSeriesDataset(
|
||||||
|
train_data_scaled, feature_config.sequence_length, feature_config.forecast_horizon, target_col_index=target_col_index_in_features
|
||||||
|
)
|
||||||
|
val_dataset = TimeSeriesDataset(
|
||||||
|
val_data_scaled, feature_config.sequence_length, feature_config.forecast_horizon, target_col_index=target_col_index_in_features
|
||||||
|
)
|
||||||
|
test_dataset = TimeSeriesDataset(
|
||||||
|
test_data_scaled, feature_config.sequence_length, feature_config.forecast_horizon, target_col_index=target_col_index_in_features
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Error creating TimeSeriesDataset: {e}")
|
||||||
|
logger.error(f"Shapes fed to Dataset: Train={train_data_scaled.shape}, Val={val_data_scaled.shape}, Test={test_data_scaled.shape}")
|
||||||
|
logger.error(f"SeqLen={feature_config.sequence_length}, Horizon={feature_config.forecast_horizon}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# 7. DataLoader Creation
|
||||||
|
logger.debug("Creating DataLoaders for the fold.")
|
||||||
|
num_workers = getattr(train_config, 'num_workers', 0)
|
||||||
|
pin_memory = torch.cuda.is_available() # Pin memory if CUDA is available
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset, batch_size=train_config.batch_size, shuffle=True,
|
||||||
|
num_workers=num_workers, pin_memory=pin_memory, drop_last=False
|
||||||
|
)
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_dataset, batch_size=eval_config.eval_batch_size, shuffle=False,
|
||||||
|
num_workers=num_workers, pin_memory=pin_memory, drop_last=False
|
||||||
|
)
|
||||||
|
test_loader = DataLoader(
|
||||||
|
test_dataset, batch_size=eval_config.eval_batch_size, shuffle=False,
|
||||||
|
num_workers=num_workers, pin_memory=pin_memory, drop_last=False
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Data loaders prepared successfully for the fold.")
|
||||||
|
|
||||||
|
return train_loader, val_loader, test_loader, target_scaler, input_size
|
@ -1,82 +1,325 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path # Added
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torchmetrics
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from typing import Dict, Any, Optional
|
from sklearn.preprocessing import StandardScaler, MinMaxScaler # For type hinting target_scaler
|
||||||
from utils.config_model import EvaluationConfig
|
from typing import Dict, Any, Optional, Union, List, Tuple
|
||||||
|
# import matplotlib.pyplot as plt # No longer needed directly
|
||||||
|
# import seaborn as sns # No longer needed directly
|
||||||
|
|
||||||
def calculate_mae(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
# Assuming config_model and io.plotting are accessible
|
||||||
"""
|
from forecasting_model.utils.config_model import EvaluationConfig
|
||||||
Calculate Mean Absolute Error.
|
from forecasting_model.io.plotting import ( # Import the plotting utilities
|
||||||
"""
|
setup_plot_style,
|
||||||
# TODO: Implement MAE calculation
|
save_plot,
|
||||||
pass
|
create_time_series_plot,
|
||||||
|
create_scatter_plot,
|
||||||
|
create_residuals_plot,
|
||||||
|
create_residuals_distribution_plot
|
||||||
|
)
|
||||||
|
|
||||||
def calculate_rmse(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
||||||
"""
|
|
||||||
Calculate Root Mean Squared Error.
|
|
||||||
"""
|
|
||||||
# TODO: Implement RMSE calculation
|
|
||||||
pass
|
|
||||||
|
|
||||||
def plot_predictions_vs_actual(
|
logger = logging.getLogger(__name__)
|
||||||
y_true: np.ndarray,
|
|
||||||
y_pred: np.ndarray,
|
|
||||||
title_suffix: str,
|
|
||||||
filename: str,
|
|
||||||
max_points: Optional[int] = None
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Create line plot of predictions vs actual values.
|
|
||||||
"""
|
|
||||||
# TODO: Implement prediction vs actual plot
|
|
||||||
pass
|
|
||||||
|
|
||||||
def plot_scatter_predictions(
|
# --- Metric Calculations (Utilities - Optional) ---
|
||||||
y_true: np.ndarray,
|
# (Keep calculate_mae_np, calculate_rmse_np if needed as standalone utils)
|
||||||
y_pred: np.ndarray,
|
# ... (code for calculate_mae_np, calculate_rmse_np unchanged) ...
|
||||||
title_suffix: str,
|
def calculate_mae_np(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
||||||
filename: str
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Create scatter plot of predictions vs actual values.
|
[Optional Utility] Calculate Mean Absolute Error using NumPy.
|
||||||
"""
|
Prefer torchmetrics inside training/validation loops.
|
||||||
# TODO: Implement scatter plot
|
|
||||||
pass
|
|
||||||
|
|
||||||
def plot_residuals_time(
|
Args:
|
||||||
residuals: np.ndarray,
|
y_true: Ground truth values (flattened).
|
||||||
title_suffix: str,
|
y_pred: Predicted values (flattened).
|
||||||
filename: str,
|
|
||||||
max_points: Optional[int] = None
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Create plot of residuals over time.
|
|
||||||
"""
|
|
||||||
# TODO: Implement residuals time plot
|
|
||||||
pass
|
|
||||||
|
|
||||||
def plot_residuals_distribution(
|
Returns:
|
||||||
residuals: np.ndarray,
|
Calculated MAE, or NaN if inputs are invalid.
|
||||||
title_suffix: str,
|
|
||||||
filename: str
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Create histogram/KDE of residuals.
|
if y_true.shape != y_pred.shape:
|
||||||
"""
|
logger.error(f"Shape mismatch for MAE: y_true={y_true.shape}, y_pred={y_pred.shape}")
|
||||||
# TODO: Implement residuals distribution plot
|
return np.nan
|
||||||
pass
|
if len(y_true) == 0:
|
||||||
|
logger.warning("Attempting to calculate MAE on empty arrays.")
|
||||||
|
return np.nan
|
||||||
|
try:
|
||||||
|
# Use scikit-learn for robustness if available, otherwise basic numpy
|
||||||
|
from sklearn.metrics import mean_absolute_error
|
||||||
|
mae = mean_absolute_error(y_true, y_pred)
|
||||||
|
except ImportError:
|
||||||
|
mae = np.mean(np.abs(y_true - y_pred))
|
||||||
|
return float(mae)
|
||||||
|
|
||||||
def evaluate_fold(
|
|
||||||
model: torch.nn.Module,
|
def calculate_rmse_np(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
||||||
test_loader: DataLoader,
|
"""
|
||||||
loss_fn: torch.nn.Module,
|
[Optional Utility] Calculate Root Mean Squared Error using NumPy.
|
||||||
device: torch.device,
|
Prefer torchmetrics inside training/validation loops.
|
||||||
target_scaler: Any,
|
|
||||||
|
Args:
|
||||||
|
y_true: Ground truth values (flattened).
|
||||||
|
y_pred: Predicted values (flattened).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Calculated RMSE, or NaN if inputs are invalid.
|
||||||
|
"""
|
||||||
|
if y_true.shape != y_pred.shape:
|
||||||
|
logger.error(f"Shape mismatch for RMSE: y_true={y_true.shape}, y_pred={y_pred.shape}")
|
||||||
|
return np.nan
|
||||||
|
if len(y_true) == 0:
|
||||||
|
logger.warning("Attempting to calculate RMSE on empty arrays.")
|
||||||
|
return np.nan
|
||||||
|
try:
|
||||||
|
# Use scikit-learn for robustness if available, otherwise basic numpy
|
||||||
|
from sklearn.metrics import mean_squared_error
|
||||||
|
mse = mean_squared_error(y_true, y_pred, squared=True)
|
||||||
|
except ImportError:
|
||||||
|
mse = np.mean((y_true - y_pred)**2)
|
||||||
|
rmse = np.sqrt(mse)
|
||||||
|
return float(rmse)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Plotting Functions (Utilities) ---
|
||||||
|
# REMOVED - These are now imported from io.plotting
|
||||||
|
|
||||||
|
|
||||||
|
# --- Fold Evaluation Function ---
|
||||||
|
|
||||||
|
def evaluate_fold_predictions(
|
||||||
|
y_true_scaled: np.ndarray,
|
||||||
|
y_pred_scaled: np.ndarray,
|
||||||
|
target_scaler: Union[StandardScaler, MinMaxScaler, None],
|
||||||
eval_config: EvaluationConfig,
|
eval_config: EvaluationConfig,
|
||||||
fold_num: int
|
fold_num: int,
|
||||||
|
output_dir: str, # Base output directory (e.g., output/cv_results)
|
||||||
|
time_index: Optional[np.ndarray] = None # Optional: Pass time index for x-axis
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, float]:
|
||||||
"""
|
"""
|
||||||
Evaluate model on test set and generate plots.
|
Processes prediction results for a fold's test set using torchmetrics.
|
||||||
|
|
||||||
|
Takes scaled predictions and targets, inverse transforms them,
|
||||||
|
calculates final metrics (MAE, RMSE) using torchmetrics.functional,
|
||||||
|
and generates evaluation plots using utilities from io.plotting. Assumes
|
||||||
|
model inference is already done.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y_true_scaled: Numpy array of scaled ground truth targets (n_samples, horizon).
|
||||||
|
y_pred_scaled: Numpy array of scaled model predictions (n_samples, horizon).
|
||||||
|
target_scaler: The scaler fitted on the target variable during training. Needed
|
||||||
|
for inverse transforming to original scale. Can be None.
|
||||||
|
eval_config: Configuration object for evaluation parameters (e.g., plotting).
|
||||||
|
fold_num: The current fold number (e.g., 0, 1, ...).
|
||||||
|
output_dir: The base directory to save fold-specific outputs (plots, metrics).
|
||||||
|
time_index: Optional array representing the time index for the test set,
|
||||||
|
used for x-axis in time-based plots. If None, uses integer indices.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing evaluation metrics {'MAE': value, 'RMSE': value} on the
|
||||||
|
original scale. Metrics will be NaN if inverse transform or calculation fails.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If input shapes are inconsistent or required scaler is missing.
|
||||||
"""
|
"""
|
||||||
# TODO: Implement full evaluation pipeline
|
logger.info(f"Processing evaluation results for Fold {fold_num + 1}...")
|
||||||
pass
|
fold_id = fold_num + 1 # Use 1-based indexing for reporting/filenames
|
||||||
|
|
||||||
|
if y_true_scaled.shape != y_pred_scaled.shape:
|
||||||
|
raise ValueError(f"Shape mismatch between targets and predictions: "
|
||||||
|
f"{y_true_scaled.shape} vs {y_pred_scaled.shape}")
|
||||||
|
if y_true_scaled.ndim != 2:
|
||||||
|
raise ValueError(f"Expected 2D arrays for targets and predictions, got {y_true_scaled.ndim}D")
|
||||||
|
|
||||||
|
n_samples, horizon = y_true_scaled.shape
|
||||||
|
logger.debug(f"Processing {n_samples} samples with horizon {horizon}.")
|
||||||
|
|
||||||
|
# --- Inverse Transform (Outputs NumPy) ---
|
||||||
|
y_true_flat_scaled = y_true_scaled.reshape(-1, 1)
|
||||||
|
y_pred_flat_scaled = y_pred_scaled.reshape(-1, 1)
|
||||||
|
|
||||||
|
y_true_inv_np: np.ndarray
|
||||||
|
y_pred_inv_np: np.ndarray
|
||||||
|
|
||||||
|
if target_scaler is not None:
|
||||||
|
try:
|
||||||
|
logger.debug("Inverse transforming predictions and targets.")
|
||||||
|
y_true_inv_np = target_scaler.inverse_transform(y_true_flat_scaled)
|
||||||
|
y_pred_inv_np = target_scaler.inverse_transform(y_pred_flat_scaled)
|
||||||
|
# Flatten NumPy arrays for metric calculation and plotting
|
||||||
|
y_true_np = y_true_inv_np.flatten()
|
||||||
|
y_pred_np = y_pred_inv_np.flatten()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during inverse scaling for Fold {fold_id}: {e}", exc_info=True)
|
||||||
|
logger.error("Metrics calculation will be skipped due to inverse transform failure.")
|
||||||
|
return {'MAE': np.nan, 'RMSE': np.nan}
|
||||||
|
else:
|
||||||
|
logger.info("No target scaler provided, assuming inputs are already on original scale.")
|
||||||
|
# Flatten NumPy arrays for metric calculation and plotting
|
||||||
|
y_true_np = y_true_flat_scaled.flatten()
|
||||||
|
y_pred_np = y_pred_flat_scaled.flatten()
|
||||||
|
|
||||||
|
# --- Calculate Metrics using torchmetrics.functional ---
|
||||||
|
metrics: Dict[str, float] = {'MAE': np.nan, 'RMSE': np.nan} # Initialize with NaN
|
||||||
|
try:
|
||||||
|
if len(y_true_np) > 0: # Check if data exists after potential failures
|
||||||
|
y_true_tensor = torch.from_numpy(y_true_np).float().cpu()
|
||||||
|
y_pred_tensor = torch.from_numpy(y_pred_np).float().cpu()
|
||||||
|
|
||||||
|
mae_tensor = torchmetrics.functional.mean_absolute_error(y_pred_tensor, y_true_tensor)
|
||||||
|
mse_tensor = torchmetrics.functional.mean_squared_error(y_pred_tensor, y_true_tensor)
|
||||||
|
rmse_tensor = torch.sqrt(mse_tensor)
|
||||||
|
|
||||||
|
metrics['MAE'] = mae_tensor.item()
|
||||||
|
metrics['RMSE'] = rmse_tensor.item()
|
||||||
|
|
||||||
|
logger.info(f"Fold {fold_id} Test Set Metrics (torchmetrics): MAE={metrics['MAE']:.4f}, RMSE={metrics['RMSE']:.4f}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Skipping metric calculation for Fold {fold_id} due to empty data after inverse transform.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to calculate metrics using torchmetrics for Fold {fold_id}: {e}", exc_info=True)
|
||||||
|
# metrics already initialized to NaN
|
||||||
|
|
||||||
|
|
||||||
|
# --- Generate Plots (Optional - uses plotting utilities) ---
|
||||||
|
if eval_config.save_plots and len(y_true_np) > 0:
|
||||||
|
logger.info(f"Generating evaluation plots for Fold {fold_id}...")
|
||||||
|
# Define plot directory and setup style
|
||||||
|
fold_plot_dir = Path(output_dir) / f"fold_{fold_id:02d}" / "plots"
|
||||||
|
setup_plot_style() # Apply consistent styling
|
||||||
|
|
||||||
|
title_suffix = f"Fold {fold_id} Test Set"
|
||||||
|
residuals_np = y_true_np - y_pred_np
|
||||||
|
|
||||||
|
# Determine x-axis: use provided time_index if available, else integer indices
|
||||||
|
# Note: Flattened y_true/y_pred have length n_samples * horizon
|
||||||
|
# Need an appropriate index for this flattened view if time_index is provided.
|
||||||
|
# Simple approach: use integer indices for flattened data.
|
||||||
|
plot_indices = np.arange(len(y_true_np))
|
||||||
|
xlabel = "Time Index (Flattened Horizon x Samples)"
|
||||||
|
# If time_index corresponding to the start of each forecast is passed,
|
||||||
|
# more sophisticated x-axis handling could be done, but integer indices are simpler.
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create and save each plot using utility functions
|
||||||
|
fig_ts = create_time_series_plot(
|
||||||
|
plot_indices, y_true_np, y_pred_np,
|
||||||
|
f"Predictions vs Actual - {title_suffix}",
|
||||||
|
xlabel=xlabel,
|
||||||
|
ylabel="Value (Original Scale)",
|
||||||
|
max_points=eval_config.plot_sample_size
|
||||||
|
)
|
||||||
|
save_plot(fig_ts, fold_plot_dir / "predictions_vs_actual.png")
|
||||||
|
|
||||||
|
fig_scatter = create_scatter_plot(
|
||||||
|
y_true_np, y_pred_np,
|
||||||
|
f"Scatter Plot - {title_suffix}",
|
||||||
|
xlabel="Actual Values (Original Scale)",
|
||||||
|
ylabel="Predicted Values (Original Scale)"
|
||||||
|
)
|
||||||
|
save_plot(fig_scatter, fold_plot_dir / "scatter_predictions.png")
|
||||||
|
|
||||||
|
fig_res_time = create_residuals_plot(
|
||||||
|
plot_indices, residuals_np,
|
||||||
|
f"Residuals Over Time - {title_suffix}",
|
||||||
|
xlabel=xlabel,
|
||||||
|
ylabel="Residual (Original Scale)",
|
||||||
|
max_points=eval_config.plot_sample_size
|
||||||
|
)
|
||||||
|
save_plot(fig_res_time, fold_plot_dir / "residuals_time.png")
|
||||||
|
|
||||||
|
fig_res_dist = create_residuals_distribution_plot(
|
||||||
|
residuals_np,
|
||||||
|
f"Residuals Distribution - {title_suffix}",
|
||||||
|
xlabel="Residual Value (Original Scale)",
|
||||||
|
ylabel="Density"
|
||||||
|
)
|
||||||
|
save_plot(fig_res_dist, fold_plot_dir / "residuals_distribution.png")
|
||||||
|
|
||||||
|
logger.info(f"Evaluation plots saved to: {fold_plot_dir}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to generate or save one or more plots for Fold {fold_id}: {e}", exc_info=True)
|
||||||
|
# Continue without plots, metrics are already calculated.
|
||||||
|
|
||||||
|
elif eval_config.save_plots and len(y_true_np) == 0:
|
||||||
|
logger.warning(f"Skipping plot generation for Fold {fold_id} due to empty data.")
|
||||||
|
|
||||||
|
|
||||||
|
logger.info(f"Evaluation processing finished for Fold {fold_id}.")
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
# --- (Optional) Wrapper for non-PL usage or direct testing ---
|
||||||
|
# This function still calls evaluate_fold_predictions internally, so it benefits
|
||||||
|
# from the updated plotting logic without needing direct changes here.
|
||||||
|
def evaluate_model_on_fold_test_set(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
test_loader: DataLoader,
|
||||||
|
device: torch.device,
|
||||||
|
target_scaler: Union[StandardScaler, MinMaxScaler, None],
|
||||||
|
eval_config: EvaluationConfig,
|
||||||
|
fold_num: int,
|
||||||
|
output_dir: str
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
[Optional Function] Evaluates a given model on a fold's test set.
|
||||||
|
|
||||||
|
Runs the inference loop, collects scaled results, then processes them using
|
||||||
|
`evaluate_fold_predictions` (which now uses plotting utilities).
|
||||||
|
Useful for standalone testing or if not using pl.Trainer.test().
|
||||||
|
"""
|
||||||
|
# ... (Implementation of inference loop remains the same) ...
|
||||||
|
logger.info(f"Starting full evaluation (inference + processing) for Fold {fold_num + 1}...")
|
||||||
|
model.eval()
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
all_preds_scaled_list: List[torch.Tensor] = []
|
||||||
|
all_targets_scaled_list: List[torch.Tensor] = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for i, (X_batch, y_batch) in enumerate(test_loader):
|
||||||
|
try:
|
||||||
|
X_batch = X_batch.to(device)
|
||||||
|
outputs = model(X_batch) # Scaled outputs
|
||||||
|
|
||||||
|
# Ensure outputs match target shape (e.g., handle trailing dimension)
|
||||||
|
if outputs.shape != y_batch.shape:
|
||||||
|
if outputs.ndim == y_batch.ndim + 1 and outputs.shape[-1] == 1:
|
||||||
|
outputs = outputs.squeeze(-1)
|
||||||
|
if outputs.shape != y_batch.shape:
|
||||||
|
raise ValueError(f"Shape mismatch: Output {outputs.shape}, Target {y_batch.shape}")
|
||||||
|
|
||||||
|
all_preds_scaled_list.append(outputs.cpu())
|
||||||
|
all_targets_scaled_list.append(y_batch.cpu()) # Keep targets on CPU
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during inference batch {i} for Fold {fold_num+1}: {e}", exc_info=True)
|
||||||
|
raise ValueError(f"Inference failed on batch {i} for Fold {fold_num+1}")
|
||||||
|
|
||||||
|
|
||||||
|
# Concatenate results from all batches
|
||||||
|
try:
|
||||||
|
if not all_preds_scaled_list or not all_targets_scaled_list:
|
||||||
|
logger.error(f"No prediction results collected for Fold {fold_num + 1}. Check test_loader.")
|
||||||
|
return {'MAE': np.nan, 'RMSE': np.nan}
|
||||||
|
|
||||||
|
y_pred_scaled = torch.cat(all_preds_scaled_list, dim=0).numpy()
|
||||||
|
y_true_scaled = torch.cat(all_targets_scaled_list, dim=0).numpy()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error concatenating prediction results for Fold {fold_num + 1}: {e}", exc_info=True)
|
||||||
|
raise ValueError("Failed to combine batch results during evaluation inference.")
|
||||||
|
|
||||||
|
# Process the collected predictions using the refactored function
|
||||||
|
# No time_index passed here by default, plotting will use integer indices
|
||||||
|
return evaluate_fold_predictions(
|
||||||
|
y_true_scaled=y_true_scaled,
|
||||||
|
y_pred_scaled=y_pred_scaled,
|
||||||
|
target_scaler=target_scaler,
|
||||||
|
eval_config=eval_config,
|
||||||
|
fold_num=fold_num,
|
||||||
|
output_dir=output_dir,
|
||||||
|
time_index=None # Explicitly pass None
|
||||||
|
)
|
@ -1,5 +1,26 @@
|
|||||||
"""
|
"""
|
||||||
IO utilities for the forecasting model.
|
Input/Output utilities for the forecasting model package.
|
||||||
|
Currently, primarily includes plotting functions used internally by evaluation.
|
||||||
|
"""
|
||||||
|
|
||||||
This package contains utilities for data loading, saving, and visualization.
|
# Expose plotting utilities if intended for external use
|
||||||
"""
|
# from .plotting import (
|
||||||
|
# setup_plot_style,
|
||||||
|
# save_plot,
|
||||||
|
# create_time_series_plot,
|
||||||
|
# create_scatter_plot,
|
||||||
|
# create_residuals_plot,
|
||||||
|
# create_residuals_distribution_plot
|
||||||
|
# )
|
||||||
|
|
||||||
|
# __all__ = [
|
||||||
|
# "setup_plot_style",
|
||||||
|
# "save_plot",
|
||||||
|
# "create_time_series_plot",
|
||||||
|
# "create_scatter_plot",
|
||||||
|
# "create_residuals_plot",
|
||||||
|
# "create_residuals_distribution_plot",
|
||||||
|
# ]
|
||||||
|
|
||||||
|
# If nothing is intended for public API from this submodule, leave this file empty
|
||||||
|
# or with just a docstring.
|
@ -1,75 +1,307 @@
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def setup_plot_style() -> None:
|
def setup_plot_style(use_seaborn: bool = True) -> None:
|
||||||
"""
|
"""
|
||||||
Set up consistent plotting style.
|
Set up a consistent plotting style using seaborn if enabled.
|
||||||
"""
|
|
||||||
# TODO: Implement plot style configuration
|
|
||||||
pass
|
|
||||||
|
|
||||||
def save_plot(fig: plt.Figure, filename: str) -> None:
|
Args:
|
||||||
|
use_seaborn: Whether to apply seaborn styling.
|
||||||
"""
|
"""
|
||||||
Save plot to file with proper error handling.
|
if use_seaborn:
|
||||||
|
try:
|
||||||
|
sns.set_theme(style="whitegrid", palette="muted")
|
||||||
|
plt.rcParams['figure.figsize'] = (12, 6) # Default figure size
|
||||||
|
logger.debug("Seaborn plot style set.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to set seaborn theme: {e}. Using default matplotlib style.")
|
||||||
|
else:
|
||||||
|
# Optional: Define a default matplotlib style if seaborn is not used
|
||||||
|
plt.style.use('default')
|
||||||
|
logger.debug("Using default matplotlib plot style.")
|
||||||
|
|
||||||
|
def save_plot(fig: plt.Figure, filename: Union[str, Path]) -> None:
|
||||||
"""
|
"""
|
||||||
# TODO: Implement plot saving with error handling
|
Save matplotlib figure to a file with directory creation and error handling.
|
||||||
pass
|
|
||||||
|
Args:
|
||||||
|
fig: The matplotlib Figure object to save.
|
||||||
|
filename: The full path (including filename and extension) to save the plot to.
|
||||||
|
Can be a string or Path object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
OSError: If the directory cannot be created.
|
||||||
|
Exception: For other file saving errors.
|
||||||
|
"""
|
||||||
|
filepath = Path(filename)
|
||||||
|
try:
|
||||||
|
# Create the parent directory if it doesn't exist
|
||||||
|
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
fig.savefig(filepath, bbox_inches='tight', dpi=150) # Save with tight bounding box and decent resolution
|
||||||
|
logger.info(f"Plot saved successfully to: {filepath}")
|
||||||
|
except OSError as e:
|
||||||
|
logger.error(f"Failed to create directory for plot {filepath}: {e}", exc_info=True)
|
||||||
|
raise # Re-raise OSError for directory creation issues
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save plot to {filepath}: {e}", exc_info=True)
|
||||||
|
raise # Re-raise other saving errors
|
||||||
|
finally:
|
||||||
|
# Close the figure to free up memory, regardless of saving success
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
def create_time_series_plot(
|
def create_time_series_plot(
|
||||||
x: np.ndarray,
|
x: np.ndarray,
|
||||||
y_true: np.ndarray,
|
y_true: np.ndarray,
|
||||||
y_pred: np.ndarray,
|
y_pred: np.ndarray,
|
||||||
title: str,
|
title: str,
|
||||||
xlabel: str,
|
xlabel: str = "Time Index",
|
||||||
ylabel: str,
|
ylabel: str = "Value",
|
||||||
max_points: Optional[int] = None
|
max_points: Optional[int] = None
|
||||||
) -> plt.Figure:
|
) -> plt.Figure:
|
||||||
"""
|
"""
|
||||||
Create a time series plot with actual vs predicted values.
|
Create a time series plot comparing actual vs predicted values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The array for the x-axis (e.g., time steps, indices).
|
||||||
|
y_true: Ground truth values (1D array).
|
||||||
|
y_pred: Predicted values (1D array).
|
||||||
|
title: Title for the plot.
|
||||||
|
xlabel: Label for the x-axis.
|
||||||
|
ylabel: Label for the y-axis.
|
||||||
|
max_points: Maximum number of points to display (subsamples if needed).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The generated matplotlib Figure object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If input array shapes are incompatible.
|
||||||
"""
|
"""
|
||||||
# TODO: Implement time series plot creation
|
if not (x.shape == y_true.shape == y_pred.shape and x.ndim == 1):
|
||||||
pass
|
raise ValueError("Input arrays (x, y_true, y_pred) must be 1D and have the same shape.")
|
||||||
|
if len(x) == 0:
|
||||||
|
logger.warning("Attempting to create time series plot with empty data.")
|
||||||
|
# Return an empty figure or raise error? Let's return empty.
|
||||||
|
return plt.figure()
|
||||||
|
|
||||||
|
logger.debug(f"Creating time series plot: {title}")
|
||||||
|
fig, ax = plt.subplots(figsize=(15, 6)) # Consistent size
|
||||||
|
|
||||||
|
n_points = len(x)
|
||||||
|
indices = np.arange(n_points) # Use internal indices for potential slicing
|
||||||
|
|
||||||
|
if max_points and n_points > max_points:
|
||||||
|
step = max(1, n_points // max_points)
|
||||||
|
plot_indices = indices[::step]
|
||||||
|
plot_x = x[::step]
|
||||||
|
plot_y_true = y_true[::step]
|
||||||
|
plot_y_pred = y_pred[::step]
|
||||||
|
effective_title = f'{title} (Sampled {len(plot_indices)} points)'
|
||||||
|
else:
|
||||||
|
plot_x = x
|
||||||
|
plot_y_true = y_true
|
||||||
|
plot_y_pred = y_pred
|
||||||
|
effective_title = title
|
||||||
|
|
||||||
|
ax.plot(plot_x, plot_y_true, label='Actual', marker='.', linestyle='-', markersize=4, linewidth=1.5)
|
||||||
|
ax.plot(plot_x, plot_y_pred, label='Predicted', marker='x', linestyle='--', markersize=4, alpha=0.8, linewidth=1)
|
||||||
|
|
||||||
|
ax.set_title(effective_title, fontsize=14)
|
||||||
|
ax.set_xlabel(xlabel, fontsize=12)
|
||||||
|
ax.set_ylabel(ylabel, fontsize=12)
|
||||||
|
ax.legend()
|
||||||
|
ax.grid(True, linestyle='--', alpha=0.6)
|
||||||
|
fig.tight_layout()
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
def create_scatter_plot(
|
def create_scatter_plot(
|
||||||
y_true: np.ndarray,
|
y_true: np.ndarray,
|
||||||
y_pred: np.ndarray,
|
y_pred: np.ndarray,
|
||||||
title: str,
|
title: str,
|
||||||
xlabel: str,
|
xlabel: str = "Actual Values",
|
||||||
ylabel: str
|
ylabel: str = "Predicted Values"
|
||||||
) -> plt.Figure:
|
) -> plt.Figure:
|
||||||
"""
|
"""
|
||||||
Create a scatter plot of actual vs predicted values.
|
Create a scatter plot of actual vs predicted values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y_true: Ground truth values (1D array).
|
||||||
|
y_pred: Predicted values (1D array).
|
||||||
|
title: Title for the plot.
|
||||||
|
xlabel: Label for the x-axis.
|
||||||
|
ylabel: Label for the y-axis.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The generated matplotlib Figure object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If input array shapes are incompatible.
|
||||||
"""
|
"""
|
||||||
# TODO: Implement scatter plot creation
|
if not (y_true.shape == y_pred.shape and y_true.ndim == 1):
|
||||||
pass
|
raise ValueError("Input arrays (y_true, y_pred) must be 1D and have the same shape.")
|
||||||
|
if len(y_true) == 0:
|
||||||
|
logger.warning("Attempting to create scatter plot with empty data.")
|
||||||
|
return plt.figure()
|
||||||
|
|
||||||
|
logger.debug(f"Creating scatter plot: {title}")
|
||||||
|
fig, ax = plt.subplots(figsize=(8, 8)) # Square figure common for scatter
|
||||||
|
|
||||||
|
# Determine plot limits, handle potential NaNs
|
||||||
|
valid_mask = ~np.isnan(y_true) & ~np.isnan(y_pred)
|
||||||
|
if not np.any(valid_mask):
|
||||||
|
logger.warning(f"No valid (non-NaN) data points found for scatter plot '{title}'.")
|
||||||
|
# Return empty figure, plot would be blank
|
||||||
|
return fig
|
||||||
|
|
||||||
|
y_true_valid = y_true[valid_mask]
|
||||||
|
y_pred_valid = y_pred[valid_mask]
|
||||||
|
|
||||||
|
min_val = min(y_true_valid.min(), y_pred_valid.min())
|
||||||
|
max_val = max(y_true_valid.max(), y_pred_valid.max())
|
||||||
|
plot_range = max_val - min_val
|
||||||
|
if plot_range < 1e-6: # Handle cases where all points are identical
|
||||||
|
plot_range = 1.0 # Avoid zero range
|
||||||
|
|
||||||
|
lim_min = min_val - 0.05 * plot_range
|
||||||
|
lim_max = max_val + 0.05 * plot_range
|
||||||
|
|
||||||
|
ax.scatter(y_true_valid, y_pred_valid, alpha=0.5, s=10, label='Predictions')
|
||||||
|
ax.plot([lim_min, lim_max], [lim_min, lim_max], 'r--', label='Ideal (y=x)', linewidth=1.5)
|
||||||
|
|
||||||
|
ax.set_title(title, fontsize=14)
|
||||||
|
ax.set_xlabel(xlabel, fontsize=12)
|
||||||
|
ax.set_ylabel(ylabel, fontsize=12)
|
||||||
|
ax.set_xlim(lim_min, lim_max)
|
||||||
|
ax.set_ylim(lim_min, lim_max)
|
||||||
|
ax.legend()
|
||||||
|
ax.grid(True, linestyle='--', alpha=0.6)
|
||||||
|
ax.set_aspect('equal', adjustable='box') # Ensure square scaling
|
||||||
|
fig.tight_layout()
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
def create_residuals_plot(
|
def create_residuals_plot(
|
||||||
x: np.ndarray,
|
x: np.ndarray,
|
||||||
residuals: np.ndarray,
|
residuals: np.ndarray,
|
||||||
title: str,
|
title: str,
|
||||||
xlabel: str,
|
xlabel: str = "Time Index",
|
||||||
ylabel: str,
|
ylabel: str = "Residual (Actual - Predicted)",
|
||||||
max_points: Optional[int] = None
|
max_points: Optional[int] = None
|
||||||
) -> plt.Figure:
|
) -> plt.Figure:
|
||||||
"""
|
"""
|
||||||
Create a plot of residuals over time.
|
Create a plot of residuals over time.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The array for the x-axis (e.g., time steps, indices).
|
||||||
|
residuals: Array of residual values (1D array).
|
||||||
|
title: Title for the plot.
|
||||||
|
xlabel: Label for the x-axis.
|
||||||
|
ylabel: Label for the y-axis.
|
||||||
|
max_points: Maximum number of points to display (subsamples if needed).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The generated matplotlib Figure object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If input array shapes are incompatible.
|
||||||
"""
|
"""
|
||||||
# TODO: Implement residuals plot creation
|
if not (x.shape == residuals.shape and x.ndim == 1):
|
||||||
pass
|
raise ValueError("Input arrays (x, residuals) must be 1D and have the same shape.")
|
||||||
|
if len(x) == 0:
|
||||||
|
logger.warning("Attempting to create residuals time plot with empty data.")
|
||||||
|
return plt.figure()
|
||||||
|
|
||||||
|
logger.debug(f"Creating residuals time plot: {title}")
|
||||||
|
fig, ax = plt.subplots(figsize=(15, 5)) # Often wider than tall
|
||||||
|
|
||||||
|
n_points = len(x)
|
||||||
|
indices = np.arange(n_points)
|
||||||
|
|
||||||
|
if max_points and n_points > max_points:
|
||||||
|
step = max(1, n_points // max_points)
|
||||||
|
plot_indices = indices[::step]
|
||||||
|
plot_x = x[::step]
|
||||||
|
plot_residuals = residuals[::step]
|
||||||
|
effective_title = f'{title} (Sampled {len(plot_indices)} points)'
|
||||||
|
else:
|
||||||
|
plot_x = x
|
||||||
|
plot_residuals = residuals
|
||||||
|
effective_title = title
|
||||||
|
|
||||||
|
ax.plot(plot_x, plot_residuals, marker='.', linestyle='-', markersize=4, linewidth=1, label='Residuals')
|
||||||
|
ax.axhline(0, color='red', linestyle='--', label='Zero Error', linewidth=1.5)
|
||||||
|
|
||||||
|
ax.set_title(effective_title, fontsize=14)
|
||||||
|
ax.set_xlabel(xlabel, fontsize=12)
|
||||||
|
ax.set_ylabel(ylabel, fontsize=12)
|
||||||
|
ax.legend()
|
||||||
|
ax.grid(True, linestyle='--', alpha=0.6)
|
||||||
|
fig.tight_layout()
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
def create_residuals_distribution_plot(
|
def create_residuals_distribution_plot(
|
||||||
residuals: np.ndarray,
|
residuals: np.ndarray,
|
||||||
title: str,
|
title: str,
|
||||||
xlabel: str,
|
xlabel: str = "Residual Value",
|
||||||
ylabel: str
|
ylabel: str = "Density"
|
||||||
) -> plt.Figure:
|
) -> plt.Figure:
|
||||||
"""
|
"""
|
||||||
Create a distribution plot of residuals.
|
Create a distribution plot (histogram and KDE) of residuals using seaborn.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
residuals: Array of residual values (1D array).
|
||||||
|
title: Title for the plot.
|
||||||
|
xlabel: Label for the x-axis.
|
||||||
|
ylabel: Label for the y-axis.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The generated matplotlib Figure object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If input array shape is invalid.
|
||||||
"""
|
"""
|
||||||
# TODO: Implement residuals distribution plot creation
|
if residuals.ndim != 1:
|
||||||
pass
|
raise ValueError("Input array (residuals) must be 1D.")
|
||||||
|
if len(residuals) == 0:
|
||||||
|
logger.warning("Attempting to create residuals distribution plot with empty data.")
|
||||||
|
return plt.figure()
|
||||||
|
|
||||||
|
logger.debug(f"Creating residuals distribution plot: {title}")
|
||||||
|
fig, ax = plt.subplots(figsize=(8, 6))
|
||||||
|
|
||||||
|
# Filter out NaNs before plotting and calculating stats
|
||||||
|
residuals_valid = residuals[~np.isnan(residuals)]
|
||||||
|
if len(residuals_valid) == 0:
|
||||||
|
logger.warning(f"No valid (non-NaN) data points found for residual distribution plot '{title}'.")
|
||||||
|
return fig # Return empty figure
|
||||||
|
|
||||||
|
# Use seaborn histplot which combines histogram and KDE
|
||||||
|
try:
|
||||||
|
sns.histplot(residuals_valid, kde=True, bins=50, stat="density", ax=ax)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Seaborn histplot failed for '{title}': {e}. Falling back to matplotlib hist.", exc_info=True)
|
||||||
|
# Fallback to basic matplotlib histogram if seaborn fails
|
||||||
|
ax.hist(residuals_valid, bins=50, density=True, alpha=0.7)
|
||||||
|
ylabel = "Frequency" # Adjust label if only histogram shown
|
||||||
|
|
||||||
|
mean_res = np.mean(residuals_valid)
|
||||||
|
std_res = np.std(residuals_valid)
|
||||||
|
ax.axvline(float(mean_res), color='red', linestyle='--', label=f'Mean: {mean_res:.3f}')
|
||||||
|
|
||||||
|
ax.set_title(f'{title}\n(Std Dev: {std_res:.3f})', fontsize=14)
|
||||||
|
ax.set_xlabel(xlabel, fontsize=12)
|
||||||
|
ax.set_ylabel(ylabel, fontsize=12)
|
||||||
|
ax.legend()
|
||||||
|
ax.grid(True, axis='y', linestyle='--', alpha=0.6)
|
||||||
|
fig.tight_layout()
|
||||||
|
|
||||||
|
return fig
|
@ -1,18 +1,103 @@
|
|||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import Optional
|
import torch.optim as optim
|
||||||
from utils.config_model import ModelConfig
|
import pytorch_lightning as pl
|
||||||
|
import torchmetrics
|
||||||
|
from typing import Optional, Dict, Any, Union, List, Tuple
|
||||||
|
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
||||||
|
|
||||||
class LSTMForecastModel(nn.Module):
|
# Assuming config_model is in sibling directory utils/
|
||||||
def __init__(self, model_config: ModelConfig):
|
from forecasting_model.utils.config_model import ModelConfig, TrainingConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class LSTMForecastLightningModule(pl.LightningModule):
|
||||||
|
"""
|
||||||
|
PyTorch Lightning Module for LSTM-based time series forecasting.
|
||||||
|
|
||||||
|
Encapsulates the model architecture, training, validation, and test logic.
|
||||||
|
Uses torchmetrics for efficient metric calculation.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
train_config: TrainingConfig,
|
||||||
|
input_size: int,
|
||||||
|
target_scaler: Optional[Union[StandardScaler, MinMaxScaler]] = None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = model_config
|
|
||||||
self.use_residual_skips = model_config.use_residual_skips
|
|
||||||
|
|
||||||
# TODO: Initialize LSTM layers
|
# --- Validate & Store Configs ---
|
||||||
# TODO: Initialize dropout
|
# Validate the input_size passed during instantiation
|
||||||
# TODO: Initialize output layer
|
if input_size <= 0:
|
||||||
# TODO: Initialize residual connection layer if needed
|
raise ValueError("`input_size` must be provided as a positive integer during model instantiation.")
|
||||||
|
|
||||||
|
# Store the validated input_size directly for use in layer definitions
|
||||||
|
self._input_size = input_size # Use a temporary attribute before hparams are saved
|
||||||
|
|
||||||
|
# Ensure forecast_horizon is set in the config for the output layer
|
||||||
|
if not hasattr(model_config, 'forecast_horizon') or model_config.forecast_horizon is None or model_config.forecast_horizon <= 0:
|
||||||
|
raise ValueError("ModelConfig requires `forecast_horizon` to be set and positive.")
|
||||||
|
self.output_size = model_config.forecast_horizon
|
||||||
|
|
||||||
|
# Store configurations - input_size argument will be saved via save_hyperparameters
|
||||||
|
self.model_config = model_config
|
||||||
|
self.train_config = train_config
|
||||||
|
self.target_scaler = target_scaler # Store scaler for this fold
|
||||||
|
|
||||||
|
# Use save_hyperparameters() to automatically log configs and allow loading
|
||||||
|
# Pass input_size explicitly to be saved in hparams
|
||||||
|
# Exclude scaler as it's stateful and fold-specific
|
||||||
|
self.save_hyperparameters('model_config', 'train_config', 'input_size', ignore=['target_scaler'])
|
||||||
|
|
||||||
|
# --- Define Model Layers ---
|
||||||
|
# Access input_size via hparams now
|
||||||
|
self.lstm = nn.LSTM(
|
||||||
|
input_size=self.hparams.input_size,
|
||||||
|
hidden_size=self.hparams.model_config.hidden_size,
|
||||||
|
num_layers=self.hparams.model_config.num_layers,
|
||||||
|
batch_first=True, # Input shape: (batch, seq_len, features)
|
||||||
|
dropout=self.hparams.model_config.dropout if self.hparams.model_config.num_layers > 1 else 0.0
|
||||||
|
)
|
||||||
|
self.dropout = nn.Dropout(self.hparams.model_config.dropout)
|
||||||
|
|
||||||
|
# Output layer maps LSTM hidden state to the forecast horizon
|
||||||
|
# We typically take the output of the last time step
|
||||||
|
self.fc = nn.Linear(self.hparams.model_config.hidden_size, self.output_size)
|
||||||
|
|
||||||
|
# Optional residual connection handling
|
||||||
|
self.use_residual_skips = self.hparams.model_config.use_residual_skips
|
||||||
|
self.residual_projection = None
|
||||||
|
if self.use_residual_skips:
|
||||||
|
# If input size doesn't match hidden size, project input
|
||||||
|
if self.hparams.input_size != self.hparams.model_config.hidden_size:
|
||||||
|
# Use hparams.input_size here
|
||||||
|
self.residual_projection = nn.Linear(self.hparams.input_size, self.hparams.model_config.hidden_size)
|
||||||
|
logger.info("Residual connections enabled.")
|
||||||
|
if self.residual_projection:
|
||||||
|
logger.info("Residual projection layer added.")
|
||||||
|
|
||||||
|
# --- Define Loss Function ---
|
||||||
|
if self.hparams.train_config.loss_function.upper() == 'MSE':
|
||||||
|
self.criterion = nn.MSELoss()
|
||||||
|
elif self.hparams.train_config.loss_function.upper() == 'MAE':
|
||||||
|
self.criterion = nn.L1Loss()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported loss function: {self.hparams.train_config.loss_function}")
|
||||||
|
|
||||||
|
# --- Define Metrics (TorchMetrics) ---
|
||||||
|
metrics = torchmetrics.MetricCollection([
|
||||||
|
torchmetrics.MeanAbsoluteError(),
|
||||||
|
torchmetrics.MeanSquaredError(squared=False) # RMSE
|
||||||
|
])
|
||||||
|
self.train_metrics = metrics.clone(prefix='train_')
|
||||||
|
self.val_metrics = metrics.clone(prefix='val_')
|
||||||
|
self.test_metrics = metrics.clone(prefix='test_')
|
||||||
|
|
||||||
|
self.val_mae_original_scale = torchmetrics.MeanAbsoluteError()
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@ -24,5 +109,184 @@ class LSTMForecastModel(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
Predictions tensor of shape (batch_size, forecast_horizon)
|
Predictions tensor of shape (batch_size, forecast_horizon)
|
||||||
"""
|
"""
|
||||||
# TODO: Implement forward pass with optional residual connections
|
# LSTM forward pass
|
||||||
pass
|
lstm_out, (hidden, cell) = self.lstm(x) # Shape: (batch, seq_len, hidden_size)
|
||||||
|
|
||||||
|
# Output from the last time step
|
||||||
|
last_time_step_out = lstm_out[:, -1, :] # Shape: (batch_size, hidden_size)
|
||||||
|
|
||||||
|
# Apply dropout
|
||||||
|
last_time_step_out = self.dropout(last_time_step_out)
|
||||||
|
|
||||||
|
# Optional Residual Connection
|
||||||
|
if self.use_residual_skips:
|
||||||
|
residual = x[:, -1, :] # Input from the last time step: (batch_size, input_size)
|
||||||
|
if self.residual_projection:
|
||||||
|
residual = self.residual_projection(residual) # Project to hidden_size
|
||||||
|
last_time_step_out = last_time_step_out + residual
|
||||||
|
|
||||||
|
# Final fully connected layer
|
||||||
|
predictions = self.fc(last_time_step_out) # Shape: (batch_size, output_size/horizon)
|
||||||
|
|
||||||
|
return predictions # Shape: (batch_size, forecast_horizon)
|
||||||
|
|
||||||
|
def _calculate_loss(self, outputs, targets):
|
||||||
|
# Ensure shapes match before loss calculation
|
||||||
|
if outputs.shape != targets.shape:
|
||||||
|
# Squeeze potential extra dim: (batch, horizon, 1) -> (batch, horizon)
|
||||||
|
if outputs.ndim == targets.ndim + 1 and outputs.shape[-1] == 1:
|
||||||
|
outputs = outputs.squeeze(-1)
|
||||||
|
if outputs.shape != targets.shape:
|
||||||
|
raise ValueError(f"Output shape {outputs.shape} doesn't match target shape {targets.shape} for loss calculation.")
|
||||||
|
return self.criterion(outputs, targets)
|
||||||
|
|
||||||
|
def _inverse_transform(self, data: torch.Tensor) -> Optional[torch.Tensor]:
|
||||||
|
"""Helper to inverse transform data using the stored target scaler."""
|
||||||
|
if self.target_scaler is None:
|
||||||
|
# logger.warning("Cannot inverse transform: target_scaler not available.")
|
||||||
|
return None # Cannot inverse transform
|
||||||
|
|
||||||
|
# Scaler expects 2D input (N, 1)
|
||||||
|
# Ensure data is on CPU and is float64 for sklearn scaler typically
|
||||||
|
data_cpu = data.detach().cpu().numpy().astype(np.float64)
|
||||||
|
original_shape = data_cpu.shape
|
||||||
|
if data_cpu.ndim == 1:
|
||||||
|
data_flat = data_cpu.reshape(-1, 1)
|
||||||
|
elif data_cpu.ndim == 2: # (batch, horizon)
|
||||||
|
data_flat = data_cpu.reshape(-1, 1)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unexpected shape for inverse transform: {original_shape}. Reshaping to (-1, 1).")
|
||||||
|
data_flat = data_cpu.reshape(-1, 1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
inversed_np = self.target_scaler.inverse_transform(data_flat)
|
||||||
|
# Return as tensor on the original device
|
||||||
|
inversed_tensor = torch.from_numpy(inversed_np).float().to(data.device)
|
||||||
|
# Reshape back? Or keep flat? Keep flat for direct metric use often.
|
||||||
|
return inversed_tensor.flatten()
|
||||||
|
# return inversed_tensor.reshape(original_shape) # If original shape needed
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to inverse transform data: {e}", exc_info=True)
|
||||||
|
return None # Return None if inverse transform fails
|
||||||
|
|
||||||
|
|
||||||
|
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
|
||||||
|
x, y = batch # Shapes: x=(batch, seq_len, features), y=(batch, horizon)
|
||||||
|
outputs = self(x) # Scaled outputs: (batch, horizon)
|
||||||
|
loss = self._calculate_loss(outputs, y)
|
||||||
|
|
||||||
|
# Log scaled metrics
|
||||||
|
metrics = self.train_metrics(outputs, y) # Update internal state
|
||||||
|
self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||||
|
self.log_dict(self.train_metrics, on_step=False, on_epoch=True, logger=True) # Log all metrics in collection
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
|
||||||
|
x, y = batch
|
||||||
|
outputs = self(x) # Scaled outputs
|
||||||
|
loss = self._calculate_loss(outputs, y)
|
||||||
|
|
||||||
|
# Log scaled metrics
|
||||||
|
metrics = self.val_metrics(outputs, y) # Update internal state
|
||||||
|
self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||||
|
self.log_dict(self.val_metrics, on_step=False, on_epoch=True, logger=True)
|
||||||
|
|
||||||
|
# Log MAE on ORIGINAL scale if scaler is available (often the primary metric for checkpointing/Optuna)
|
||||||
|
if self.target_scaler is not None:
|
||||||
|
outputs_inv = self._inverse_transform(outputs)
|
||||||
|
y_inv = self._inverse_transform(y)
|
||||||
|
|
||||||
|
if outputs_inv is not None and y_inv is not None:
|
||||||
|
# Ensure shapes are compatible (flattened by _inverse_transform)
|
||||||
|
if outputs_inv.shape == y_inv.shape:
|
||||||
|
self.val_mae_original_scale.update(outputs_inv, y_inv)
|
||||||
|
self.log('val_mae_orig_scale', self.val_mae_original_scale, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Shape mismatch after inverse transform in validation: Preds {outputs_inv.shape}, Targets {y_inv.shape}")
|
||||||
|
else:
|
||||||
|
logger.warning("Could not compute original scale MAE in validation due to inverse transform failure.")
|
||||||
|
|
||||||
|
|
||||||
|
def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
|
||||||
|
# Optional: Keep this method ONLY if you want trainer.test() to log metrics.
|
||||||
|
# For getting predictions for evaluation, use predict_step.
|
||||||
|
# If evaluate_fold_predictions handles all metrics, this can be simplified or removed.
|
||||||
|
# Let's simplify it for now to only log loss if needed.
|
||||||
|
try:
|
||||||
|
x, y = batch
|
||||||
|
outputs = self(x)
|
||||||
|
loss = self._calculate_loss(outputs, y)
|
||||||
|
# Log scaled test metrics if you still want trainer.test() to report them
|
||||||
|
metrics = self.test_metrics(outputs, y)
|
||||||
|
self.log('test_loss_step', loss, on_step=True, on_epoch=False) # Log step loss if needed
|
||||||
|
self.log_dict(self.test_metrics, on_step=False, on_epoch=True, logger=True)
|
||||||
|
# No return needed if just logging
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error occurred in test_step for batch {batch_idx}: {e}", exc_info=True)
|
||||||
|
# Optionally log something to indicate failure
|
||||||
|
|
||||||
|
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Runs inference for prediction and returns scaled predictions and targets.
|
||||||
|
'batch' might contain only features depending on the DataLoader setup for predict.
|
||||||
|
Let's assume the test_loader yields (x, y) pairs for convenience here.
|
||||||
|
"""
|
||||||
|
if isinstance(batch, (list, tuple)) and len(batch) == 2:
|
||||||
|
x, y = batch
|
||||||
|
else:
|
||||||
|
# Assume batch contains only features if not a pair
|
||||||
|
x = batch
|
||||||
|
y = None # No targets available during prediction if dataloader only yields features
|
||||||
|
|
||||||
|
outputs = self(x) # Scaled outputs
|
||||||
|
|
||||||
|
result = {'preds_scaled': outputs.detach().cpu()}
|
||||||
|
if y is not None:
|
||||||
|
# Include targets if they were part of the batch (e.g., using test_loader for predict)
|
||||||
|
result['targets_scaled'] = y.detach().cpu()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def configure_optimizers(self) -> Union[optim.Optimizer, Tuple[List[optim.Optimizer], List[Dict[str, Any]]]]:
|
||||||
|
"""
|
||||||
|
Configure the optimizer (Adam) and optional LR scheduler.
|
||||||
|
"""
|
||||||
|
optimizer = optim.Adam(
|
||||||
|
self.parameters(),
|
||||||
|
lr=self.hparams.train_config.learning_rate # Access lr via hparams
|
||||||
|
)
|
||||||
|
logger.info(f"Configured Adam optimizer with LR: {self.hparams.train_config.learning_rate}")
|
||||||
|
|
||||||
|
# Optional LR Scheduler configuration
|
||||||
|
scheduler_config = None
|
||||||
|
if hasattr(self.hparams.train_config, 'scheduler_step_size') and \
|
||||||
|
self.hparams.train_config.scheduler_step_size is not None and \
|
||||||
|
hasattr(self.hparams.train_config, 'scheduler_gamma') and \
|
||||||
|
self.hparams.train_config.scheduler_gamma is not None:
|
||||||
|
|
||||||
|
if self.hparams.train_config.scheduler_step_size > 0 and 0 < self.hparams.train_config.scheduler_gamma < 1:
|
||||||
|
logger.info(f"Configuring StepLR scheduler with step_size={self.hparams.train_config.scheduler_step_size} "
|
||||||
|
f"and gamma={self.hparams.train_config.scheduler_gamma}")
|
||||||
|
scheduler = optim.lr_scheduler.StepLR(
|
||||||
|
optimizer,
|
||||||
|
step_size=self.hparams.train_config.scheduler_step_size,
|
||||||
|
gamma=self.hparams.train_config.scheduler_gamma
|
||||||
|
)
|
||||||
|
scheduler_config = {
|
||||||
|
'scheduler': scheduler,
|
||||||
|
'interval': 'epoch', # or 'step'
|
||||||
|
'frequency': 1,
|
||||||
|
'monitor': 'val_loss', # Optional: Only step if monitor improves (for ReduceLROnPlateau)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
logger.warning("Scheduler parameters provided but invalid (step_size must be >0, 0<gamma<1). No scheduler configured.")
|
||||||
|
|
||||||
|
# Example for ReduceLROnPlateau (if needed later)
|
||||||
|
# scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
|
||||||
|
# scheduler_config = {'scheduler': scheduler, 'monitor': 'val_loss'}
|
||||||
|
|
||||||
|
if scheduler_config:
|
||||||
|
return [optimizer], [scheduler_config]
|
||||||
|
else:
|
||||||
|
return optimizer
|
@ -1,50 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from typing import Optional, Dict, Any
|
|
||||||
from ..utils.config_model import TrainingConfig
|
|
||||||
|
|
||||||
class Trainer:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
train_loader: DataLoader,
|
|
||||||
val_loader: DataLoader,
|
|
||||||
loss_fn: nn.Module,
|
|
||||||
device: torch.device,
|
|
||||||
config: TrainingConfig,
|
|
||||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
|
||||||
target_scaler: Optional[Any] = None
|
|
||||||
):
|
|
||||||
self.model = model
|
|
||||||
self.train_loader = train_loader
|
|
||||||
self.val_loader = val_loader
|
|
||||||
self.loss_fn = loss_fn
|
|
||||||
self.device = device
|
|
||||||
self.config = config
|
|
||||||
self.scheduler = scheduler
|
|
||||||
self.target_scaler = target_scaler
|
|
||||||
|
|
||||||
# TODO: Initialize optimizer (Adam)
|
|
||||||
# TODO: Initialize early stopping if configured
|
|
||||||
|
|
||||||
def train_epoch(self) -> Dict[str, float]:
|
|
||||||
"""
|
|
||||||
Train for one epoch.
|
|
||||||
"""
|
|
||||||
# TODO: Implement training loop for one epoch
|
|
||||||
pass
|
|
||||||
|
|
||||||
def evaluate(self, loader: DataLoader) -> Dict[str, float]:
|
|
||||||
"""
|
|
||||||
Evaluate model on given data loader.
|
|
||||||
"""
|
|
||||||
# TODO: Implement evaluation with metrics on original scale
|
|
||||||
pass
|
|
||||||
|
|
||||||
def train(self) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Main training loop with validation and early stopping.
|
|
||||||
"""
|
|
||||||
# TODO: Implement full training loop with validation
|
|
||||||
pass
|
|
@ -2,4 +2,32 @@
|
|||||||
Utility functions and classes for the forecasting model.
|
Utility functions and classes for the forecasting model.
|
||||||
|
|
||||||
This package contains configuration models, helper functions, and other utilities.
|
This package contains configuration models, helper functions, and other utilities.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Expose configuration models
|
||||||
|
from .config_model import (
|
||||||
|
MainConfig,
|
||||||
|
DataConfig,
|
||||||
|
FeatureConfig,
|
||||||
|
ModelConfig,
|
||||||
|
TrainingConfig,
|
||||||
|
CrossValidationConfig,
|
||||||
|
EvaluationConfig,
|
||||||
|
OptunaConfig,
|
||||||
|
WaveletTransformConfig, # Expose nested configs if they might be used directly
|
||||||
|
ClippingConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
# Define __all__ for explicit public API
|
||||||
|
__all__ = [
|
||||||
|
"MainConfig",
|
||||||
|
"DataConfig",
|
||||||
|
"FeatureConfig",
|
||||||
|
"ModelConfig",
|
||||||
|
"TrainingConfig",
|
||||||
|
"CrossValidationConfig",
|
||||||
|
"EvaluationConfig",
|
||||||
|
"OptunaConfig",
|
||||||
|
"WaveletTransformConfig",
|
||||||
|
"ClippingConfig",
|
||||||
|
]
|
@ -1,62 +1,151 @@
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
from typing import Optional, List, Union
|
from typing import Optional, List, Union, Literal
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
# --- Nested Configs ---
|
||||||
|
|
||||||
class WaveletTransformConfig(BaseModel):
|
class WaveletTransformConfig(BaseModel):
|
||||||
|
"""Configuration for optional wavelet transform features."""
|
||||||
apply: bool = False
|
apply: bool = False
|
||||||
target_or_feature: str = "target"
|
target_or_feature: Literal['target', 'feature'] = "target"
|
||||||
wavelet_type: str = "db4"
|
wavelet_type: str = "db4"
|
||||||
level: int = 3
|
level: int = Field(3, gt=0) # Level must be positive
|
||||||
use_coeffs: List[str] = ["approx", "detail_1"]
|
use_coeffs: List[str] = ["approx", "detail_1"]
|
||||||
|
|
||||||
|
class ClippingConfig(BaseModel):
|
||||||
|
"""Configuration for optional feature clipping."""
|
||||||
|
apply: bool = False
|
||||||
|
clip_min: float = -5.0
|
||||||
|
clip_max: float = 5.0
|
||||||
|
|
||||||
|
@model_validator(mode='after')
|
||||||
|
def check_clip_range(self) -> 'ClippingConfig':
|
||||||
|
if self.apply and self.clip_max <= self.clip_min:
|
||||||
|
raise ValueError(f'clip_max ({self.clip_max}) must be greater than clip_min ({self.clip_min}) when clipping is applied.')
|
||||||
|
return self
|
||||||
|
|
||||||
|
# --- Main Config Sections ---
|
||||||
|
|
||||||
class DataConfig(BaseModel):
|
class DataConfig(BaseModel):
|
||||||
data_path: str
|
"""Configuration related to data loading and initial preparation."""
|
||||||
datetime_col: str
|
data_path: str = Field(..., description="Path to the input CSV data file.")
|
||||||
target_col: str
|
# --- Raw Data Specifics ---
|
||||||
|
raw_datetime_col: str = Field(..., description="Name of the raw datetime column in the CSV (e.g., 'MTU (CET/CEST)')")
|
||||||
|
raw_target_col: str = Field(..., description="Name of the raw target/price column in the CSV (e.g., 'Day-ahead Price [EUR/MWh]')")
|
||||||
|
|
||||||
|
raw_datetime_format: str = '%d.%m.%Y %H:%M' # Example, make it configurable if needed
|
||||||
|
|
||||||
|
# --- Standardized Names & Processing ---
|
||||||
|
datetime_col: str = Field(..., description="Standardized name for the datetime index after processing (e.g., 'Timestamp')")
|
||||||
|
target_col: str = Field(..., description="Standardized name for the target column after processing (e.g., 'Price')")
|
||||||
|
expected_frequency: Optional[str] = Field('h', description="Expected pandas frequency string (e.g., 'h', 'D', '15min'). If null, no frequency check/setting is performed.")
|
||||||
|
fill_initial_target_nans: bool = Field(True, description="Forward/backward fill NaNs in the target column immediately after loading?")
|
||||||
|
|
||||||
class FeatureConfig(BaseModel):
|
class FeatureConfig(BaseModel):
|
||||||
sequence_length: int
|
"""Configuration for feature engineering and preprocessing."""
|
||||||
forecast_horizon: int
|
sequence_length: int = Field(..., gt=0)
|
||||||
lags: List[int]
|
forecast_horizon: int = Field(..., gt=0)
|
||||||
rolling_window_sizes: List[int]
|
lags: List[int] = []
|
||||||
use_time_features: bool
|
rolling_window_sizes: List[int] = []
|
||||||
scaling_method: Optional[str] = None
|
use_time_features: bool = True
|
||||||
|
sinus_curve: bool = False # Added
|
||||||
|
cosin_curve: bool = False # Added
|
||||||
wavelet_transform: Optional[WaveletTransformConfig] = None
|
wavelet_transform: Optional[WaveletTransformConfig] = None
|
||||||
|
fill_nan: Optional[Union[str, float, int]] = 'ffill' # Added (e.g., 'ffill', 0)
|
||||||
|
clipping: ClippingConfig = ClippingConfig() # Default instance
|
||||||
|
scaling_method: Optional[Literal['standard', 'minmax']] = 'standard' # Added literal validation
|
||||||
|
|
||||||
|
@field_validator('lags', 'rolling_window_sizes')
|
||||||
|
@classmethod
|
||||||
|
def check_positive_list_values(cls, v: List[int]) -> List[int]:
|
||||||
|
if any(val <= 0 for val in v):
|
||||||
|
raise ValueError('Lists lags/rolling_window_sizes must contain only positive values')
|
||||||
|
return v
|
||||||
|
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
input_size: Optional[int] = None # Will be calculated
|
"""Configuration for the forecasting model architecture."""
|
||||||
hidden_size: int
|
# input_size: Optional[int] = Field(None, gt=0) # Removed: Determined dynamically
|
||||||
num_layers: int
|
hidden_size: int = Field(..., gt=0)
|
||||||
dropout: float
|
num_layers: int = Field(..., gt=0)
|
||||||
|
dropout: float = Field(..., ge=0.0, le=1.0)
|
||||||
use_residual_skips: bool = False
|
use_residual_skips: bool = False
|
||||||
output_size: Optional[int] = None # Will be calculated
|
# Add forecast_horizon here to ensure LightningModule gets it directly
|
||||||
|
forecast_horizon: Optional[int] = Field(None, gt=0) # Will be set from FeatureConfig
|
||||||
|
|
||||||
class TrainingConfig(BaseModel):
|
class TrainingConfig(BaseModel):
|
||||||
batch_size: int
|
"""Configuration for the training process (PyTorch Lightning)."""
|
||||||
epochs: int
|
batch_size: int = Field(..., gt=0)
|
||||||
learning_rate: float
|
epochs: int = Field(..., gt=0) # Max epochs
|
||||||
loss_function: str
|
learning_rate: float = Field(..., gt=0.0)
|
||||||
device: str
|
loss_function: Literal['MSE', 'MAE'] = 'MSE'
|
||||||
early_stopping_patience: Optional[int] = None
|
# device: str = 'auto' # Handled by PL Trainer accelerator/devices args
|
||||||
scheduler_step_size: Optional[int] = None
|
early_stopping_patience: Optional[int] = Field(None, ge=1) # Patience must be >= 1 if set
|
||||||
scheduler_gamma: Optional[float] = None
|
scheduler_step_size: Optional[int] = Field(None, gt=0)
|
||||||
|
scheduler_gamma: Optional[float] = Field(None, gt=0.0, lt=1.0)
|
||||||
|
gradient_clip_val: Optional[float] = Field(None, ge=0.0) # Added
|
||||||
|
num_workers: int = Field(0, ge=0) # Added
|
||||||
|
precision: Literal[16, 32, 64, 'bf16'] = 32 # Added
|
||||||
|
|
||||||
class CrossValidationConfig(BaseModel):
|
class CrossValidationConfig(BaseModel):
|
||||||
n_splits: int
|
"""Configuration for time series cross-validation."""
|
||||||
test_size_fraction: float
|
n_splits: int = Field(..., gt=0)
|
||||||
val_size_fraction: float
|
test_size_fraction: float = Field(..., gt=0.0, lt=1.0, description="Fraction of the fixed training window size for the test set.")
|
||||||
initial_train_size: Optional[Union[int, float]] = None
|
val_size_fraction: float = Field(..., gt=0.0, lt=1.0, description="Fraction of the fixed training window size for the validation set.")
|
||||||
|
initial_train_size: Optional[Union[int, float]] = Field(None, gt=0.0, description="Size of the fixed training window (absolute number or fraction of total data > 0). If null, estimated automatically.")
|
||||||
|
|
||||||
class EvaluationConfig(BaseModel):
|
class EvaluationConfig(BaseModel):
|
||||||
metrics: List[str]
|
"""Configuration for the final evaluation process."""
|
||||||
eval_batch_size: int
|
# metrics: List[str] = ['MAE', 'RMSE'] # Defined internally now
|
||||||
save_plots: bool
|
eval_batch_size: int = Field(..., gt=0)
|
||||||
plot_sample_size: int
|
save_plots: bool = True
|
||||||
|
plot_sample_size: Optional[int] = Field(1000, gt=0) # Max points for plots
|
||||||
|
|
||||||
|
class OptunaConfig(BaseModel):
|
||||||
|
"""Optional configuration for Optuna hyperparameter optimization."""
|
||||||
|
enabled: bool = False
|
||||||
|
n_trials: int = Field(20, gt=0)
|
||||||
|
storage: Optional[str] = None # e.g., "sqlite:///output/hpo_results/study.db"
|
||||||
|
direction: Literal['minimize', 'maximize'] = 'minimize'
|
||||||
|
metric_to_optimize: str = 'val_mae_orig_scale'
|
||||||
|
pruning: bool = True
|
||||||
|
|
||||||
|
# --- Top-Level Configuration Model ---
|
||||||
|
|
||||||
class MainConfig(BaseModel):
|
class MainConfig(BaseModel):
|
||||||
|
"""Main configuration model nesting all sections."""
|
||||||
|
project_name: str = "TimeSeriesForecasting"
|
||||||
|
random_seed: Optional[int] = 42 # Added top-level seed
|
||||||
|
|
||||||
data: DataConfig
|
data: DataConfig
|
||||||
features: FeatureConfig
|
features: FeatureConfig
|
||||||
model: ModelConfig
|
model: ModelConfig # ModelConfig no longer contains input_size
|
||||||
training: TrainingConfig
|
training: TrainingConfig
|
||||||
cross_validation: CrossValidationConfig
|
cross_validation: CrossValidationConfig
|
||||||
evaluation: EvaluationConfig
|
evaluation: EvaluationConfig
|
||||||
|
optuna: Optional[OptunaConfig] = OptunaConfig() # Added optional Optuna config
|
||||||
|
|
||||||
|
@model_validator(mode='after')
|
||||||
|
def check_forecast_horizon_consistency(self) -> 'MainConfig':
|
||||||
|
# Ensure model config gets forecast horizon from features config if not set
|
||||||
|
if self.features and self.model:
|
||||||
|
if self.model.forecast_horizon is None:
|
||||||
|
# If model config doesn't have it, set it from features config
|
||||||
|
self.model.forecast_horizon = self.features.forecast_horizon
|
||||||
|
elif self.model.forecast_horizon != self.features.forecast_horizon:
|
||||||
|
# If both are set but differ, raise error
|
||||||
|
raise ValueError(
|
||||||
|
f"ModelConfig forecast_horizon ({self.model.forecast_horizon}) must match "
|
||||||
|
f"FeatureConfig forecast_horizon ({self.features.forecast_horizon})."
|
||||||
|
)
|
||||||
|
# After potential setting, ensure model.forecast_horizon is actually set
|
||||||
|
if self.model and (self.model.forecast_horizon is None or self.model.forecast_horizon <= 0):
|
||||||
|
raise ValueError("ModelConfig requires a positive forecast_horizon (must be set in features config if not set explicitly in model config).")
|
||||||
|
|
||||||
|
# Input size check is removed as it's not part of static config anymore
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
# Example configuration for Pydantic itself
|
||||||
|
validate_assignment = True # Re-validate on assignment
|
||||||
|
# extra = 'forbid' # Forbid extra fields not defined in schema
|
468
forecasting_model_run.py
Normal file
468
forecasting_model_run.py
Normal file
@ -0,0 +1,468 @@
|
|||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
import yaml
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
|
||||||
|
from pytorch_lightning.loggers import CSVLogger
|
||||||
|
|
||||||
|
# Import necessary components from your project structure
|
||||||
|
# Assuming forecasting_model is a package installable or in PYTHONPATH
|
||||||
|
from forecasting_model.utils.config_model import MainConfig
|
||||||
|
from forecasting_model.data_processing import (
|
||||||
|
load_raw_data,
|
||||||
|
TimeSeriesCrossValidationSplitter,
|
||||||
|
prepare_fold_data_and_loaders
|
||||||
|
)
|
||||||
|
from forecasting_model.model import LSTMForecastLightningModule
|
||||||
|
from forecasting_model.evaluation import evaluate_fold_predictions
|
||||||
|
from typing import Dict, List, Any, Optional
|
||||||
|
|
||||||
|
# Silence overly verbose libraries if needed
|
||||||
|
mpl_logger = logging.getLogger('matplotlib')
|
||||||
|
mpl_logger.setLevel(logging.WARNING)
|
||||||
|
pil_logger = logging.getLogger('PIL')
|
||||||
|
pil_logger.setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
# --- Basic Logging Setup ---
|
||||||
|
# Configure logging early. Level might be adjusted by config.
|
||||||
|
logging.basicConfig(level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)-7s - %(message)s',
|
||||||
|
datefmt='%H:%M:%S')
|
||||||
|
# Get the root logger
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
# --- Argument Parsing ---
|
||||||
|
def parse_arguments():
|
||||||
|
"""Parses command-line arguments."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run the Time Series Forecasting training pipeline.",
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-c', '--config',
|
||||||
|
type=str,
|
||||||
|
default='config.yaml',
|
||||||
|
help="Path to the YAML configuration file."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--seed',
|
||||||
|
type=int,
|
||||||
|
default=None, # Default to None, use config value if not provided
|
||||||
|
help="Override random seed defined in config."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--debug',
|
||||||
|
action='store_true',
|
||||||
|
help="Override log level to DEBUG."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--output-dir',
|
||||||
|
type=str,
|
||||||
|
default='output/cv_results', # Default output base directory
|
||||||
|
help="Base directory for saving cross-validation results (checkpoints, logs, plots)."
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
# --- Helper Functions ---
|
||||||
|
|
||||||
|
def load_config(config_path: Path) -> MainConfig:
|
||||||
|
"""
|
||||||
|
Load and validate configuration from YAML file using Pydantic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: Path to the YAML configuration file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated MainConfig object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If the config file doesn't exist.
|
||||||
|
yaml.YAMLError: If the file is not valid YAML.
|
||||||
|
pydantic.ValidationError: If the config doesn't match the schema.
|
||||||
|
"""
|
||||||
|
if not config_path.is_file():
|
||||||
|
logger.error(f"Configuration file not found at: {config_path}")
|
||||||
|
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||||
|
|
||||||
|
logger.info(f"Loading configuration from: {config_path}")
|
||||||
|
try:
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
config_dict = yaml.safe_load(f)
|
||||||
|
|
||||||
|
# Validate configuration using Pydantic model
|
||||||
|
config = MainConfig(**config_dict)
|
||||||
|
logger.info("Configuration loaded and validated successfully.")
|
||||||
|
return config
|
||||||
|
except yaml.YAMLError as e:
|
||||||
|
logger.error(f"Error parsing YAML file {config_path}: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
except Exception as e: # Catches Pydantic validation errors too
|
||||||
|
logger.error(f"Error validating configuration {config_path}: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def set_seeds(seed: Optional[int] = 42) -> None:
|
||||||
|
"""
|
||||||
|
Set random seeds for reproducibility across libraries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: The seed value to use. If None, uses default 42.
|
||||||
|
"""
|
||||||
|
if seed is None:
|
||||||
|
seed = 42
|
||||||
|
logger.warning(f"No seed provided, using default seed: {seed}")
|
||||||
|
else:
|
||||||
|
logger.info(f"Setting random seed: {seed}")
|
||||||
|
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
# Ensure reproducibility for CUDA operations where possible
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed) # For multi-GPU
|
||||||
|
# These settings can slow down training but improve reproducibility
|
||||||
|
# torch.backends.cudnn.deterministic = True
|
||||||
|
# torch.backends.cudnn.benchmark = False
|
||||||
|
# PyTorch Lightning seeding (optional, as we seed torch directly)
|
||||||
|
# pl.seed_everything(seed, workers=True) # workers=True ensures dataloader reproducibility
|
||||||
|
|
||||||
|
def aggregate_cv_metrics(all_fold_metrics: List[Dict[str, float]]) -> Dict[str, Dict[str, float]]:
|
||||||
|
"""
|
||||||
|
Calculate mean and standard deviation of metrics across folds.
|
||||||
|
Handles potential NaN values by ignoring them.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_fold_metrics: A list where each element is a dictionary of
|
||||||
|
metrics for one fold (e.g., {'MAE': v1, 'RMSE': v2}).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary where keys are metric names and values are dicts
|
||||||
|
containing 'mean' and 'std' for that metric across folds.
|
||||||
|
Example: {'MAE': {'mean': m, 'std': s}, 'RMSE': {'mean': m2, 'std': s2}}
|
||||||
|
"""
|
||||||
|
if not all_fold_metrics:
|
||||||
|
logger.warning("Received empty list for metric aggregation.")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
aggregated: Dict[str, Dict[str, float]] = {}
|
||||||
|
# Get metric names from the first valid fold's results
|
||||||
|
first_valid_metrics = next((m for m in all_fold_metrics if m), None)
|
||||||
|
if not first_valid_metrics:
|
||||||
|
logger.warning("No valid fold metrics found for aggregation.")
|
||||||
|
return {}
|
||||||
|
metric_names = list(first_valid_metrics.keys())
|
||||||
|
|
||||||
|
for metric in metric_names:
|
||||||
|
# Collect values for this metric across all folds, ignoring NaNs
|
||||||
|
values = [fold_metrics.get(metric) for fold_metrics in all_fold_metrics if fold_metrics and metric in fold_metrics]
|
||||||
|
valid_values = [v for v in values if v is not None and not np.isnan(v)]
|
||||||
|
|
||||||
|
if not valid_values:
|
||||||
|
logger.warning(f"No valid values found for metric '{metric}' across folds.")
|
||||||
|
mean_val = np.nan
|
||||||
|
std_val = np.nan
|
||||||
|
else:
|
||||||
|
mean_val = float(np.mean(valid_values))
|
||||||
|
std_val = float(np.std(valid_values))
|
||||||
|
logger.debug(f"Aggregated '{metric}': Mean={mean_val:.4f}, Std={std_val:.4f} from {len(valid_values)} folds.")
|
||||||
|
|
||||||
|
aggregated[metric] = {'mean': mean_val, 'std': std_val}
|
||||||
|
|
||||||
|
return aggregated
|
||||||
|
|
||||||
|
def save_results(results: Dict, filename: Path):
|
||||||
|
"""Save dictionary results to a JSON file."""
|
||||||
|
try:
|
||||||
|
filename.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(filename, 'w') as f:
|
||||||
|
json.dump(results, f, indent=4)
|
||||||
|
logger.info(f"Saved results to {filename}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save results to {filename}: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Main Training & Evaluation Function ---
|
||||||
|
def run_training_pipeline(config: MainConfig, output_base_dir: Path):
|
||||||
|
"""Runs the full cross-validation training and evaluation pipeline."""
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# --- Data Loading ---
|
||||||
|
try:
|
||||||
|
df = load_raw_data(config.data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.critical(f"Failed to load raw data: {e}", exc_info=True)
|
||||||
|
sys.exit(1) # Cannot proceed without data
|
||||||
|
|
||||||
|
# --- Cross-Validation Setup ---
|
||||||
|
try:
|
||||||
|
cv_splitter = TimeSeriesCrossValidationSplitter(config.cross_validation, len(df))
|
||||||
|
except ValueError as e:
|
||||||
|
logger.critical(f"Failed to initialize CV splitter: {e}", exc_info=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
all_fold_test_metrics: List[Dict[str, float]] = []
|
||||||
|
all_fold_best_val_scores: Dict[int, Optional[float]] = {} # Store best val score per fold
|
||||||
|
|
||||||
|
# --- Cross-Validation Loop ---
|
||||||
|
logger.info(f"Starting {config.cross_validation.n_splits}-Fold Cross-Validation...")
|
||||||
|
for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split()):
|
||||||
|
fold_start_time = time.perf_counter()
|
||||||
|
fold_id = fold_num + 1
|
||||||
|
logger.info(f"--- Starting Fold {fold_id}/{config.cross_validation.n_splits} ---")
|
||||||
|
|
||||||
|
fold_output_dir = output_base_dir / f"fold_{fold_id:02d}"
|
||||||
|
fold_output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
logger.debug(f"Fold output directory: {fold_output_dir}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# --- Per-Fold Data Preparation ---
|
||||||
|
logger.info("Preparing data loaders for the fold...")
|
||||||
|
train_loader, val_loader, test_loader, target_scaler, input_size = prepare_fold_data_and_loaders(
|
||||||
|
full_df=df,
|
||||||
|
train_idx=train_idx,
|
||||||
|
val_idx=val_idx,
|
||||||
|
test_idx=test_idx,
|
||||||
|
target_col=config.data.target_col, # Pass target col name explicitly
|
||||||
|
feature_config=config.features,
|
||||||
|
train_config=config.training,
|
||||||
|
eval_config=config.evaluation
|
||||||
|
)
|
||||||
|
logger.info(f"Data loaders prepared. Input size determined: {input_size}")
|
||||||
|
|
||||||
|
# --- Model Initialization ---
|
||||||
|
# Pass input_size directly, ModelConfig no longer holds it.
|
||||||
|
# Ensure forecast horizon is consistent (checked in MainConfig validation)
|
||||||
|
current_model_config = config.model # Use the validated model config
|
||||||
|
|
||||||
|
model = LSTMForecastLightningModule(
|
||||||
|
model_config=current_model_config, # Does not contain input_size
|
||||||
|
train_config=config.training,
|
||||||
|
input_size=input_size, # Pass the dynamically determined input_size
|
||||||
|
target_scaler=target_scaler # Pass the fold-specific scaler
|
||||||
|
)
|
||||||
|
logger.info("LSTMForecastLightningModule initialized.")
|
||||||
|
|
||||||
|
# --- PyTorch Lightning Callbacks ---
|
||||||
|
# Monitor the validation MAE on the original scale (logged by LightningModule)
|
||||||
|
monitor_metric = "val_mae_orig_scale"
|
||||||
|
monitor_mode = "min"
|
||||||
|
|
||||||
|
early_stop_callback = None
|
||||||
|
if config.training.early_stopping_patience is not None and config.training.early_stopping_patience > 0:
|
||||||
|
early_stop_callback = EarlyStopping(
|
||||||
|
monitor=monitor_metric,
|
||||||
|
min_delta=0.0001, # Minimum change to qualify as improvement
|
||||||
|
patience=config.training.early_stopping_patience,
|
||||||
|
verbose=True,
|
||||||
|
mode=monitor_mode
|
||||||
|
)
|
||||||
|
logger.info(f"Enabled EarlyStopping: monitor='{monitor_metric}', patience={config.training.early_stopping_patience}")
|
||||||
|
|
||||||
|
# Checkpoint callback to save the best model based on validation metric
|
||||||
|
checkpoint_callback = ModelCheckpoint(
|
||||||
|
dirpath=fold_output_dir / "checkpoints",
|
||||||
|
filename=f"best_model_fold_{fold_id}", # {{epoch}}-{{val_loss:.2f}} etc. possible
|
||||||
|
save_top_k=1,
|
||||||
|
monitor=monitor_metric,
|
||||||
|
mode=monitor_mode,
|
||||||
|
verbose=True
|
||||||
|
)
|
||||||
|
logger.info(f"Enabled ModelCheckpoint: monitor='{monitor_metric}', mode='{monitor_mode}'")
|
||||||
|
|
||||||
|
# Learning rate monitor callback
|
||||||
|
lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
||||||
|
|
||||||
|
callbacks = [checkpoint_callback, lr_monitor]
|
||||||
|
if early_stop_callback:
|
||||||
|
callbacks.append(early_stop_callback)
|
||||||
|
|
||||||
|
# --- PyTorch Lightning Logger ---
|
||||||
|
# Log metrics to a CSV file within the fold directory
|
||||||
|
pl_logger = CSVLogger(save_dir=str(output_base_dir), name=f"fold_{fold_id:02d}", version='logs')
|
||||||
|
logger.info(f"Using CSVLogger, logs will be saved in: {pl_logger.log_dir}")
|
||||||
|
|
||||||
|
# --- PyTorch Lightning Trainer ---
|
||||||
|
# Determine accelerator and devices based on PyTorch check
|
||||||
|
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
|
||||||
|
devices = 1 if accelerator == 'gpu' else None # Or specify specific GPU IDs [0], [1] etc.
|
||||||
|
precision = getattr(config.training, 'precision', 32) # Default to 32-bit
|
||||||
|
|
||||||
|
trainer = pl.Trainer(
|
||||||
|
accelerator=accelerator,
|
||||||
|
devices=devices,
|
||||||
|
max_epochs=config.training.epochs,
|
||||||
|
callbacks=callbacks,
|
||||||
|
logger=pl_logger,
|
||||||
|
log_every_n_steps=max(1, len(train_loader)//10), # Log ~10 times per epoch
|
||||||
|
enable_progress_bar=True, # Set to False for less verbose runs (e.g., HPO)
|
||||||
|
gradient_clip_val=getattr(config.training, 'gradient_clip_val', None),
|
||||||
|
precision=precision,
|
||||||
|
# deterministic=True, # For stricter reproducibility (can slow down)
|
||||||
|
)
|
||||||
|
logger.info(f"Initialized PyTorch Lightning Trainer: accelerator='{accelerator}', devices={devices}, precision={precision}")
|
||||||
|
|
||||||
|
# --- Training ---
|
||||||
|
logger.info(f"Starting training for Fold {fold_id}...")
|
||||||
|
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
|
||||||
|
logger.info(f"Training finished for Fold {fold_id}.")
|
||||||
|
|
||||||
|
# Store best validation score for this fold
|
||||||
|
best_val_score = trainer.checkpoint_callback.best_model_score
|
||||||
|
best_model_path = trainer.checkpoint_callback.best_model_path
|
||||||
|
all_fold_best_val_scores[fold_id] = best_val_score.item() if best_val_score else None
|
||||||
|
if best_val_score is not None:
|
||||||
|
logger.info(f"Best validation score ({monitor_metric}) for Fold {fold_id}: {all_fold_best_val_scores[fold_id]:.4f}")
|
||||||
|
logger.info(f"Best model checkpoint path: {best_model_path}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Could not retrieve best validation score/path for Fold {fold_id} (metric: {monitor_metric}). Evaluation might use last model.")
|
||||||
|
best_model_path = None # Ensure evaluation doesn't try to load 'best' if checkpointing failed
|
||||||
|
|
||||||
|
# --- Prediction on Test Set ---
|
||||||
|
# Use trainer.predict() to get model outputs
|
||||||
|
logger.info(f"Starting prediction for Fold {fold_id} using best checkpoint...")
|
||||||
|
# predict_step returns dict {'preds_scaled': ..., 'targets_scaled': ...}
|
||||||
|
# We pass the test_loader here, which yields (x, y) pairs, so predict_step will include targets
|
||||||
|
prediction_results_list = trainer.predict(
|
||||||
|
# model=model, # Not needed if using ckpt_path
|
||||||
|
ckpt_path=best_model_path if best_model_path else 'last', # Load best model or last if best failed
|
||||||
|
dataloaders=test_loader
|
||||||
|
# return_predictions=True # Default is True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if prediction returned results
|
||||||
|
if not prediction_results_list:
|
||||||
|
logger.error(f"Predict phase did not return any results for Fold {fold_id}. Check predict_step and logs.")
|
||||||
|
fold_metrics = {'MAE': np.nan, 'RMSE': np.nan}
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
# Concatenate predictions and targets from predict_step results
|
||||||
|
all_preds_scaled = torch.cat([batch_res['preds_scaled'] for batch_res in prediction_results_list], dim=0).numpy()
|
||||||
|
# Check if targets were included (they should be if using test_loader)
|
||||||
|
if 'targets_scaled' in prediction_results_list[0]:
|
||||||
|
all_targets_scaled = torch.cat([batch_res['targets_scaled'] for batch_res in prediction_results_list], dim=0).numpy()
|
||||||
|
else:
|
||||||
|
# This case shouldn't happen if using test_loader, but good safeguard
|
||||||
|
logger.error(f"Targets not found in prediction results for Fold {fold_id}. Cannot evaluate.")
|
||||||
|
raise ValueError("Targets missing from prediction results.")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Final Evaluation & Plotting ---
|
||||||
|
logger.info(f"Processing prediction results for Fold {fold_id}...")
|
||||||
|
fold_metrics = evaluate_fold_predictions(
|
||||||
|
y_true_scaled=all_targets_scaled,
|
||||||
|
y_pred_scaled=all_preds_scaled,
|
||||||
|
target_scaler=target_scaler, # Use the scaler from this fold
|
||||||
|
eval_config=config.evaluation,
|
||||||
|
fold_num=fold_num, # Pass zero-based index
|
||||||
|
output_dir=output_base_dir, # Base dir for saving plots etc.
|
||||||
|
# time_index=df.iloc[test_idx].index # Pass time index if needed
|
||||||
|
)
|
||||||
|
# Save fold metrics
|
||||||
|
save_results(fold_metrics, fold_output_dir / "test_metrics.json")
|
||||||
|
|
||||||
|
except KeyError as e:
|
||||||
|
logger.error(f"KeyError processing prediction results for Fold {fold_id}: Missing key {e}. Check predict_step return format.", exc_info=True)
|
||||||
|
fold_metrics = {'MAE': np.nan, 'RMSE': np.nan}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing prediction results for Fold {fold_id}: {e}", exc_info=True)
|
||||||
|
fold_metrics = {'MAE': np.nan, 'RMSE': np.nan}
|
||||||
|
|
||||||
|
all_fold_test_metrics.append(fold_metrics)
|
||||||
|
|
||||||
|
# --- (Optional) Log final test metrics using trainer.test() ---
|
||||||
|
# If you want the metrics logged by test_step aggregated, call test now.
|
||||||
|
# logger.info(f"Logging final test metrics via trainer.test() for Fold {fold_id}...")
|
||||||
|
# try:
|
||||||
|
# trainer.test(ckpt_path=best_model_path if best_model_path else 'last', dataloaders=test_loader, verbose=False)
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.warning(f"trainer.test() call failed for Fold {fold_id}: {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Catch errors during the fold processing (data prep, training, prediction, eval)
|
||||||
|
logger.error(f"An error occurred during Fold {fold_id} pipeline: {e}", exc_info=True)
|
||||||
|
all_fold_test_metrics.append({'MAE': np.nan, 'RMSE': np.nan})
|
||||||
|
|
||||||
|
|
||||||
|
# --- Cleanup per fold ---
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
logger.debug("Cleared CUDA cache.")
|
||||||
|
|
||||||
|
fold_end_time = time.perf_counter()
|
||||||
|
logger.info(f"--- Finished Fold {fold_id} in {fold_end_time - fold_start_time:.2f} seconds ---")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Aggregation and Final Reporting ---
|
||||||
|
logger.info("Cross-validation finished. Aggregating results...")
|
||||||
|
aggregated_metrics = aggregate_cv_metrics(all_fold_test_metrics)
|
||||||
|
|
||||||
|
# Save aggregated results
|
||||||
|
final_results = {
|
||||||
|
'aggregated_test_metrics': aggregated_metrics,
|
||||||
|
'per_fold_test_metrics': all_fold_test_metrics,
|
||||||
|
'per_fold_best_val_scores': all_fold_best_val_scores,
|
||||||
|
}
|
||||||
|
save_results(final_results, output_base_dir / "aggregated_cv_results.json")
|
||||||
|
|
||||||
|
|
||||||
|
# Log final results
|
||||||
|
logger.info("--- Aggregated Cross-Validation Test Results ---")
|
||||||
|
if aggregated_metrics:
|
||||||
|
for metric, stats in aggregated_metrics.items():
|
||||||
|
logger.info(f"{metric}: {stats['mean']:.4f} ± {stats['std']:.4f}")
|
||||||
|
else:
|
||||||
|
logger.warning("No metrics available for aggregation.")
|
||||||
|
logger.info("-------------------------------------------------")
|
||||||
|
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
logger.info(f"Training pipeline finished successfully in {end_time - start_time:.2f} seconds.")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Main Execution ---
|
||||||
|
def run():
|
||||||
|
"""Main execution function."""
|
||||||
|
args = parse_arguments()
|
||||||
|
config_path = Path(args.config)
|
||||||
|
output_dir = Path(args.output_dir)
|
||||||
|
|
||||||
|
# Adjust log level if debug flag is set
|
||||||
|
if args.debug:
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
logger.debug("# --- Debug mode enabled. --- #")
|
||||||
|
|
||||||
|
# --- Configuration Loading ---
|
||||||
|
try:
|
||||||
|
config = load_config(config_path)
|
||||||
|
except Exception:
|
||||||
|
# Error already logged in load_config
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# --- Seed Setting ---
|
||||||
|
# Use command-line seed if provided, otherwise use config seed
|
||||||
|
seed = args.seed if args.seed is not None else getattr(config, 'random_seed', 42)
|
||||||
|
set_seeds(seed)
|
||||||
|
|
||||||
|
# --- Pipeline Execution ---
|
||||||
|
try:
|
||||||
|
run_training_pipeline(config, output_dir)
|
||||||
|
|
||||||
|
except SystemExit as e:
|
||||||
|
logger.warning(f"Pipeline exited with code {e.code}.")
|
||||||
|
sys.exit(e.code) # Propagate exit code
|
||||||
|
except Exception as e:
|
||||||
|
logger.critical(f"An critical error occurred during pipeline execution: {e}", exc_info=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
3
main.py
3
main.py
@ -10,9 +10,6 @@ from forecasting_model.data_processing import (
|
|||||||
TimeSeriesCrossValidationSplitter,
|
TimeSeriesCrossValidationSplitter,
|
||||||
prepare_fold_data_and_loaders
|
prepare_fold_data_and_loaders
|
||||||
)
|
)
|
||||||
from forecasting_model.model import LSTMForecastModel
|
|
||||||
from forecasting_model.trainer import Trainer
|
|
||||||
from forecasting_model.evaluation import evaluate_fold
|
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
|
395
optuna_run.py
Normal file
395
optuna_run.py
Normal file
@ -0,0 +1,395 @@
|
|||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import copy # For deep copying config
|
||||||
|
from pathlib import Path
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import optuna
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
|
||||||
|
# Import the Optuna callback for pruning
|
||||||
|
from optuna.integration.pytorch_lightning import PyTorchLightningPruningCallback
|
||||||
|
|
||||||
|
# Import necessary components from the forecasting_model package
|
||||||
|
from forecasting_model.utils.config_model import MainConfig
|
||||||
|
from forecasting_model.data_processing import (
|
||||||
|
load_raw_data,
|
||||||
|
TimeSeriesCrossValidationSplitter,
|
||||||
|
prepare_fold_data_and_loaders
|
||||||
|
)
|
||||||
|
from forecasting_model.model import LSTMForecastLightningModule
|
||||||
|
# We don't need evaluation functions here, Optuna optimizes based on validation metrics
|
||||||
|
# from forecasting_model.evaluation import ...
|
||||||
|
from typing import Dict, List, Any, Optional
|
||||||
|
|
||||||
|
# Import helper functions from forecasting_model.py (or move them to a shared utils file)
|
||||||
|
# For now, let's redefine simplified versions or assume they exist in utils
|
||||||
|
from forecasting_model_run import load_config, set_seeds # Assuming these are accessible
|
||||||
|
|
||||||
|
# Silence overly verbose libraries if needed
|
||||||
|
mpl_logger = logging.getLogger('matplotlib')
|
||||||
|
mpl_logger.setLevel(logging.WARNING)
|
||||||
|
pil_logger = logging.getLogger('PIL')
|
||||||
|
pil_logger.setLevel(logging.WARNING)
|
||||||
|
pl_logger = logging.getLogger('pytorch_lightning')
|
||||||
|
pl_logger.setLevel(logging.INFO) # Keep PL logs, but maybe set higher later
|
||||||
|
|
||||||
|
# --- Basic Logging Setup ---
|
||||||
|
logging.basicConfig(level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)-25s - %(levelname)-7s - %(message)s',
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S')
|
||||||
|
root_logger = logging.getLogger()
|
||||||
|
logger = logging.getLogger(__name__) # Logger for this script
|
||||||
|
optuna_lg = logging.getLogger('optuna') # Optuna's logger
|
||||||
|
|
||||||
|
|
||||||
|
# --- Argument Parsing ---
|
||||||
|
def parse_arguments():
|
||||||
|
"""Parses command-line arguments for Optuna HPO."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run Hyperparameter Optimization using Optuna for Time Series Forecasting.",
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-c', '--config',
|
||||||
|
type=str,
|
||||||
|
default='config.yaml',
|
||||||
|
help="Path to the BASE YAML configuration file."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--output-dir',
|
||||||
|
type=str,
|
||||||
|
default='output/hpo_results',
|
||||||
|
help="Directory for saving Optuna study database and potentially best trial info."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--study-name',
|
||||||
|
type=str,
|
||||||
|
default='lstm_forecasting_hpo',
|
||||||
|
help="Name for the Optuna study."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--n-trials',
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
help="Number of Optuna trials to run."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--storage-db',
|
||||||
|
type=str,
|
||||||
|
default=None, # Default to in-memory if not specified
|
||||||
|
help="Optuna storage database URL (e.g., 'sqlite:///output/hpo_results/study.db'). If None, uses in-memory storage."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--metric-to-optimize',
|
||||||
|
type=str,
|
||||||
|
default='val_mae_orig_scale',
|
||||||
|
help="Metric logged during validation to optimize (must match metric name in LightningModule)."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--direction',
|
||||||
|
type=str,
|
||||||
|
default='minimize',
|
||||||
|
choices=['minimize', 'maximize'],
|
||||||
|
help="Direction for Optuna optimization."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--pruning',
|
||||||
|
action='store_true',
|
||||||
|
help="Enable Optuna's trial pruning based on intermediate validation results."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--seed',
|
||||||
|
type=int,
|
||||||
|
default=42, # Fixed seed for the HPO process itself
|
||||||
|
help="Random seed for the main HPO script."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--debug',
|
||||||
|
action='store_true',
|
||||||
|
help="Override log level to DEBUG."
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
# --- Optuna Objective Function ---
|
||||||
|
def objective(
|
||||||
|
trial: optuna.Trial,
|
||||||
|
base_config: MainConfig, # Pass the loaded base config
|
||||||
|
df: pd.DataFrame, # Pass the loaded data
|
||||||
|
output_base_dir: Path, # Base dir for any potential trial artifacts (usually avoid saving checkpoints here)
|
||||||
|
metric_to_optimize: str,
|
||||||
|
enable_pruning: bool
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Optuna objective function. Trains and evaluates one set of hyperparameters
|
||||||
|
using cross-validation and returns the average validation metric.
|
||||||
|
"""
|
||||||
|
logger.info(f"\n--- Starting Optuna Trial {trial.number} ---")
|
||||||
|
trial_start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# --- 1. Suggest Hyperparameters ---
|
||||||
|
# Make a deep copy of the base config to modify for this trial
|
||||||
|
# Using dict conversion and back might be easier than Pydantic's copy for deep nested updates
|
||||||
|
try:
|
||||||
|
trial_config_dict = copy.deepcopy(base_config.dict()) # Convert to dict for easier modification
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to deep copy base configuration: {e}")
|
||||||
|
raise # Cannot proceed without config
|
||||||
|
|
||||||
|
# Suggest values for hyperparameters we want to tune
|
||||||
|
# Example suggestions (adjust ranges and types as needed):
|
||||||
|
trial_config_dict['training']['learning_rate'] = trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True)
|
||||||
|
trial_config_dict['training']['batch_size'] = trial.suggest_categorical('batch_size', [32, 64, 128])
|
||||||
|
trial_config_dict['model']['hidden_size'] = trial.suggest_int('hidden_size', 32, 256, step=32)
|
||||||
|
trial_config_dict['model']['num_layers'] = trial.suggest_int('num_layers', 1, 4)
|
||||||
|
trial_config_dict['model']['dropout'] = trial.suggest_float('dropout', 0.0, 0.5, step=0.1)
|
||||||
|
# Example: Suggest sequence length? (Requires careful handling as it affects data prep)
|
||||||
|
# trial_config_dict['features']['sequence_length'] = trial.suggest_int('sequence_length', 24, 168, step=24)
|
||||||
|
|
||||||
|
# --- 2. Re-validate Trial Config (Optional but Recommended) ---
|
||||||
|
try:
|
||||||
|
trial_config = MainConfig(**trial_config_dict)
|
||||||
|
logger.debug(f"Trial {trial.number} Config: {trial_config.training} {trial_config.model} {trial_config.features}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Trial {trial.number}: Invalid configuration generated from suggested parameters: {e}")
|
||||||
|
# Return a high value (for minimization) to penalize invalid configs
|
||||||
|
return float('inf')
|
||||||
|
|
||||||
|
|
||||||
|
# --- 3. Run Cross-Validation for this Trial ---
|
||||||
|
cv_splitter = TimeSeriesCrossValidationSplitter(trial_config.cross_validation, len(df))
|
||||||
|
fold_best_val_metrics: List[Optional[float]] = []
|
||||||
|
|
||||||
|
for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split()):
|
||||||
|
fold_id = fold_num + 1
|
||||||
|
logger.info(f"Trial {trial.number}, Fold {fold_id}: Starting fold evaluation.")
|
||||||
|
fold_start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Create a temporary directory for this specific trial+fold if needed (usually avoid for HPO)
|
||||||
|
# fold_trial_dir = output_base_dir / f"trial_{trial.number}" / f"fold_{fold_id:02d}"
|
||||||
|
# fold_trial_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# --- Per-Fold Data Prep ---
|
||||||
|
# Use trial_config for batch sizes etc.
|
||||||
|
train_loader, val_loader, _, target_scaler, input_size = prepare_fold_data_and_loaders(
|
||||||
|
full_df=df, train_idx=train_idx, val_idx=val_idx, test_idx=test_idx, # Test loader not needed here
|
||||||
|
target_col=trial_config.data.target_col,
|
||||||
|
feature_config=trial_config.features,
|
||||||
|
train_config=trial_config.training,
|
||||||
|
eval_config=trial_config.evaluation # Pass eval for batch size if needed by prep?
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Model Instantiation ---
|
||||||
|
current_model_config = trial_config.model.copy(update={'input_size': input_size,
|
||||||
|
'forecast_horizon': trial_config.features.forecast_horizon})
|
||||||
|
model = LSTMForecastLightningModule(
|
||||||
|
model_config=current_model_config,
|
||||||
|
train_config=trial_config.training,
|
||||||
|
target_scaler=target_scaler
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Callbacks for this Trial/Fold ---
|
||||||
|
# Monitor the metric Optuna cares about
|
||||||
|
monitor_mode = "min" if args.direction == "minimize" else "max"
|
||||||
|
|
||||||
|
callbacks = []
|
||||||
|
if trial_config.training.early_stopping_patience is not None and trial_config.training.early_stopping_patience > 0:
|
||||||
|
early_stopping = EarlyStopping(
|
||||||
|
monitor=metric_to_optimize,
|
||||||
|
patience=trial_config.training.early_stopping_patience,
|
||||||
|
mode=monitor_mode,
|
||||||
|
verbose=False # Less verbose during HPO
|
||||||
|
)
|
||||||
|
callbacks.append(early_stopping)
|
||||||
|
|
||||||
|
# Add Optuna Pruning Callback
|
||||||
|
if enable_pruning:
|
||||||
|
pruning_callback = PyTorchLightningPruningCallback(trial, monitor=metric_to_optimize)
|
||||||
|
callbacks.append(pruning_callback)
|
||||||
|
|
||||||
|
# Optional: LR Monitor
|
||||||
|
# callbacks.append(LearningRateMonitor(logging_interval='epoch'))
|
||||||
|
|
||||||
|
# --- Trainer for this Trial/Fold ---
|
||||||
|
trainer = pl.Trainer(
|
||||||
|
accelerator='gpu' if torch.cuda.is_available() else 'cpu',
|
||||||
|
devices=1 if torch.cuda.is_available() else None,
|
||||||
|
max_epochs=trial_config.training.epochs,
|
||||||
|
callbacks=callbacks,
|
||||||
|
logger=False, # Disable default PL logging during HPO
|
||||||
|
enable_checkpointing=False, # Disable checkpoint saving during HPO
|
||||||
|
enable_progress_bar=False, # Disable progress bar for cleaner logs
|
||||||
|
enable_model_summary=False, # Disable model summary
|
||||||
|
gradient_clip_val=getattr(trial_config.training, 'gradient_clip_val', None),
|
||||||
|
precision=getattr(trial_config.training, 'precision', 32),
|
||||||
|
# Log GPU usage if available?
|
||||||
|
# log_gpu_memory='min_max',
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Fit the Model ---
|
||||||
|
logger.info(f"Trial {trial.number}, Fold {fold_id}: Fitting model...")
|
||||||
|
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
|
||||||
|
|
||||||
|
# --- Get Best Validation Score for Pruning/Reporting ---
|
||||||
|
# Access the monitored metric value from the trainer's logged metrics or callback state
|
||||||
|
# Ensure the key matches exactly what's logged in validation_step
|
||||||
|
best_val_score = trainer.callback_metrics.get(metric_to_optimize)
|
||||||
|
|
||||||
|
if best_val_score is None:
|
||||||
|
logger.warning(f"Trial {trial.number}, Fold {fold_id}: Metric '{metric_to_optimize}' not found in trainer metrics. Using inf/nan.")
|
||||||
|
# Handle cases where training might have failed or metric wasn't logged
|
||||||
|
best_val_score = float('inf') if monitor_mode == 'min' else float('-inf') # Return worst possible value
|
||||||
|
else:
|
||||||
|
best_val_score = best_val_score.item() # Convert tensor to float
|
||||||
|
logger.info(f"Trial {trial.number}, Fold {fold_id}: Best validation score ({metric_to_optimize}) = {best_val_score:.4f}")
|
||||||
|
|
||||||
|
fold_best_val_metrics.append(best_val_score)
|
||||||
|
|
||||||
|
# --- Intermediate Pruning Report (Optional but Recommended) ---
|
||||||
|
# Report the intermediate value (best score for this fold) to Optuna
|
||||||
|
# trial.report(best_val_score, fold_id) # Report score at step `fold_id`
|
||||||
|
# Check if the trial should be pruned based on reported values
|
||||||
|
# if trial.should_prune():
|
||||||
|
# logger.info(f"Trial {trial.number}: Pruned after fold {fold_id}.")
|
||||||
|
# raise optuna.TrialPruned()
|
||||||
|
|
||||||
|
logger.info(f"Trial {trial.number}, Fold {fold_id}: Finished in {time.perf_counter() - fold_start_time:.2f}s")
|
||||||
|
|
||||||
|
except optuna.TrialPruned:
|
||||||
|
# Re-raise prune exception to let Optuna handle it
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Trial {trial.number}, Fold {fold_id}: Failed with error: {e}", exc_info=True)
|
||||||
|
# Record a failure for this fold (e.g., append NaN or worst value)
|
||||||
|
fold_best_val_metrics.append(float('inf') if monitor_mode == 'min' else float('-inf'))
|
||||||
|
# Optionally: Break the CV loop for this trial if one fold fails catastrophically?
|
||||||
|
# break
|
||||||
|
|
||||||
|
|
||||||
|
# --- 4. Calculate Average Metric Across Folds ---
|
||||||
|
if not fold_best_val_metrics:
|
||||||
|
logger.error(f"Trial {trial.number}: No validation results obtained across folds.")
|
||||||
|
return float('inf') # Return worst value
|
||||||
|
|
||||||
|
# Handle potential infinities or NaNs from failed folds
|
||||||
|
valid_scores = [s for s in fold_best_val_metrics if np.isfinite(s)]
|
||||||
|
if not valid_scores:
|
||||||
|
logger.error(f"Trial {trial.number}: All folds failed or produced non-finite scores.")
|
||||||
|
return float('inf')
|
||||||
|
|
||||||
|
average_val_metric = np.mean(valid_scores)
|
||||||
|
logger.info(f"--- Trial {trial.number}: Finished ---")
|
||||||
|
logger.info(f" Average validation {metric_to_optimize}: {average_val_metric:.5f}")
|
||||||
|
logger.info(f" Total trial time: {time.perf_counter() - trial_start_time:.2f}s")
|
||||||
|
|
||||||
|
# --- 5. Return Metric for Optuna ---
|
||||||
|
return average_val_metric
|
||||||
|
|
||||||
|
|
||||||
|
# --- Main HPO Execution ---
|
||||||
|
def run_hpo():
|
||||||
|
"""Main execution function for HPO."""
|
||||||
|
global args # Make args accessible in objective (simplifies passing) - or use functools.partial
|
||||||
|
args = parse_arguments()
|
||||||
|
config_path = Path(args.config)
|
||||||
|
output_dir = Path(args.output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True) # Ensure output dir exists
|
||||||
|
|
||||||
|
# Adjust log level if debug flag is set
|
||||||
|
if args.debug:
|
||||||
|
root_logger.setLevel(logging.DEBUG)
|
||||||
|
optuna_lg.setLevel(logging.DEBUG)
|
||||||
|
pl_logger.setLevel(logging.DEBUG)
|
||||||
|
logger.debug("Debug mode enabled.")
|
||||||
|
else:
|
||||||
|
# Reduce verbosity during HPO runs
|
||||||
|
optuna_lg.setLevel(logging.WARNING)
|
||||||
|
pl_logger.setLevel(logging.INFO) # Keep INFO for PL start/end messages
|
||||||
|
|
||||||
|
# --- Configuration Loading ---
|
||||||
|
try:
|
||||||
|
base_config = load_config(config_path)
|
||||||
|
except Exception:
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# --- Seed Setting (for HPO script itself) ---
|
||||||
|
set_seeds(args.seed)
|
||||||
|
|
||||||
|
# --- Load Data Once ---
|
||||||
|
# Assume data doesn't change based on HPs (unless sequence_length is tuned heavily)
|
||||||
|
try:
|
||||||
|
logger.info("Loading base dataset...")
|
||||||
|
df = load_raw_data(base_config.data)
|
||||||
|
logger.info("Base dataset loaded.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.critical(f"Failed to load raw data for HPO: {e}", exc_info=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Optuna Study Setup ---
|
||||||
|
storage_path = args.storage_db
|
||||||
|
if storage_path:
|
||||||
|
# Ensure directory exists if using SQLite file storage
|
||||||
|
db_path = Path(storage_path.replace("sqlite:///", ""))
|
||||||
|
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
storage_path = f"sqlite:///{db_path.resolve()}" # Use absolute path
|
||||||
|
logger.info(f"Using Optuna storage: {storage_path}")
|
||||||
|
else:
|
||||||
|
logger.warning("No Optuna storage DB specified, using in-memory storage (results lost on exit).")
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create or load the study
|
||||||
|
study = optuna.create_study(
|
||||||
|
study_name=args.study_name,
|
||||||
|
storage=storage_path,
|
||||||
|
direction=args.direction,
|
||||||
|
load_if_exists=True, # Load previous results if study exists
|
||||||
|
pruner=optuna.pruners.MedianPruner() if args.pruning else optuna.pruners.NopPruner() # Example pruner
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Run Optimization ---
|
||||||
|
logger.info(f"Starting Optuna optimization: study='{args.study_name}', n_trials={args.n_trials}, metric='{args.metric_to_optimize}', direction='{args.direction}'")
|
||||||
|
study.optimize(
|
||||||
|
lambda trial: objective(trial, base_config, df, output_dir, args.metric_to_optimize, args.pruning),
|
||||||
|
n_trials=args.n_trials,
|
||||||
|
timeout=None # Optional: Set timeout in seconds
|
||||||
|
# Optional: Add callbacks (e.g., logging callback)
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Report Best Trial ---
|
||||||
|
logger.info("--- Optuna HPO Finished ---")
|
||||||
|
logger.info(f"Number of finished trials: {len(study.trials)}")
|
||||||
|
|
||||||
|
best_trial = study.best_trial
|
||||||
|
logger.info(f"Best trial number: {best_trial.number}")
|
||||||
|
logger.info(f" Best validation {args.metric_to_optimize}: {best_trial.value:.5f}")
|
||||||
|
logger.info(" Best hyperparameters:")
|
||||||
|
for key, value in best_trial.params.items():
|
||||||
|
logger.info(f" {key}: {value}")
|
||||||
|
|
||||||
|
# --- Save Best Hyperparameters (Optional) ---
|
||||||
|
best_params_file = output_dir / f"{args.study_name}_best_params.json"
|
||||||
|
try:
|
||||||
|
with open(best_params_file, 'w') as f:
|
||||||
|
import json
|
||||||
|
json.dump(best_trial.params, f, indent=4)
|
||||||
|
logger.info(f"Best hyperparameters saved to {best_params_file}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save best parameters: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.critical(f"An critical error occurred during the Optuna study: {e}", exc_info=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_hpo()
|
68
rules.md
Normal file
68
rules.md
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
## Coding Style Rules & Paradigms
|
||||||
|
|
||||||
|
### Configuration Driven
|
||||||
|
|
||||||
|
* Uses **Pydantic** heavily (`utils/models.py`) to define configuration schemas.
|
||||||
|
* Configuration is loaded from a **YAML file** (`config.yaml`) at runtime (`main.py`).
|
||||||
|
* The `Config` object (or relevant sub-configs) is **passed down** through function calls, making parameters explicit.
|
||||||
|
* A **template configuration** (`_config.yaml`) is often included within the package.
|
||||||
|
|
||||||
|
### Modularity
|
||||||
|
|
||||||
|
* Code is organized into **logical sub-packages** (`io`, `processing`, `pipeline`, `visualization`, `synthesis`, `utils`, `validation`).
|
||||||
|
* Each sub-package has an `__init__.py`, often used to **expose key functions/classes** to the parent level.
|
||||||
|
* **Helper functions** (often internal, prefixed with `_`) are frequently used to break down complex logic within modules (e.g., `processing/surface_helper.py`, `pipeline/runner.py` helpers).
|
||||||
|
|
||||||
|
### Logging
|
||||||
|
|
||||||
|
* Uses the standard **`logging` library**.
|
||||||
|
* Loggers are obtained per module using `logger = logging.getLogger(__name__)`.
|
||||||
|
* **Logging levels** (`DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL`) are used semantically:
|
||||||
|
* `DEBUG`: Verbose internal steps.
|
||||||
|
* `INFO`: Major milestones/stages.
|
||||||
|
* `WARNING`: Recoverable issues or deviations.
|
||||||
|
* `ERROR`: Specific failures that might be handled.
|
||||||
|
* `CRITICAL`: Fatal errors causing exits.
|
||||||
|
* **Root logger configuration** happens in `main.py`, potentially adjusted based on the `debug` flag in the config.
|
||||||
|
|
||||||
|
### Error Handling ("Fail Hard but Helpful")
|
||||||
|
|
||||||
|
* The main entry point (`main.py`) uses a **top-level `try...except` block** to catch major failures during config loading or pipeline execution.
|
||||||
|
* **Critical errors** are logged with tracebacks (`exc_info=True`) and result in `sys.exit(1)`.
|
||||||
|
* Functions often return a **tuple indicating success/failure** and results/error messages (e.g., `(result_data, error_message)` or `(success_flag, result_data)`).
|
||||||
|
* Lower-level functions may log errors/warnings but **allow processing to continue** if feasible and configured (e.g., `allow_segmentation_errors`).
|
||||||
|
* **Specific exceptions** are caught where appropriate (`FileNotFoundError`, `pydicom.errors.InvalidDicomError`, `ValueError`, etc.).
|
||||||
|
* **Pydantic validation errors** during config loading are treated as critical.
|
||||||
|
|
||||||
|
### Typing
|
||||||
|
|
||||||
|
* Consistent use of **Python type hints** (`typing` module: `Optional`, `Dict`, `List`, `Tuple`, `Union`, `Callable`, `Literal`, etc.).
|
||||||
|
* **Pydantic models** rely heavily on type hints for validation.
|
||||||
|
|
||||||
|
### Data Structures
|
||||||
|
|
||||||
|
* **Pydantic models** define primary configuration and result structures (e.g., `Config`, `ProcessingResult`, `CombinedDicomDataset`).
|
||||||
|
* **NumPy arrays** are fundamental for image/volume data.
|
||||||
|
* **Pandas DataFrames** are used for aggregating results, metadata, and creating reports (Excel).
|
||||||
|
* Standard **Python dictionaries** are used extensively for metadata and intermediate data passing.
|
||||||
|
|
||||||
|
### Naming Conventions
|
||||||
|
|
||||||
|
* Follows **PEP 8**: `snake_case` for variables and functions, `PascalCase` for classes.
|
||||||
|
* Internal helper functions are typically prefixed with an **underscore (`_`)**.
|
||||||
|
* Constants are defined in **`UPPER_SNAKE_CASE`** (often in a dedicated `utils/constants.py`).
|
||||||
|
|
||||||
|
### Documentation
|
||||||
|
|
||||||
|
* **Docstrings** are present for most functions and classes, explaining purpose, arguments (`Args:`), and return values (`Returns:`).
|
||||||
|
* Minimal **inline comments**; code aims to be self-explanatory, with docstrings providing higher-level context. (Matches your custom instructions).
|
||||||
|
|
||||||
|
### Dependencies
|
||||||
|
|
||||||
|
* Managed via `requirements.txt`.
|
||||||
|
* Uses standard **scientific Python stack** (`numpy`, `pandas`, `scipy`, `scikit-image`, `matplotlib`), **domain-specific libraries** (`pydicom`), **utility libraries** (`PyYAML`, `joblib`, `tqdm`, `openpyxl`), and `pydantic` for configuration/validation.
|
||||||
|
|
||||||
|
### Parallelism
|
||||||
|
|
||||||
|
* Uses **`joblib`** for parallel processing, configurable via the main config (`mainprocess_core_count`, `subprocess_core_count`).
|
||||||
|
* Parallelism can be **disabled** via configuration or debug mode.
|
Reference in New Issue
Block a user