292 lines
14 KiB
Python
292 lines
14 KiB
Python
import logging
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import pytorch_lightning as pl
|
|
import torchmetrics
|
|
from typing import Optional, Dict, Any, Union, List, Tuple
|
|
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
|
|
|
# Assuming config_model is in sibling directory utils/
|
|
from forecasting_model.utils.config_model import ModelConfig, TrainingConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class LSTMForecastLightningModule(pl.LightningModule):
|
|
"""
|
|
PyTorch Lightning Module for LSTM-based time series forecasting.
|
|
|
|
Encapsulates the model architecture, training, validation, and test logic.
|
|
Uses torchmetrics for efficient metric calculation.
|
|
"""
|
|
def __init__(
|
|
self,
|
|
model_config: ModelConfig,
|
|
train_config: TrainingConfig,
|
|
input_size: int,
|
|
target_scaler: Optional[Union[StandardScaler, MinMaxScaler]] = None,
|
|
):
|
|
super().__init__()
|
|
|
|
# --- Validate & Store Configs ---
|
|
# Validate the input_size passed during instantiation
|
|
if input_size <= 0:
|
|
raise ValueError("`input_size` must be provided as a positive integer during model instantiation.")
|
|
|
|
# Store the validated input_size directly for use in layer definitions
|
|
self._input_size = input_size # Use a temporary attribute before hparams are saved
|
|
|
|
# Ensure forecast_horizon is set in the config for the output layer
|
|
if not hasattr(model_config, 'forecast_horizon') or model_config.forecast_horizon is None or model_config.forecast_horizon <= 0:
|
|
raise ValueError("ModelConfig requires `forecast_horizon` to be set and positive.")
|
|
self.output_size = model_config.forecast_horizon
|
|
|
|
# Store configurations - input_size argument will be saved via save_hyperparameters
|
|
self.model_config = model_config
|
|
self.train_config = train_config
|
|
self.target_scaler = target_scaler # Store scaler for this fold
|
|
|
|
# Use save_hyperparameters() to automatically log configs and allow loading
|
|
# Pass input_size explicitly to be saved in hparams
|
|
# Exclude scaler as it's stateful and fold-specific
|
|
self.save_hyperparameters('model_config', 'train_config', 'input_size', ignore=['target_scaler'])
|
|
|
|
# --- Define Model Layers ---
|
|
# Access input_size via hparams now
|
|
self.lstm = nn.LSTM(
|
|
input_size=self.hparams.input_size,
|
|
hidden_size=self.hparams.model_config.hidden_size,
|
|
num_layers=self.hparams.model_config.num_layers,
|
|
batch_first=True, # Input shape: (batch, seq_len, features)
|
|
dropout=self.hparams.model_config.dropout if self.hparams.model_config.num_layers > 1 else 0.0
|
|
)
|
|
self.dropout = nn.Dropout(self.hparams.model_config.dropout)
|
|
|
|
# Output layer maps LSTM hidden state to the forecast horizon
|
|
# We typically take the output of the last time step
|
|
self.fc = nn.Linear(self.hparams.model_config.hidden_size, self.output_size)
|
|
|
|
# Optional residual connection handling
|
|
self.use_residual_skips = self.hparams.model_config.use_residual_skips
|
|
self.residual_projection = None
|
|
if self.use_residual_skips:
|
|
# If input size doesn't match hidden size, project input
|
|
if self.hparams.input_size != self.hparams.model_config.hidden_size:
|
|
# Use hparams.input_size here
|
|
self.residual_projection = nn.Linear(self.hparams.input_size, self.hparams.model_config.hidden_size)
|
|
logger.info("Residual connections enabled.")
|
|
if self.residual_projection:
|
|
logger.info("Residual projection layer added.")
|
|
|
|
# --- Define Loss Function ---
|
|
if self.hparams.train_config.loss_function.upper() == 'MSE':
|
|
self.criterion = nn.MSELoss()
|
|
elif self.hparams.train_config.loss_function.upper() == 'MAE':
|
|
self.criterion = nn.L1Loss()
|
|
else:
|
|
raise ValueError(f"Unsupported loss function: {self.hparams.train_config.loss_function}")
|
|
|
|
# --- Define Metrics (TorchMetrics) ---
|
|
metrics = torchmetrics.MetricCollection([
|
|
torchmetrics.MeanAbsoluteError(),
|
|
torchmetrics.MeanSquaredError(squared=False) # RMSE
|
|
])
|
|
self.train_metrics = metrics.clone(prefix='train_')
|
|
self.val_metrics = metrics.clone(prefix='val_')
|
|
self.test_metrics = metrics.clone(prefix='test_')
|
|
|
|
self.val_mae_original_scale = torchmetrics.MeanAbsoluteError()
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Forward pass through the LSTM network.
|
|
|
|
Args:
|
|
x: Input tensor of shape (batch_size, sequence_length, input_size)
|
|
|
|
Returns:
|
|
Predictions tensor of shape (batch_size, forecast_horizon)
|
|
"""
|
|
# LSTM forward pass
|
|
lstm_out, (hidden, cell) = self.lstm(x) # Shape: (batch, seq_len, hidden_size)
|
|
|
|
# Output from the last time step
|
|
last_time_step_out = lstm_out[:, -1, :] # Shape: (batch_size, hidden_size)
|
|
|
|
# Apply dropout
|
|
last_time_step_out = self.dropout(last_time_step_out)
|
|
|
|
# Optional Residual Connection
|
|
if self.use_residual_skips:
|
|
residual = x[:, -1, :] # Input from the last time step: (batch_size, input_size)
|
|
if self.residual_projection:
|
|
residual = self.residual_projection(residual) # Project to hidden_size
|
|
last_time_step_out = last_time_step_out + residual
|
|
|
|
# Final fully connected layer
|
|
predictions = self.fc(last_time_step_out) # Shape: (batch_size, output_size/horizon)
|
|
|
|
return predictions # Shape: (batch_size, forecast_horizon)
|
|
|
|
def _calculate_loss(self, outputs, targets):
|
|
# Ensure shapes match before loss calculation
|
|
if outputs.shape != targets.shape:
|
|
# Squeeze potential extra dim: (batch, horizon, 1) -> (batch, horizon)
|
|
if outputs.ndim == targets.ndim + 1 and outputs.shape[-1] == 1:
|
|
outputs = outputs.squeeze(-1)
|
|
if outputs.shape != targets.shape:
|
|
raise ValueError(f"Output shape {outputs.shape} doesn't match target shape {targets.shape} for loss calculation.")
|
|
return self.criterion(outputs, targets)
|
|
|
|
def _inverse_transform(self, data: torch.Tensor) -> Optional[torch.Tensor]:
|
|
"""Helper to inverse transform data using the stored target scaler."""
|
|
if self.target_scaler is None:
|
|
# logger.warning("Cannot inverse transform: target_scaler not available.")
|
|
return None # Cannot inverse transform
|
|
|
|
# Scaler expects 2D input (N, 1)
|
|
# Ensure data is on CPU and is float64 for sklearn scaler typically
|
|
data_cpu = data.detach().cpu().numpy().astype(np.float64)
|
|
original_shape = data_cpu.shape
|
|
if data_cpu.ndim == 1:
|
|
data_flat = data_cpu.reshape(-1, 1)
|
|
elif data_cpu.ndim == 2: # (batch, horizon)
|
|
data_flat = data_cpu.reshape(-1, 1)
|
|
else:
|
|
logger.warning(f"Unexpected shape for inverse transform: {original_shape}. Reshaping to (-1, 1).")
|
|
data_flat = data_cpu.reshape(-1, 1)
|
|
|
|
try:
|
|
inversed_np = self.target_scaler.inverse_transform(data_flat)
|
|
# Return as tensor on the original device
|
|
inversed_tensor = torch.from_numpy(inversed_np).float().to(data.device)
|
|
# Reshape back? Or keep flat? Keep flat for direct metric use often.
|
|
return inversed_tensor.flatten()
|
|
# return inversed_tensor.reshape(original_shape) # If original shape needed
|
|
except Exception as e:
|
|
logger.error(f"Failed to inverse transform data: {e}", exc_info=True)
|
|
return None # Return None if inverse transform fails
|
|
|
|
|
|
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
|
|
x, y = batch # Shapes: x=(batch, seq_len, features), y=(batch, horizon)
|
|
outputs = self(x) # Scaled outputs: (batch, horizon)
|
|
loss = self._calculate_loss(outputs, y)
|
|
|
|
# Log scaled metrics
|
|
metrics = self.train_metrics(outputs, y) # Update internal state
|
|
self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
|
self.log_dict(self.train_metrics, on_step=False, on_epoch=True, logger=True) # Log all metrics in collection
|
|
|
|
return loss
|
|
|
|
def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
|
|
x, y = batch
|
|
outputs = self(x) # Scaled outputs
|
|
loss = self._calculate_loss(outputs, y)
|
|
|
|
# Log scaled metrics
|
|
metrics = self.val_metrics(outputs, y) # Update internal state
|
|
self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
|
self.log_dict(self.val_metrics, on_step=False, on_epoch=True, logger=True)
|
|
|
|
# Log MAE on ORIGINAL scale if scaler is available (often the primary metric for checkpointing/Optuna)
|
|
if self.target_scaler is not None:
|
|
outputs_inv = self._inverse_transform(outputs)
|
|
y_inv = self._inverse_transform(y)
|
|
|
|
if outputs_inv is not None and y_inv is not None:
|
|
# Ensure shapes are compatible (flattened by _inverse_transform)
|
|
if outputs_inv.shape == y_inv.shape:
|
|
self.val_mae_original_scale.update(outputs_inv, y_inv)
|
|
self.log('val_mae_orig_scale', self.val_mae_original_scale, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
|
else:
|
|
logger.warning(f"Shape mismatch after inverse transform in validation: Preds {outputs_inv.shape}, Targets {y_inv.shape}")
|
|
else:
|
|
logger.warning("Could not compute original scale MAE in validation due to inverse transform failure.")
|
|
|
|
|
|
def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
|
|
# Optional: Keep this method ONLY if you want trainer.test() to log metrics.
|
|
# For getting predictions for evaluation, use predict_step.
|
|
# If evaluate_fold_predictions handles all metrics, this can be simplified or removed.
|
|
# Let's simplify it for now to only log loss if needed.
|
|
try:
|
|
x, y = batch
|
|
outputs = self(x)
|
|
loss = self._calculate_loss(outputs, y)
|
|
# Log scaled test metrics if you still want trainer.test() to report them
|
|
metrics = self.test_metrics(outputs, y)
|
|
self.log('test_loss_step', loss, on_step=True, on_epoch=False) # Log step loss if needed
|
|
self.log_dict(self.test_metrics, on_step=False, on_epoch=True, logger=True)
|
|
# No return needed if just logging
|
|
except Exception as e:
|
|
logger.error(f"Error occurred in test_step for batch {batch_idx}: {e}", exc_info=True)
|
|
# Optionally log something to indicate failure
|
|
|
|
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Dict[str, torch.Tensor]:
|
|
"""
|
|
Runs inference for prediction and returns scaled predictions and targets.
|
|
'batch' might contain only features depending on the DataLoader setup for predict.
|
|
Let's assume the test_loader yields (x, y) pairs for convenience here.
|
|
"""
|
|
if isinstance(batch, (list, tuple)) and len(batch) == 2:
|
|
x, y = batch
|
|
else:
|
|
# Assume batch contains only features if not a pair
|
|
x = batch
|
|
y = None # No targets available during prediction if dataloader only yields features
|
|
|
|
outputs = self(x) # Scaled outputs
|
|
|
|
result = {'preds_scaled': outputs.detach().cpu()}
|
|
if y is not None:
|
|
# Include targets if they were part of the batch (e.g., using test_loader for predict)
|
|
result['targets_scaled'] = y.detach().cpu()
|
|
|
|
return result
|
|
|
|
def configure_optimizers(self) -> Union[optim.Optimizer, Tuple[List[optim.Optimizer], List[Dict[str, Any]]]]:
|
|
"""
|
|
Configure the optimizer (Adam) and optional LR scheduler.
|
|
"""
|
|
optimizer = optim.Adam(
|
|
self.parameters(),
|
|
lr=self.hparams.train_config.learning_rate # Access lr via hparams
|
|
)
|
|
logger.info(f"Configured Adam optimizer with LR: {self.hparams.train_config.learning_rate}")
|
|
|
|
# Optional LR Scheduler configuration
|
|
scheduler_config = None
|
|
if hasattr(self.hparams.train_config, 'scheduler_step_size') and \
|
|
self.hparams.train_config.scheduler_step_size is not None and \
|
|
hasattr(self.hparams.train_config, 'scheduler_gamma') and \
|
|
self.hparams.train_config.scheduler_gamma is not None:
|
|
|
|
if self.hparams.train_config.scheduler_step_size > 0 and 0 < self.hparams.train_config.scheduler_gamma < 1:
|
|
logger.info(f"Configuring StepLR scheduler with step_size={self.hparams.train_config.scheduler_step_size} "
|
|
f"and gamma={self.hparams.train_config.scheduler_gamma}")
|
|
scheduler = optim.lr_scheduler.StepLR(
|
|
optimizer,
|
|
step_size=self.hparams.train_config.scheduler_step_size,
|
|
gamma=self.hparams.train_config.scheduler_gamma
|
|
)
|
|
scheduler_config = {
|
|
'scheduler': scheduler,
|
|
'interval': 'epoch', # or 'step'
|
|
'frequency': 1,
|
|
'monitor': 'val_loss', # Optional: Only step if monitor improves (for ReduceLROnPlateau)
|
|
}
|
|
else:
|
|
logger.warning("Scheduler parameters provided but invalid (step_size must be >0, 0<gamma<1). No scheduler configured.")
|
|
|
|
# Example for ReduceLROnPlateau (if needed later)
|
|
# scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
|
|
# scheduler_config = {'scheduler': scheduler, 'monitor': 'val_loss'}
|
|
|
|
if scheduler_config:
|
|
return [optimizer], [scheduler_config]
|
|
else:
|
|
return optimizer |