import logging import numpy as np import torch import torch.nn as nn import torch.optim as optim import pytorch_lightning as pl import torchmetrics from typing import Optional, Dict, Any, Union, List, Tuple from sklearn.preprocessing import StandardScaler, MinMaxScaler # Assuming config_model is in sibling directory utils/ from forecasting_model.utils.config_model import ModelConfig, TrainingConfig logger = logging.getLogger(__name__) class LSTMForecastLightningModule(pl.LightningModule): """ PyTorch Lightning Module for LSTM-based time series forecasting. Encapsulates the model architecture, training, validation, and test logic. Uses torchmetrics for efficient metric calculation. """ def __init__( self, model_config: ModelConfig, train_config: TrainingConfig, input_size: int, target_scaler: Optional[Union[StandardScaler, MinMaxScaler]] = None, ): super().__init__() # --- Validate & Store Configs --- # Validate the input_size passed during instantiation if input_size <= 0: raise ValueError("`input_size` must be provided as a positive integer during model instantiation.") # Store the validated input_size directly for use in layer definitions self._input_size = input_size # Use a temporary attribute before hparams are saved # Ensure forecast_horizon is set in the config for the output layer if not hasattr(model_config, 'forecast_horizon') or model_config.forecast_horizon is None or model_config.forecast_horizon <= 0: raise ValueError("ModelConfig requires `forecast_horizon` to be set and positive.") self.output_size = model_config.forecast_horizon # Store configurations - input_size argument will be saved via save_hyperparameters self.model_config = model_config self.train_config = train_config self.target_scaler = target_scaler # Store scaler for this fold # Use save_hyperparameters() to automatically log configs and allow loading # Pass input_size explicitly to be saved in hparams # Exclude scaler as it's stateful and fold-specific self.save_hyperparameters('model_config', 'train_config', 'input_size', ignore=['target_scaler']) # --- Define Model Layers --- # Access input_size via hparams now self.lstm = nn.LSTM( input_size=self.hparams.input_size, hidden_size=self.hparams.model_config.hidden_size, num_layers=self.hparams.model_config.num_layers, batch_first=True, # Input shape: (batch, seq_len, features) dropout=self.hparams.model_config.dropout if self.hparams.model_config.num_layers > 1 else 0.0 ) self.dropout = nn.Dropout(self.hparams.model_config.dropout) # Output layer maps LSTM hidden state to the forecast horizon # We typically take the output of the last time step self.fc = nn.Linear(self.hparams.model_config.hidden_size, self.output_size) # Optional residual connection handling self.use_residual_skips = self.hparams.model_config.use_residual_skips self.residual_projection = None if self.use_residual_skips: # If input size doesn't match hidden size, project input if self.hparams.input_size != self.hparams.model_config.hidden_size: # Use hparams.input_size here self.residual_projection = nn.Linear(self.hparams.input_size, self.hparams.model_config.hidden_size) logger.info("Residual connections enabled.") if self.residual_projection: logger.info("Residual projection layer added.") # --- Define Loss Function --- if self.hparams.train_config.loss_function.upper() == 'MSE': self.criterion = nn.MSELoss() elif self.hparams.train_config.loss_function.upper() == 'MAE': self.criterion = nn.L1Loss() else: raise ValueError(f"Unsupported loss function: {self.hparams.train_config.loss_function}") # --- Define Metrics (TorchMetrics) --- metrics = torchmetrics.MetricCollection([ torchmetrics.MeanAbsoluteError(), torchmetrics.MeanSquaredError(squared=False) # RMSE ]) self.train_metrics = metrics.clone(prefix='train_') self.val_metrics = metrics.clone(prefix='val_') self.test_metrics = metrics.clone(prefix='test_') self.val_mae_original_scale = torchmetrics.MeanAbsoluteError() def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass through the LSTM network. Args: x: Input tensor of shape (batch_size, sequence_length, input_size) Returns: Predictions tensor of shape (batch_size, forecast_horizon) """ # LSTM forward pass lstm_out, (hidden, cell) = self.lstm(x) # Shape: (batch, seq_len, hidden_size) # Output from the last time step last_time_step_out = lstm_out[:, -1, :] # Shape: (batch_size, hidden_size) # Apply dropout last_time_step_out = self.dropout(last_time_step_out) # Optional Residual Connection if self.use_residual_skips: residual = x[:, -1, :] # Input from the last time step: (batch_size, input_size) if self.residual_projection: residual = self.residual_projection(residual) # Project to hidden_size last_time_step_out = last_time_step_out + residual # Final fully connected layer predictions = self.fc(last_time_step_out) # Shape: (batch_size, output_size/horizon) return predictions # Shape: (batch_size, forecast_horizon) def _calculate_loss(self, outputs, targets): # Ensure shapes match before loss calculation if outputs.shape != targets.shape: # Squeeze potential extra dim: (batch, horizon, 1) -> (batch, horizon) if outputs.ndim == targets.ndim + 1 and outputs.shape[-1] == 1: outputs = outputs.squeeze(-1) if outputs.shape != targets.shape: raise ValueError(f"Output shape {outputs.shape} doesn't match target shape {targets.shape} for loss calculation.") return self.criterion(outputs, targets) def _inverse_transform(self, data: torch.Tensor) -> Optional[torch.Tensor]: """Helper to inverse transform data using the stored target scaler.""" if self.target_scaler is None: # logger.warning("Cannot inverse transform: target_scaler not available.") return None # Cannot inverse transform # Scaler expects 2D input (N, 1) # Ensure data is on CPU and is float64 for sklearn scaler typically data_cpu = data.detach().cpu().numpy().astype(np.float64) original_shape = data_cpu.shape if data_cpu.ndim == 1: data_flat = data_cpu.reshape(-1, 1) elif data_cpu.ndim == 2: # (batch, horizon) data_flat = data_cpu.reshape(-1, 1) else: logger.warning(f"Unexpected shape for inverse transform: {original_shape}. Reshaping to (-1, 1).") data_flat = data_cpu.reshape(-1, 1) try: inversed_np = self.target_scaler.inverse_transform(data_flat) # Return as tensor on the original device inversed_tensor = torch.from_numpy(inversed_np).float().to(data.device) # Reshape back? Or keep flat? Keep flat for direct metric use often. return inversed_tensor.flatten() # return inversed_tensor.reshape(original_shape) # If original shape needed except Exception as e: logger.error(f"Failed to inverse transform data: {e}", exc_info=True) return None # Return None if inverse transform fails def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: x, y = batch # Shapes: x=(batch, seq_len, features), y=(batch, horizon) outputs = self(x) # Scaled outputs: (batch, horizon) loss = self._calculate_loss(outputs, y) # Log scaled metrics metrics = self.train_metrics(outputs, y) # Update internal state self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log_dict(self.train_metrics, on_step=False, on_epoch=True, logger=True) # Log all metrics in collection return loss def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int): x, y = batch outputs = self(x) # Scaled outputs loss = self._calculate_loss(outputs, y) # Log scaled metrics metrics = self.val_metrics(outputs, y) # Update internal state self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log_dict(self.val_metrics, on_step=False, on_epoch=True, logger=True) # Log MAE on ORIGINAL scale if scaler is available (often the primary metric for checkpointing/Optuna) if self.target_scaler is not None: outputs_inv = self._inverse_transform(outputs) y_inv = self._inverse_transform(y) if outputs_inv is not None and y_inv is not None: # Ensure shapes are compatible (flattened by _inverse_transform) if outputs_inv.shape == y_inv.shape: self.val_mae_original_scale.update(outputs_inv, y_inv) self.log('val_mae_orig_scale', self.val_mae_original_scale, on_step=False, on_epoch=True, prog_bar=True, logger=True) else: logger.warning(f"Shape mismatch after inverse transform in validation: Preds {outputs_inv.shape}, Targets {y_inv.shape}") else: logger.warning("Could not compute original scale MAE in validation due to inverse transform failure.") def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int): # Optional: Keep this method ONLY if you want trainer.test() to log metrics. # For getting predictions for evaluation, use predict_step. # If evaluate_fold_predictions handles all metrics, this can be simplified or removed. # Let's simplify it for now to only log loss if needed. try: x, y = batch outputs = self(x) loss = self._calculate_loss(outputs, y) # Log scaled test metrics if you still want trainer.test() to report them metrics = self.test_metrics(outputs, y) self.log('test_loss_step', loss, on_step=True, on_epoch=False) # Log step loss if needed self.log_dict(self.test_metrics, on_step=False, on_epoch=True, logger=True) # No return needed if just logging except Exception as e: logger.error(f"Error occurred in test_step for batch {batch_idx}: {e}", exc_info=True) # Optionally log something to indicate failure def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Dict[str, torch.Tensor]: """ Runs inference for prediction and returns scaled predictions and targets. 'batch' might contain only features depending on the DataLoader setup for predict. Let's assume the test_loader yields (x, y) pairs for convenience here. """ if isinstance(batch, (list, tuple)) and len(batch) == 2: x, y = batch else: # Assume batch contains only features if not a pair x = batch y = None # No targets available during prediction if dataloader only yields features outputs = self(x) # Scaled outputs result = {'preds_scaled': outputs.detach().cpu()} if y is not None: # Include targets if they were part of the batch (e.g., using test_loader for predict) result['targets_scaled'] = y.detach().cpu() return result def configure_optimizers(self) -> Union[optim.Optimizer, Tuple[List[optim.Optimizer], List[Dict[str, Any]]]]: """ Configure the optimizer (Adam) and optional LR scheduler. """ optimizer = optim.Adam( self.parameters(), lr=self.hparams.train_config.learning_rate # Access lr via hparams ) logger.info(f"Configured Adam optimizer with LR: {self.hparams.train_config.learning_rate}") # Optional LR Scheduler configuration scheduler_config = None if hasattr(self.hparams.train_config, 'scheduler_step_size') and \ self.hparams.train_config.scheduler_step_size is not None and \ hasattr(self.hparams.train_config, 'scheduler_gamma') and \ self.hparams.train_config.scheduler_gamma is not None: if self.hparams.train_config.scheduler_step_size > 0 and 0 < self.hparams.train_config.scheduler_gamma < 1: logger.info(f"Configuring StepLR scheduler with step_size={self.hparams.train_config.scheduler_step_size} " f"and gamma={self.hparams.train_config.scheduler_gamma}") scheduler = optim.lr_scheduler.StepLR( optimizer, step_size=self.hparams.train_config.scheduler_step_size, gamma=self.hparams.train_config.scheduler_gamma ) scheduler_config = { 'scheduler': scheduler, 'interval': 'epoch', # or 'step' 'frequency': 1, 'monitor': 'val_loss', # Optional: Only step if monitor improves (for ReduceLROnPlateau) } else: logger.warning("Scheduler parameters provided but invalid (step_size must be >0, 0