import logging from typing import List, Dict, Any, Optional import numpy as np import pandas as pd import torch from sklearn.preprocessing import StandardScaler, MinMaxScaler # Imports from our project structure from .base import ForecastProvider from forecasting_model.utils import FeatureConfig from forecasting_model.train.model import LSTMForecastLightningModule from forecasting_model import engineer_features from optimizer.forecasting.utils import interpolate_forecast logger = logging.getLogger(__name__) class SingleModelProvider(ForecastProvider): """Provides forecasts using a single trained LSTM model.""" def __init__( self, model_instance: LSTMForecastLightningModule, feature_config: FeatureConfig, target_col: str, target_scaler: Optional[Any], # BaseEstimator, TransformerMixin -> more specific if possible # input_size: int # Not needed directly if model instance is configured ): self.model = model_instance self.feature_config = feature_config self.target_col = target_col self.target_scaler = target_scaler self.sequence_length = feature_config.sequence_length self.forecast_horizons = sorted(feature_config.forecast_horizon) # Ensure sorted # Calculate required lookback for feature engineering feature_lookback = 0 if feature_config.lags: feature_lookback = max(feature_lookback, max(feature_config.lags)) if feature_config.rolling_window_sizes: # Rolling window of size W needs W-1 previous points feature_lookback = max(feature_lookback, max(w - 1 for w in feature_config.rolling_window_sizes)) # Total lookback: sequence length for model input + feature engineering needs # We need `sequence_length` points for the *last* input sequence. # The first point of that sequence needs `feature_lookback` points before it. # So, total points needed before the *end* of the input sequence is sequence_length + feature_lookback. # Since the input sequence ends *before* the first forecast point (t=1), # we need `sequence_length + feature_lookback` points before t=1. self._required_lookback = self.sequence_length + feature_lookback logger.debug(f"SingleModelProvider initialized. Required lookback: {self._required_lookback} (SeqLen: {self.sequence_length}, FeatLookback: {feature_lookback})") def get_required_lookback(self) -> int: return self._required_lookback def get_forecast( self, historical_data_slice: pd.DataFrame, optimization_horizon_hours: int ) -> np.ndarray | None: """ Generates forecast using the single model and interpolates to hourly resolution. """ logger.debug(f"SingleModelProvider: Generating forecast for {optimization_horizon_hours} hours.") if len(historical_data_slice) < self._required_lookback: logger.error(f"Insufficient historical data provided. Need {self._required_lookback}, got {len(historical_data_slice)}.") return None try: # 1. Feature Engineering # Use the provided slice which already includes the lookback. engineered_df = engineer_features(historical_data_slice.copy(), self.target_col, self.feature_config) # Check for NaNs after feature engineering before creating sequences if engineered_df.isnull().any().any(): logger.warning("NaNs found after feature engineering. Attempting to fill with ffill/bfill.") # Be careful about filling target vs features if needed engineered_df = engineered_df.ffill().bfill() if engineered_df.isnull().any().any(): logger.error("NaNs persist after fill. Cannot create sequences.") return None # 2. Create *one* input sequence ending at the last point of the historical slice # This sequence is used to predict starting from the next hour (t=1) if len(engineered_df) < self.sequence_length: logger.error(f"Engineered data ({len(engineered_df)}) is shorter than sequence length ({self.sequence_length}).") return None input_sequence_data = engineered_df.iloc[-self.sequence_length:].copy() # Convert sequence data to numpy array (excluding target if model expects it that way) # Assuming model takes all engineered features as input # TODO: Verify the exact features the model expects (target included/excluded?) # Assuming all columns except maybe the original target are features feature_columns = [col for col in engineered_df.columns if col != self.target_col] # Example if not feature_columns: feature_columns = engineered_df.columns.tolist() # Use all if target wasn't dropped input_sequence_np = input_sequence_data[feature_columns].values if input_sequence_np.shape != (self.sequence_length, len(feature_columns)): logger.error(f"Input sequence has wrong shape. Expected ({self.sequence_length}, {len(feature_columns)}), got {input_sequence_np.shape}") return None input_tensor = torch.FloatTensor(input_sequence_np).unsqueeze(0) # Add batch dim # 3. Run Inference self.model.eval() with torch.no_grad(): # Model output shape: (1, num_horizons) predictions_scaled = self.model(input_tensor) if predictions_scaled.ndim != 2 or predictions_scaled.shape[0] != 1 or predictions_scaled.shape[1] != len(self.forecast_horizons): logger.error(f"Model prediction output shape mismatch. Expected (1, {len(self.forecast_horizons)}), got {predictions_scaled.shape}.") return None predictions_scaled_np = predictions_scaled.squeeze(0).cpu().numpy() # Shape: (num_horizons,) # 4. Inverse Transform predictions_original_scale = predictions_scaled_np if self.target_scaler: try: # Scaler expects shape (n_samples, n_features), even if n_features=1 predictions_original_scale = self.target_scaler.inverse_transform(predictions_scaled_np.reshape(-1, 1)).flatten() logger.debug("Applied inverse transform to predictions.") except Exception as e: logger.error(f"Failed to apply inverse transform: {e}", exc_info=True) # Decide whether to return scaled or None. Returning None is safer. return None # 5. Interpolate # Use the last actual price from the input data as the anchor point t=0 last_actual_price = historical_data_slice[self.target_col].iloc[-1] interpolated_forecast = interpolate_forecast( native_horizons=self.forecast_horizons, native_predictions=predictions_original_scale, target_horizon=optimization_horizon_hours, last_known_actual=last_actual_price ) if interpolated_forecast is None: logger.error("Interpolation step failed.") return None logger.debug(f"SingleModelProvider: Successfully generated forecast.") return interpolated_forecast except Exception as e: logger.error(f"Error during single model forecast generation: {e}", exc_info=True) return None