Files
entrix_case_challange/optimizer/forecasting/ensemble.py
2025-05-03 20:46:14 +02:00

188 lines
10 KiB
Python

import logging
from typing import List, Dict, Any, Optional
import numpy as np
import pandas as pd
import torch
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from .base import ForecastProvider
from forecasting_model.utils import FeatureConfig
from forecasting_model.train.model import LSTMForecastLightningModule
from forecasting_model import engineer_features
from optimizer.forecasting.utils import interpolate_forecast
logger = logging.getLogger(__name__)
class EnsembleProvider(ForecastProvider):
"""Provides forecasts using an ensemble of trained LSTM models."""
def __init__(
self,
fold_artifacts: List[Dict[str, Any]],
ensemble_method: str,
ensemble_feature_config: FeatureConfig, # Assumed consistent across folds by loading logic
ensemble_target_col: str, # Assumed consistent
):
if not fold_artifacts:
raise ValueError("EnsembleProvider requires at least one fold artifact.")
self.fold_artifacts = fold_artifacts
self.ensemble_method = ensemble_method
# Store common config for reference, but use fold-specific details in get_forecast
self.ensemble_feature_config = ensemble_feature_config
self.ensemble_target_col = ensemble_target_col
self.common_forecast_horizons = sorted(ensemble_feature_config.forecast_horizon) # Assumed consistent
# Calculate max lookback needed across all folds
max_lookback = 0
for i, fold in enumerate(fold_artifacts):
try:
fold_feature_config = fold['feature_config']
fold_seq_len = fold_feature_config.sequence_length
feature_lookback = 0
if fold_feature_config.lags:
feature_lookback = max(feature_lookback, max(fold_feature_config.lags))
if fold_feature_config.rolling_window_sizes:
feature_lookback = max(feature_lookback, max(w - 1 for w in fold_feature_config.rolling_window_sizes))
fold_total_lookback = fold_seq_len + feature_lookback
max_lookback = max(max_lookback, fold_total_lookback)
except KeyError as e:
raise ValueError(f"Fold artifact {i} is missing expected key: {e}") from e
except Exception as e:
raise ValueError(f"Error processing fold artifact {i} for lookback calculation: {e}") from e
self._required_lookback = max_lookback
logger.debug(f"EnsembleProvider initialized with {len(fold_artifacts)} folds. Method: '{ensemble_method}'. Required lookback: {self._required_lookback}")
if ensemble_method not in ['mean', 'median']:
raise ValueError(f"Unsupported ensemble method: {ensemble_method}. Use 'mean' or 'median'.")
def get_required_lookback(self) -> int:
return self._required_lookback
def get_forecast(
self,
historical_data_slice: pd.DataFrame,
optimization_horizon_hours: int
) -> np.ndarray | None:
"""
Generates forecasts from each fold model, interpolates, and aggregates.
"""
logger.debug(f"EnsembleProvider: Generating forecast for {optimization_horizon_hours} hours using {self.ensemble_method}.")
if len(historical_data_slice) < self._required_lookback:
logger.error(f"Insufficient historical data provided. Need {self._required_lookback}, got {len(historical_data_slice)}.")
return None
fold_forecasts_interpolated = []
last_actual_price = historical_data_slice[self.ensemble_target_col].iloc[-1] # Common anchor for all folds
for i, fold_artifact in enumerate(self.fold_artifacts):
fold_id = fold_artifact.get("fold_id", i + 1)
try:
fold_model: LSTMForecastLightningModule = fold_artifact['model_instance']
fold_feature_config: FeatureConfig = fold_artifact['feature_config']
fold_target_scaler: Optional[Any] = fold_artifact['target_scaler']
fold_target_col: str = fold_artifact['main_forecasting_config'].data.target_col # Use fold specific target
fold_seq_len = fold_feature_config.sequence_length
fold_horizons = sorted(fold_feature_config.forecast_horizon)
# Calculate lookback needed *for this specific fold* to check slice length
fold_feature_lookback = 0
if fold_feature_config.lags: fold_feature_lookback = max(fold_feature_lookback, max(fold_feature_config.lags))
if fold_feature_config.rolling_window_sizes: fold_feature_lookback = max(fold_feature_lookback, max(w - 1 for w in fold_feature_config.rolling_window_sizes))
fold_total_lookback = fold_seq_len + fold_feature_lookback
if len(historical_data_slice) < fold_total_lookback:
logger.warning(f"Fold {fold_id}: Skipping fold. Insufficient historical data in slice for this fold's lookback ({fold_total_lookback} needed).")
continue
# 1. Feature Engineering (using fold's config)
# Slice needs to be long enough for this fold's total lookback.
# The input slice `historical_data_slice` should already be long enough based on max_lookback.
engineered_df_fold = engineer_features(historical_data_slice.copy(), fold_target_col, fold_feature_config)
if engineered_df_fold.isnull().any().any():
logger.warning(f"Fold {fold_id}: NaNs found after feature engineering. Attempting fill.")
engineered_df_fold = engineered_df_fold.ffill().bfill()
if engineered_df_fold.isnull().any().any():
logger.error(f"Fold {fold_id}: NaNs persist after fill. Skipping fold.")
continue
# 2. Create *one* input sequence (using fold's sequence length)
if len(engineered_df_fold) < fold_seq_len:
logger.error(f"Fold {fold_id}: Engineered data ({len(engineered_df_fold)}) is shorter than fold sequence length ({fold_seq_len}). Skipping fold.")
continue
input_sequence_data_fold = engineered_df_fold.iloc[-fold_seq_len:].copy()
feature_columns_fold = [col for col in engineered_df_fold.columns if col != fold_target_col] # Example
if not feature_columns_fold: feature_columns_fold = engineered_df_fold.columns.tolist()
input_sequence_np_fold = input_sequence_data_fold[feature_columns_fold].values
if input_sequence_np_fold.shape != (fold_seq_len, len(feature_columns_fold)):
logger.error(f"Fold {fold_id}: Input sequence has wrong shape. Expected ({fold_seq_len}, {len(feature_columns_fold)}), got {input_sequence_np_fold.shape}. Skipping fold.")
continue
input_tensor_fold = torch.FloatTensor(input_sequence_np_fold).unsqueeze(0)
# 3. Run Inference (using fold's model)
fold_model.eval()
with torch.no_grad():
predictions_scaled_fold = fold_model(input_tensor_fold) # Shape (1, num_fold_horizons)
if predictions_scaled_fold.ndim != 2 or predictions_scaled_fold.shape[0] != 1 or predictions_scaled_fold.shape[1] != len(fold_horizons):
logger.error(f"Fold {fold_id}: Prediction output shape mismatch. Expected (1, {len(fold_horizons)}), got {predictions_scaled_fold.shape}. Skipping fold.")
continue
predictions_scaled_np_fold = predictions_scaled_fold.squeeze(0).cpu().numpy()
# 4. Inverse Transform (using fold's scaler)
predictions_original_scale_fold = predictions_scaled_np_fold
if fold_target_scaler:
try:
predictions_original_scale_fold = fold_target_scaler.inverse_transform(predictions_scaled_np_fold.reshape(-1, 1)).flatten()
except Exception as e:
logger.error(f"Fold {fold_id}: Failed to apply inverse transform: {e}. Skipping fold.", exc_info=True)
continue
# 5. Interpolate (using fold's horizons)
interpolated_forecast_fold = interpolate_forecast(
native_horizons=fold_horizons,
native_predictions=predictions_original_scale_fold,
target_horizon=optimization_horizon_hours,
last_known_actual=last_actual_price
)
if interpolated_forecast_fold is not None:
fold_forecasts_interpolated.append(interpolated_forecast_fold)
logger.debug(f"Fold {fold_id}: Successfully generated interpolated forecast.")
else:
logger.warning(f"Fold {fold_id}: Interpolation failed. Skipping fold.")
except Exception as e:
logger.error(f"Error processing ensemble fold {fold_id}: {e}", exc_info=True)
continue # Skip this fold on error
# --- Aggregation ---
if not fold_forecasts_interpolated:
logger.error("No successful forecasts generated from any ensemble folds.")
return None
logger.debug(f"Aggregating forecasts from {len(fold_forecasts_interpolated)} folds using '{self.ensemble_method}'.")
stacked_predictions = np.stack(fold_forecasts_interpolated, axis=0) # Shape (n_folds, target_horizon)
if self.ensemble_method == 'mean':
final_ensemble_forecast = np.mean(stacked_predictions, axis=0)
elif self.ensemble_method == 'median':
final_ensemble_forecast = np.median(stacked_predictions, axis=0)
else:
# Should be caught in __init__, but double-check
logger.error(f"Internal error: Invalid ensemble method '{self.ensemble_method}' during aggregation.")
return None
logger.debug(f"EnsembleProvider: Successfully generated forecast.")
return final_ensemble_forecast