""" Ensemble evaluation for time series forecasting models. This module provides functionality to evaluate ensemble predictions by combining predictions from n-1 folds and testing on the remaining fold. """ import logging import numpy as np import torch import yaml # For loading fold config from pathlib import Path from typing import Dict, List, Optional, Tuple, Union from sklearn.preprocessing import StandardScaler, MinMaxScaler import pandas as pd # For time index handling import pickle # Need pickle for the specific error check from forecasting_model.evaluation import evaluate_fold_predictions from forecasting_model.train.model import LSTMForecastLightningModule from forecasting_model.utils.forecast_config_model import MainConfig logger = logging.getLogger(__name__) def load_fold_model_and_objects( fold_dir: Path, ) -> Optional[Tuple[LSTMForecastLightningModule, MainConfig, torch.utils.data.DataLoader, Union[StandardScaler, MinMaxScaler, None], int, Optional[pd.Index], List[int]]]: """ Load a trained model, its config, dataloader, scaler, input_size, prediction time index, and forecast horizons. Args: fold_dir: Directory containing the fold's artifacts (checkpoint, config, loader, etc.). Returns: A tuple containing (model, config, test_loader, target_scaler, input_size, prediction_target_time_index, forecast_horizons) or None if any essential artifact is missing or loading fails. """ try: logger.info(f"Loading artifacts from: {fold_dir}") # 1. Load Fold Configuration config_path = fold_dir / "config.yaml" if not config_path.is_file(): logger.error(f"Fold config file not found in {fold_dir}") return None with open(config_path, 'r') as f: fold_config_dict = yaml.safe_load(f) fold_config = MainConfig(**fold_config_dict) # Validate fold's config # 2. Load Saved Objects using torch.load test_loader_path = fold_dir / "test_loader.pt" scaler_path = fold_dir / "target_scaler.pt" input_size_path = fold_dir / "input_size.pt" prediction_index_path = fold_dir / "prediction_target_time_index.pt" if not all([p.is_file() for p in [test_loader_path, scaler_path, input_size_path]]): logger.error(f"Missing one or more required artifacts (test_loader, target_scaler, input_size) in {fold_dir}") return None try: # --- Explicitly set weights_only=False for non-model objects --- test_loader = torch.load(test_loader_path, weights_only=False) target_scaler = torch.load(scaler_path, weights_only=False) input_size = torch.load(input_size_path, weights_only=False) # --- End Modification --- except pickle.UnpicklingError as e: # Catch potential unpickling errors even with weights_only=False logger.error(f"Failed to unpickle saved object in {fold_dir}: {e}", exc_info=True) return None except AttributeError as e: # Catch potential issues if class definitions changed between saving and loading logger.error(f"AttributeError loading saved object in {fold_dir} (class definition changed?): {e}", exc_info=True) return None except Exception as e: # Catch other potential loading errors logger.error(f"Unexpected error loading saved objects (loader/scaler/size) from {fold_dir}: {e}", exc_info=True) return None # Retrieve forecast horizon list from the fold's config forecast_horizons = fold_config.features.forecast_horizon # --- Extract prediction target time index (if available) --- prediction_target_time_index: Optional[pd.Index] = None if prediction_index_path.is_file(): try: prediction_target_time_index = torch.load(prediction_index_path, weights_only=False) # Basic validation if not isinstance(prediction_target_time_index, pd.Index): logger.warning(f"Loaded prediction index from {prediction_index_path} is not a pandas Index.") prediction_target_time_index = None else: logger.debug(f"Loaded prediction target time index from {prediction_index_path}") except Exception as e: logger.warning(f"Failed to load prediction target time index from {prediction_index_path}: {e}") else: logger.warning(f"Prediction target time index file not found at {prediction_index_path}. Plotting x-axis might be inaccurate for ensemble plots.") # --- End Index Extraction --- # 3. Find Checkpoint and Load Model checkpoint_path = None try: # Use rglob to find the checkpoint potentially nested deeper checkpoints = list(fold_dir.glob("**/best_model_fold_*.ckpt")) if not checkpoints: logger.error(f"No 'best_model_fold_*.ckpt' checkpoint found in {fold_dir} or subdirectories.") return None if len(checkpoints) > 1: logger.warning(f"Multiple checkpoints found in {fold_dir}, using the first one: {checkpoints[0]}") checkpoint_path = checkpoints[0] logger.info(f"Loading model from checkpoint: {checkpoint_path}") model = LSTMForecastLightningModule.load_from_checkpoint( checkpoint_path, map_location=torch.device('cpu'), # Optional: load to CPU first if memory is tight model_config=fold_config.model, train_config=fold_config.training, input_size=input_size, target_scaler=target_scaler ) model.eval() logger.info(f"Successfully loaded model and artifacts from {fold_dir}") return model, fold_config, test_loader, target_scaler, input_size, prediction_target_time_index, forecast_horizons except FileNotFoundError: logger.error(f"Checkpoint file not found: {checkpoint_path}") return None except Exception as e: logger.error(f"Failed to load model from checkpoint {checkpoint_path} in {fold_dir}: {e}", exc_info=True) return None except Exception as e: logger.error(f"Generic error loading artifacts from {fold_dir}: {e}", exc_info=True) return None def make_ensemble_predictions( models: List[LSTMForecastLightningModule], test_loader: torch.utils.data.DataLoader, device: Optional[torch.device] = None ) -> Tuple[Optional[Dict[str, np.ndarray]], Optional[np.ndarray]]: """ Make predictions using an ensemble of models efficiently. Processes the test_loader once, getting predictions from all models per batch. Args: models: List of trained models (already in eval mode). test_loader: DataLoader for the test set. device: Device to run predictions on (e.g., torch.device("cuda:0")). If None, attempts to use GPU if available, else CPU. Returns: Tuple of (ensemble_predictions, targets): - ensemble_predictions: Dict containing ensemble predictions keyed by method ('mean', 'median', 'min', 'max'). Values are np.arrays. Returns None if prediction fails. - targets: Ground truth values as a single np.array. Returns None if prediction fails or targets are unavailable in loader. """ if not models: logger.warning("make_ensemble_predictions received an empty list of models.") return None, None if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Running ensemble predictions on device: {device}") # Move all models to the target device for model in models: model.to(device) all_batch_preds: List[List[np.ndarray]] = [[] for _ in models] # Outer list: models, Inner list: batches all_batch_targets: List[np.ndarray] = [] targets_available = True with torch.no_grad(): for batch_idx, batch in enumerate(test_loader): try: # Determine if batch contains targets if isinstance(batch, (list, tuple)) and len(batch) == 2: x, y = batch x = x.to(device) # Keep targets on CPU until needed for concatenation all_batch_targets.append(y.cpu().numpy()) else: x = batch.to(device) targets_available = False # No targets found in this batch # Get predictions from all models for this batch for i, model in enumerate(models): try: pred = model(x) # Shape: (batch, horizon) all_batch_preds[i].append(pred.cpu().numpy()) except Exception as model_err: logger.error(f"Error during prediction with model {i} on batch {batch_idx}: {model_err}", exc_info=True) # Handle error: Fill with NaNs? Skip model? For now, fill with NaNs of expected shape # Infer expected shape: (batch_size, horizon) batch_size = x.shape[0] horizon = models[0].output_size # Assume all models have same horizon nan_preds = np.full((batch_size, horizon), np.nan) all_batch_preds[i].append(nan_preds) except Exception as batch_err: logger.error(f"Error processing batch {batch_idx} for ensemble prediction: {batch_err}", exc_info=True) # If a batch fails catastrophically, we might not be able to proceed reliably return None, None # Indicate failure # Concatenate batch results for each model model_preds_concat = [] for i in range(len(models)): if not all_batch_preds[i]: # Check if any predictions were collected for this model logger.warning(f"No predictions collected for model index {i}. Skipping this model in ensemble.") continue # Skip this model if it failed on all batches try: model_preds_concat.append(np.concatenate(all_batch_preds[i], axis=0)) except ValueError as e: logger.error(f"Failed to concatenate predictions for model index {i}: {e}. Check for shape mismatches or empty lists.") # Decide how to handle: skip model or fail? Let's skip for robustness. continue if not model_preds_concat: logger.error("No valid predictions collected from any model in the ensemble.") return None, None # Concatenate targets if available targets_concat = None if targets_available and all_batch_targets: try: targets_concat = np.concatenate(all_batch_targets, axis=0) except ValueError as e: logger.error(f"Failed to concatenate targets: {e}") return None, None # Fail if targets were expected but couldn't be combined elif targets_available and not all_batch_targets: logger.warning("Targets were expected based on first batch, but none were collected.") # Proceed without targets, returning None for them # Stack predictions from all models: Shape (num_models, num_samples, horizon) try: stacked_preds = np.stack(model_preds_concat, axis=0) except ValueError as e: logger.error(f"Failed to stack model predictions: {e}. Check if all models produced compatible shapes.") return None, targets_concat # Return targets if available, but no ensemble preds # Calculate different ensemble predictions (handle NaNs potentially introduced by model failures) # np.nanmean, np.nanmedian etc. ignore NaNs ensemble_preds = { 'mean': np.nanmean(stacked_preds, axis=0), 'median': np.nanmedian(stacked_preds, axis=0), 'min': np.nanmin(stacked_preds, axis=0), 'max': np.nanmax(stacked_preds, axis=0) } logger.info(f"Ensemble predictions generated using {stacked_preds.shape[0]} models.") return ensemble_preds, targets_concat def evaluate_ensemble_for_test_fold( test_fold_num: int, all_fold_dirs: List[Path], output_base_dir: Path, # full_data_index: Optional[pd.Index] = None # Removed, get from loaded objects ) -> Optional[Dict[str, Dict[str, float]]]: """ Evaluates ensemble predictions for a specific test fold. Args: test_fold_num: The 1-based number of the fold to use as the test set. all_fold_dirs: List of paths to all fold directories. output_base_dir: Base directory for saving evaluation results/plots. Returns: Dictionary containing metrics for each ensemble method for this test fold, or None if evaluation fails. """ logger.info(f"--- Evaluating Ensemble: Test Fold {test_fold_num} ---") test_fold_dir = output_base_dir / f"fold_{test_fold_num:02d}" load_result = load_fold_model_and_objects(test_fold_dir) if load_result is None: logger.error(f"Failed to load necessary artifacts for test fold {test_fold_num}. Skipping ensemble evaluation for this fold.") return None # Unpack results including the prediction time index and horizons _, test_fold_config, test_loader, target_scaler, _, prediction_target_time_index, test_forecast_horizons = load_result # Load models from all *other* folds ensemble_models: List[LSTMForecastLightningModule] = [] model_forecast_horizons = None # Track horizons from loaded models for i, fold_dir in enumerate(all_fold_dirs): current_fold_num = i + 1 if current_fold_num == test_fold_num: continue # Skip the test fold itself model_load_result = load_fold_model_and_objects(fold_dir) if model_load_result: model, _, _, _, _, _, fold_horizons = model_load_result # Only need the model here if model: ensemble_models.append(model) # Store horizons from the first successful model load if model_forecast_horizons is None: model_forecast_horizons = fold_horizons # Optional: Check consistency of horizons across ensemble models elif set(model_forecast_horizons) != set(fold_horizons): logger.error(f"Inconsistent forecast horizons between ensemble models! Test fold {test_fold_num} expected {test_forecast_horizons}, " f"Model {i+1} has {fold_horizons}. Ensemble may be invalid.") # Decide how to handle: error out, or proceed with caution? # return None # Option: Fail hard else: logger.warning(f"Could not load model from fold {current_fold_num} to include in ensemble for test fold {test_fold_num}.") if len(ensemble_models) < 2: logger.warning(f"Skipping ensemble evaluation for test fold {test_fold_num}: " f"Need at least 2 models for ensemble, only loaded {len(ensemble_models)}.") return {} # Return empty dict, not None, to indicate process ran but no ensemble formed # Check consistency between test fold horizons and ensemble model horizons if model_forecast_horizons is None: # Should not happen if len(ensemble_models) >= 1 logger.error(f"Could not determine forecast horizons from ensemble models for test fold {test_fold_num}.") return None if set(test_forecast_horizons) != set(model_forecast_horizons): logger.error(f"Forecast horizons of test fold {test_fold_num} ({test_forecast_horizons}) do not match " f"horizons from ensemble models ({model_forecast_horizons}). Cannot evaluate.") return None # Make ensemble predictions using the loaded models and the test fold's data loader # Use the test fold's config to determine device implicitly device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ensemble_preds_dict, targets_np = make_ensemble_predictions(ensemble_models, test_loader, device=device) if ensemble_preds_dict is None or targets_np is None: logger.error(f"Failed to generate ensemble predictions or retrieve targets for test fold {test_fold_num}.") return None # Indicate failure # Evaluate each ensemble method's predictions against the test fold's targets fold_ensemble_results: Dict[str, Dict[str, float]] = {} for method, preds_np in ensemble_preds_dict.items(): logger.info(f"Evaluating ensemble method '{method}' for test fold {test_fold_num}...") # Define a unique output directory for this method's plots method_plot_dir = output_base_dir / "ensemble_eval_plots" / f"test_fold_{test_fold_num:02d}" / f"method_{method}" # Use the prediction_target_time_index loaded earlier prediction_time_index_for_plot = None if prediction_target_time_index is not None: if len(prediction_target_time_index) == targets_np.shape[0]: prediction_time_index_for_plot = prediction_target_time_index else: logger.warning(f"Length of loaded prediction target time index ({len(prediction_target_time_index)}) does not match " f"number of samples ({targets_np.shape[0]}) for test fold {test_fold_num}, method '{method}'. Plot x-axis may be incorrect.") # Call the standard evaluation function metrics = evaluate_fold_predictions( y_true_scaled=targets_np, y_pred_scaled=preds_np, target_scaler=target_scaler, eval_config=test_fold_config.evaluation, fold_num=test_fold_num - 1, output_dir=str(method_plot_dir.parent.parent), plot_subdir=f"method_{method}", prediction_time_index=prediction_time_index_for_plot, # Pass the index forecast_horizons=test_forecast_horizons, plot_title_prefix=f"Ensemble ({method})" ) fold_ensemble_results[method] = metrics logger.info(f"--- Finished Ensemble Evaluation: Test Fold {test_fold_num} ---") return fold_ensemble_results def run_ensemble_evaluation( config: MainConfig, # Pass main config for context if needed, though fold configs are loaded output_base_dir: Path, # full_data_index: Optional[pd.Index] = None # Removed, get index from loaded objects ) -> Dict[int, Dict[str, Dict[str, float]]]: """ Run ensemble evaluation across all folds, treating each as the test set once. Args: config: The main configuration object (potentially unused if fold configs sufficient). output_base_dir: Base directory where fold outputs are stored. Returns: Dictionary containing ensemble metrics for each test fold: { test_fold_num: { ensemble_method: { metric_name: value, ... }, ... }, ... } """ logger.info("===== Starting Cross-Validated Ensemble Evaluation =====") all_ensemble_results: Dict[int, Dict[str, Dict[str, float]]] = {} # Discover fold directories fold_dirs = sorted([d for d in output_base_dir.glob("fold_*") if d.is_dir()]) if not fold_dirs: logger.error(f"No fold directories found in {output_base_dir} for ensemble evaluation.") return {} if len(fold_dirs) < 2: logger.warning(f"Need at least 2 folds for ensemble evaluation, found {len(fold_dirs)}. Skipping.") return {} logger.info(f"Found {len(fold_dirs)} fold directories.") # Iterate through each fold, designating it as the test fold for i, test_fold_dir in enumerate(fold_dirs): test_fold_num = i + 1 # 1-based fold number try: results_for_test_fold = evaluate_ensemble_for_test_fold( test_fold_num=test_fold_num, all_fold_dirs=fold_dirs, output_base_dir=output_base_dir, # full_data_index=full_data_index # Removed ) if results_for_test_fold is not None: # Only add results if the evaluation didn't fail completely all_ensemble_results[test_fold_num] = results_for_test_fold except Exception as e: # Catch unexpected errors during a specific test fold evaluation logger.error(f"Unexpected error during ensemble evaluation with test fold {test_fold_num}: {e}", exc_info=True) continue # Continue to the next fold # Saving is handled by the main script (`forecasting_model_run.py`) which calls this if not all_ensemble_results: logger.warning("Ensemble evaluation finished, but no results were generated.") else: logger.info("===== Finished Cross-Validated Ensemble Evaluation =====") return all_ensemble_results