188 lines
10 KiB
Python
188 lines
10 KiB
Python
import logging
|
|
from typing import List, Dict, Any, Optional
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import torch
|
|
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
|
|
|
from .base import ForecastProvider
|
|
from forecasting_model.utils import FeatureConfig
|
|
from forecasting_model.train.model import LSTMForecastLightningModule
|
|
from forecasting_model import engineer_features
|
|
from optimizer.forecasting.utils import interpolate_forecast
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
class EnsembleProvider(ForecastProvider):
|
|
"""Provides forecasts using an ensemble of trained LSTM models."""
|
|
|
|
def __init__(
|
|
self,
|
|
fold_artifacts: List[Dict[str, Any]],
|
|
ensemble_method: str,
|
|
ensemble_feature_config: FeatureConfig, # Assumed consistent across folds by loading logic
|
|
ensemble_target_col: str, # Assumed consistent
|
|
):
|
|
if not fold_artifacts:
|
|
raise ValueError("EnsembleProvider requires at least one fold artifact.")
|
|
|
|
self.fold_artifacts = fold_artifacts
|
|
self.ensemble_method = ensemble_method
|
|
# Store common config for reference, but use fold-specific details in get_forecast
|
|
self.ensemble_feature_config = ensemble_feature_config
|
|
self.ensemble_target_col = ensemble_target_col
|
|
self.common_forecast_horizons = sorted(ensemble_feature_config.forecast_horizon) # Assumed consistent
|
|
|
|
# Calculate max lookback needed across all folds
|
|
max_lookback = 0
|
|
for i, fold in enumerate(fold_artifacts):
|
|
try:
|
|
fold_feature_config = fold['feature_config']
|
|
fold_seq_len = fold_feature_config.sequence_length
|
|
|
|
feature_lookback = 0
|
|
if fold_feature_config.lags:
|
|
feature_lookback = max(feature_lookback, max(fold_feature_config.lags))
|
|
if fold_feature_config.rolling_window_sizes:
|
|
feature_lookback = max(feature_lookback, max(w - 1 for w in fold_feature_config.rolling_window_sizes))
|
|
|
|
fold_total_lookback = fold_seq_len + feature_lookback
|
|
max_lookback = max(max_lookback, fold_total_lookback)
|
|
except KeyError as e:
|
|
raise ValueError(f"Fold artifact {i} is missing expected key: {e}") from e
|
|
except Exception as e:
|
|
raise ValueError(f"Error processing fold artifact {i} for lookback calculation: {e}") from e
|
|
|
|
self._required_lookback = max_lookback
|
|
logger.debug(f"EnsembleProvider initialized with {len(fold_artifacts)} folds. Method: '{ensemble_method}'. Required lookback: {self._required_lookback}")
|
|
|
|
if ensemble_method not in ['mean', 'median']:
|
|
raise ValueError(f"Unsupported ensemble method: {ensemble_method}. Use 'mean' or 'median'.")
|
|
|
|
def get_required_lookback(self) -> int:
|
|
return self._required_lookback
|
|
|
|
def get_forecast(
|
|
self,
|
|
historical_data_slice: pd.DataFrame,
|
|
optimization_horizon_hours: int
|
|
) -> np.ndarray | None:
|
|
"""
|
|
Generates forecasts from each fold model, interpolates, and aggregates.
|
|
"""
|
|
logger.debug(f"EnsembleProvider: Generating forecast for {optimization_horizon_hours} hours using {self.ensemble_method}.")
|
|
if len(historical_data_slice) < self._required_lookback:
|
|
logger.error(f"Insufficient historical data provided. Need {self._required_lookback}, got {len(historical_data_slice)}.")
|
|
return None
|
|
|
|
fold_forecasts_interpolated = []
|
|
last_actual_price = historical_data_slice[self.ensemble_target_col].iloc[-1] # Common anchor for all folds
|
|
|
|
for i, fold_artifact in enumerate(self.fold_artifacts):
|
|
fold_id = fold_artifact.get("fold_id", i + 1)
|
|
try:
|
|
fold_model: LSTMForecastLightningModule = fold_artifact['model_instance']
|
|
fold_feature_config: FeatureConfig = fold_artifact['feature_config']
|
|
fold_target_scaler: Optional[Any] = fold_artifact['target_scaler']
|
|
fold_target_col: str = fold_artifact['main_forecasting_config'].data.target_col # Use fold specific target
|
|
fold_seq_len = fold_feature_config.sequence_length
|
|
fold_horizons = sorted(fold_feature_config.forecast_horizon)
|
|
|
|
# Calculate lookback needed *for this specific fold* to check slice length
|
|
fold_feature_lookback = 0
|
|
if fold_feature_config.lags: fold_feature_lookback = max(fold_feature_lookback, max(fold_feature_config.lags))
|
|
if fold_feature_config.rolling_window_sizes: fold_feature_lookback = max(fold_feature_lookback, max(w - 1 for w in fold_feature_config.rolling_window_sizes))
|
|
fold_total_lookback = fold_seq_len + fold_feature_lookback
|
|
|
|
if len(historical_data_slice) < fold_total_lookback:
|
|
logger.warning(f"Fold {fold_id}: Skipping fold. Insufficient historical data in slice for this fold's lookback ({fold_total_lookback} needed).")
|
|
continue
|
|
|
|
# 1. Feature Engineering (using fold's config)
|
|
# Slice needs to be long enough for this fold's total lookback.
|
|
# The input slice `historical_data_slice` should already be long enough based on max_lookback.
|
|
engineered_df_fold = engineer_features(historical_data_slice.copy(), fold_target_col, fold_feature_config)
|
|
|
|
if engineered_df_fold.isnull().any().any():
|
|
logger.warning(f"Fold {fold_id}: NaNs found after feature engineering. Attempting fill.")
|
|
engineered_df_fold = engineered_df_fold.ffill().bfill()
|
|
if engineered_df_fold.isnull().any().any():
|
|
logger.error(f"Fold {fold_id}: NaNs persist after fill. Skipping fold.")
|
|
continue
|
|
|
|
# 2. Create *one* input sequence (using fold's sequence length)
|
|
if len(engineered_df_fold) < fold_seq_len:
|
|
logger.error(f"Fold {fold_id}: Engineered data ({len(engineered_df_fold)}) is shorter than fold sequence length ({fold_seq_len}). Skipping fold.")
|
|
continue
|
|
|
|
input_sequence_data_fold = engineered_df_fold.iloc[-fold_seq_len:].copy()
|
|
feature_columns_fold = [col for col in engineered_df_fold.columns if col != fold_target_col] # Example
|
|
if not feature_columns_fold: feature_columns_fold = engineered_df_fold.columns.tolist()
|
|
input_sequence_np_fold = input_sequence_data_fold[feature_columns_fold].values
|
|
|
|
if input_sequence_np_fold.shape != (fold_seq_len, len(feature_columns_fold)):
|
|
logger.error(f"Fold {fold_id}: Input sequence has wrong shape. Expected ({fold_seq_len}, {len(feature_columns_fold)}), got {input_sequence_np_fold.shape}. Skipping fold.")
|
|
continue
|
|
|
|
input_tensor_fold = torch.FloatTensor(input_sequence_np_fold).unsqueeze(0)
|
|
|
|
# 3. Run Inference (using fold's model)
|
|
fold_model.eval()
|
|
with torch.no_grad():
|
|
predictions_scaled_fold = fold_model(input_tensor_fold) # Shape (1, num_fold_horizons)
|
|
|
|
if predictions_scaled_fold.ndim != 2 or predictions_scaled_fold.shape[0] != 1 or predictions_scaled_fold.shape[1] != len(fold_horizons):
|
|
logger.error(f"Fold {fold_id}: Prediction output shape mismatch. Expected (1, {len(fold_horizons)}), got {predictions_scaled_fold.shape}. Skipping fold.")
|
|
continue
|
|
|
|
predictions_scaled_np_fold = predictions_scaled_fold.squeeze(0).cpu().numpy()
|
|
|
|
# 4. Inverse Transform (using fold's scaler)
|
|
predictions_original_scale_fold = predictions_scaled_np_fold
|
|
if fold_target_scaler:
|
|
try:
|
|
predictions_original_scale_fold = fold_target_scaler.inverse_transform(predictions_scaled_np_fold.reshape(-1, 1)).flatten()
|
|
except Exception as e:
|
|
logger.error(f"Fold {fold_id}: Failed to apply inverse transform: {e}. Skipping fold.", exc_info=True)
|
|
continue
|
|
|
|
# 5. Interpolate (using fold's horizons)
|
|
interpolated_forecast_fold = interpolate_forecast(
|
|
native_horizons=fold_horizons,
|
|
native_predictions=predictions_original_scale_fold,
|
|
target_horizon=optimization_horizon_hours,
|
|
last_known_actual=last_actual_price
|
|
)
|
|
|
|
if interpolated_forecast_fold is not None:
|
|
fold_forecasts_interpolated.append(interpolated_forecast_fold)
|
|
logger.debug(f"Fold {fold_id}: Successfully generated interpolated forecast.")
|
|
else:
|
|
logger.warning(f"Fold {fold_id}: Interpolation failed. Skipping fold.")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing ensemble fold {fold_id}: {e}", exc_info=True)
|
|
continue # Skip this fold on error
|
|
|
|
# --- Aggregation ---
|
|
if not fold_forecasts_interpolated:
|
|
logger.error("No successful forecasts generated from any ensemble folds.")
|
|
return None
|
|
|
|
logger.debug(f"Aggregating forecasts from {len(fold_forecasts_interpolated)} folds using '{self.ensemble_method}'.")
|
|
stacked_predictions = np.stack(fold_forecasts_interpolated, axis=0) # Shape (n_folds, target_horizon)
|
|
|
|
if self.ensemble_method == 'mean':
|
|
final_ensemble_forecast = np.mean(stacked_predictions, axis=0)
|
|
elif self.ensemble_method == 'median':
|
|
final_ensemble_forecast = np.median(stacked_predictions, axis=0)
|
|
else:
|
|
# Should be caught in __init__, but double-check
|
|
logger.error(f"Internal error: Invalid ensemble method '{self.ensemble_method}' during aggregation.")
|
|
return None
|
|
|
|
logger.debug(f"EnsembleProvider: Successfully generated forecast.")
|
|
return final_ensemble_forecast |