intermediate backup
This commit is contained in:
425
forecasting_model/train/ensemble_evaluation.py
Normal file
425
forecasting_model/train/ensemble_evaluation.py
Normal file
@ -0,0 +1,425 @@
|
||||
"""
|
||||
Ensemble evaluation for time series forecasting models.
|
||||
|
||||
This module provides functionality to evaluate ensemble predictions
|
||||
by combining predictions from n-1 folds and testing on the remaining fold.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml # For loading fold config
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
||||
import pandas as pd # For time index handling
|
||||
import pickle # Need pickle for the specific error check
|
||||
|
||||
from forecasting_model.evaluation import evaluate_fold_predictions
|
||||
from forecasting_model.train.model import LSTMForecastLightningModule
|
||||
from forecasting_model.utils.forecast_config_model import MainConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def load_fold_model_and_objects(
|
||||
fold_dir: Path,
|
||||
) -> Optional[Tuple[LSTMForecastLightningModule, MainConfig, torch.utils.data.DataLoader, Union[StandardScaler, MinMaxScaler, None], int, Optional[pd.Index], List[int]]]:
|
||||
"""
|
||||
Load a trained model, its config, dataloader, scaler, input_size, prediction time index, and forecast horizons.
|
||||
|
||||
Args:
|
||||
fold_dir: Directory containing the fold's artifacts (checkpoint, config, loader, etc.).
|
||||
|
||||
Returns:
|
||||
A tuple containing (model, config, test_loader, target_scaler, input_size, prediction_target_time_index, forecast_horizons)
|
||||
or None if any essential artifact is missing or loading fails.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Loading artifacts from: {fold_dir}")
|
||||
|
||||
# 1. Load Fold Configuration
|
||||
config_path = fold_dir / "config.yaml"
|
||||
if not config_path.is_file():
|
||||
logger.error(f"Fold config file not found in {fold_dir}")
|
||||
return None
|
||||
with open(config_path, 'r') as f:
|
||||
fold_config_dict = yaml.safe_load(f)
|
||||
fold_config = MainConfig(**fold_config_dict) # Validate fold's config
|
||||
|
||||
# 2. Load Saved Objects using torch.load
|
||||
test_loader_path = fold_dir / "test_loader.pt"
|
||||
scaler_path = fold_dir / "target_scaler.pt"
|
||||
input_size_path = fold_dir / "input_size.pt"
|
||||
prediction_index_path = fold_dir / "prediction_target_time_index.pt"
|
||||
|
||||
if not all([p.is_file() for p in [test_loader_path, scaler_path, input_size_path]]):
|
||||
logger.error(f"Missing one or more required artifacts (test_loader, target_scaler, input_size) in {fold_dir}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# --- Explicitly set weights_only=False for non-model objects ---
|
||||
test_loader = torch.load(test_loader_path, weights_only=False)
|
||||
target_scaler = torch.load(scaler_path, weights_only=False)
|
||||
input_size = torch.load(input_size_path, weights_only=False)
|
||||
# --- End Modification ---
|
||||
except pickle.UnpicklingError as e:
|
||||
# Catch potential unpickling errors even with weights_only=False
|
||||
logger.error(f"Failed to unpickle saved object in {fold_dir}: {e}", exc_info=True)
|
||||
return None
|
||||
except AttributeError as e:
|
||||
# Catch potential issues if class definitions changed between saving and loading
|
||||
logger.error(f"AttributeError loading saved object in {fold_dir} (class definition changed?): {e}", exc_info=True)
|
||||
return None
|
||||
except Exception as e:
|
||||
# Catch other potential loading errors
|
||||
logger.error(f"Unexpected error loading saved objects (loader/scaler/size) from {fold_dir}: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
# Retrieve forecast horizon list from the fold's config
|
||||
forecast_horizons = fold_config.features.forecast_horizon
|
||||
|
||||
# --- Extract prediction target time index (if available) ---
|
||||
prediction_target_time_index: Optional[pd.Index] = None
|
||||
if prediction_index_path.is_file():
|
||||
try:
|
||||
prediction_target_time_index = torch.load(prediction_index_path, weights_only=False)
|
||||
# Basic validation
|
||||
if not isinstance(prediction_target_time_index, pd.Index):
|
||||
logger.warning(f"Loaded prediction index from {prediction_index_path} is not a pandas Index.")
|
||||
prediction_target_time_index = None
|
||||
else:
|
||||
logger.debug(f"Loaded prediction target time index from {prediction_index_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load prediction target time index from {prediction_index_path}: {e}")
|
||||
else:
|
||||
logger.warning(f"Prediction target time index file not found at {prediction_index_path}. Plotting x-axis might be inaccurate for ensemble plots.")
|
||||
# --- End Index Extraction ---
|
||||
|
||||
|
||||
# 3. Find Checkpoint and Load Model
|
||||
checkpoint_path = None
|
||||
try:
|
||||
# Use rglob to find the checkpoint potentially nested deeper
|
||||
checkpoints = list(fold_dir.glob("**/best_model_fold_*.ckpt"))
|
||||
if not checkpoints:
|
||||
logger.error(f"No 'best_model_fold_*.ckpt' checkpoint found in {fold_dir} or subdirectories.")
|
||||
return None
|
||||
if len(checkpoints) > 1:
|
||||
logger.warning(f"Multiple checkpoints found in {fold_dir}, using the first one: {checkpoints[0]}")
|
||||
checkpoint_path = checkpoints[0]
|
||||
|
||||
logger.info(f"Loading model from checkpoint: {checkpoint_path}")
|
||||
model = LSTMForecastLightningModule.load_from_checkpoint(
|
||||
checkpoint_path,
|
||||
map_location=torch.device('cpu'), # Optional: load to CPU first if memory is tight
|
||||
model_config=fold_config.model,
|
||||
train_config=fold_config.training,
|
||||
input_size=input_size,
|
||||
target_scaler=target_scaler
|
||||
)
|
||||
model.eval()
|
||||
logger.info(f"Successfully loaded model and artifacts from {fold_dir}")
|
||||
return model, fold_config, test_loader, target_scaler, input_size, prediction_target_time_index, forecast_horizons
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Checkpoint file not found: {checkpoint_path}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model from checkpoint {checkpoint_path} in {fold_dir}: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Generic error loading artifacts from {fold_dir}: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def make_ensemble_predictions(
|
||||
models: List[LSTMForecastLightningModule],
|
||||
test_loader: torch.utils.data.DataLoader,
|
||||
device: Optional[torch.device] = None
|
||||
) -> Tuple[Optional[Dict[str, np.ndarray]], Optional[np.ndarray]]:
|
||||
"""
|
||||
Make predictions using an ensemble of models efficiently.
|
||||
|
||||
Processes the test_loader once, getting predictions from all models per batch.
|
||||
|
||||
Args:
|
||||
models: List of trained models (already in eval mode).
|
||||
test_loader: DataLoader for the test set.
|
||||
device: Device to run predictions on (e.g., torch.device("cuda:0")).
|
||||
If None, attempts to use GPU if available, else CPU.
|
||||
|
||||
Returns:
|
||||
Tuple of (ensemble_predictions, targets):
|
||||
- ensemble_predictions: Dict containing ensemble predictions keyed by method
|
||||
('mean', 'median', 'min', 'max'). Values are np.arrays.
|
||||
Returns None if prediction fails.
|
||||
- targets: Ground truth values as a single np.array. Returns None if prediction fails
|
||||
or targets are unavailable in loader.
|
||||
"""
|
||||
if not models:
|
||||
logger.warning("make_ensemble_predictions received an empty list of models.")
|
||||
return None, None
|
||||
|
||||
if device is None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logger.info(f"Running ensemble predictions on device: {device}")
|
||||
|
||||
# Move all models to the target device
|
||||
for model in models:
|
||||
model.to(device)
|
||||
|
||||
all_batch_preds: List[List[np.ndarray]] = [[] for _ in models] # Outer list: models, Inner list: batches
|
||||
all_batch_targets: List[np.ndarray] = []
|
||||
targets_available = True
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, batch in enumerate(test_loader):
|
||||
try:
|
||||
# Determine if batch contains targets
|
||||
if isinstance(batch, (list, tuple)) and len(batch) == 2:
|
||||
x, y = batch
|
||||
x = x.to(device)
|
||||
# Keep targets on CPU until needed for concatenation
|
||||
all_batch_targets.append(y.cpu().numpy())
|
||||
else:
|
||||
x = batch.to(device)
|
||||
targets_available = False # No targets found in this batch
|
||||
|
||||
# Get predictions from all models for this batch
|
||||
for i, model in enumerate(models):
|
||||
try:
|
||||
pred = model(x) # Shape: (batch, horizon)
|
||||
all_batch_preds[i].append(pred.cpu().numpy())
|
||||
except Exception as model_err:
|
||||
logger.error(f"Error during prediction with model {i} on batch {batch_idx}: {model_err}", exc_info=True)
|
||||
# Handle error: Fill with NaNs? Skip model? For now, fill with NaNs of expected shape
|
||||
# Infer expected shape: (batch_size, horizon)
|
||||
batch_size = x.shape[0]
|
||||
horizon = models[0].output_size # Assume all models have same horizon
|
||||
nan_preds = np.full((batch_size, horizon), np.nan)
|
||||
all_batch_preds[i].append(nan_preds)
|
||||
|
||||
|
||||
except Exception as batch_err:
|
||||
logger.error(f"Error processing batch {batch_idx} for ensemble prediction: {batch_err}", exc_info=True)
|
||||
# If a batch fails catastrophically, we might not be able to proceed reliably
|
||||
return None, None # Indicate failure
|
||||
|
||||
# Concatenate batch results for each model
|
||||
model_preds_concat = []
|
||||
for i in range(len(models)):
|
||||
if not all_batch_preds[i]: # Check if any predictions were collected for this model
|
||||
logger.warning(f"No predictions collected for model index {i}. Skipping this model in ensemble.")
|
||||
continue # Skip this model if it failed on all batches
|
||||
try:
|
||||
model_preds_concat.append(np.concatenate(all_batch_preds[i], axis=0))
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to concatenate predictions for model index {i}: {e}. Check for shape mismatches or empty lists.")
|
||||
# Decide how to handle: skip model or fail? Let's skip for robustness.
|
||||
continue
|
||||
|
||||
if not model_preds_concat:
|
||||
logger.error("No valid predictions collected from any model in the ensemble.")
|
||||
return None, None
|
||||
|
||||
# Concatenate targets if available
|
||||
targets_concat = None
|
||||
if targets_available and all_batch_targets:
|
||||
try:
|
||||
targets_concat = np.concatenate(all_batch_targets, axis=0)
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to concatenate targets: {e}")
|
||||
return None, None # Fail if targets were expected but couldn't be combined
|
||||
elif targets_available and not all_batch_targets:
|
||||
logger.warning("Targets were expected based on first batch, but none were collected.")
|
||||
# Proceed without targets, returning None for them
|
||||
|
||||
# Stack predictions from all models: Shape (num_models, num_samples, horizon)
|
||||
try:
|
||||
stacked_preds = np.stack(model_preds_concat, axis=0)
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to stack model predictions: {e}. Check if all models produced compatible shapes.")
|
||||
return None, targets_concat # Return targets if available, but no ensemble preds
|
||||
|
||||
# Calculate different ensemble predictions (handle NaNs potentially introduced by model failures)
|
||||
# np.nanmean, np.nanmedian etc. ignore NaNs
|
||||
ensemble_preds = {
|
||||
'mean': np.nanmean(stacked_preds, axis=0),
|
||||
'median': np.nanmedian(stacked_preds, axis=0),
|
||||
'min': np.nanmin(stacked_preds, axis=0),
|
||||
'max': np.nanmax(stacked_preds, axis=0)
|
||||
}
|
||||
|
||||
logger.info(f"Ensemble predictions generated using {stacked_preds.shape[0]} models.")
|
||||
return ensemble_preds, targets_concat
|
||||
|
||||
|
||||
def evaluate_ensemble_for_test_fold(
|
||||
test_fold_num: int,
|
||||
all_fold_dirs: List[Path],
|
||||
output_base_dir: Path,
|
||||
# full_data_index: Optional[pd.Index] = None # Removed, get from loaded objects
|
||||
) -> Optional[Dict[str, Dict[str, float]]]:
|
||||
"""
|
||||
Evaluates ensemble predictions for a specific test fold.
|
||||
Args:
|
||||
test_fold_num: The 1-based number of the fold to use as the test set.
|
||||
all_fold_dirs: List of paths to all fold directories.
|
||||
output_base_dir: Base directory for saving evaluation results/plots.
|
||||
Returns:
|
||||
Dictionary containing metrics for each ensemble method for this test fold,
|
||||
or None if evaluation fails.
|
||||
"""
|
||||
logger.info(f"--- Evaluating Ensemble: Test Fold {test_fold_num} ---")
|
||||
test_fold_dir = output_base_dir / f"fold_{test_fold_num:02d}"
|
||||
|
||||
load_result = load_fold_model_and_objects(test_fold_dir)
|
||||
if load_result is None:
|
||||
logger.error(f"Failed to load necessary artifacts for test fold {test_fold_num}. Skipping ensemble evaluation for this fold.")
|
||||
return None
|
||||
# Unpack results including the prediction time index and horizons
|
||||
_, test_fold_config, test_loader, target_scaler, _, prediction_target_time_index, test_forecast_horizons = load_result
|
||||
|
||||
# Load models from all *other* folds
|
||||
ensemble_models: List[LSTMForecastLightningModule] = []
|
||||
model_forecast_horizons = None # Track horizons from loaded models
|
||||
for i, fold_dir in enumerate(all_fold_dirs):
|
||||
current_fold_num = i + 1
|
||||
if current_fold_num == test_fold_num:
|
||||
continue # Skip the test fold itself
|
||||
|
||||
model_load_result = load_fold_model_and_objects(fold_dir)
|
||||
if model_load_result:
|
||||
model, _, _, _, _, _, fold_horizons = model_load_result # Only need the model here
|
||||
if model:
|
||||
ensemble_models.append(model)
|
||||
# Store horizons from the first successful model load
|
||||
if model_forecast_horizons is None:
|
||||
model_forecast_horizons = fold_horizons
|
||||
# Optional: Check consistency of horizons across ensemble models
|
||||
elif set(model_forecast_horizons) != set(fold_horizons):
|
||||
logger.error(f"Inconsistent forecast horizons between ensemble models! Test fold {test_fold_num} expected {test_forecast_horizons}, "
|
||||
f"Model {i+1} has {fold_horizons}. Ensemble may be invalid.")
|
||||
# Decide how to handle: error out, or proceed with caution?
|
||||
# return None # Option: Fail hard
|
||||
else:
|
||||
logger.warning(f"Could not load model from fold {current_fold_num} to include in ensemble for test fold {test_fold_num}.")
|
||||
|
||||
|
||||
if len(ensemble_models) < 2:
|
||||
logger.warning(f"Skipping ensemble evaluation for test fold {test_fold_num}: "
|
||||
f"Need at least 2 models for ensemble, only loaded {len(ensemble_models)}.")
|
||||
return {} # Return empty dict, not None, to indicate process ran but no ensemble formed
|
||||
|
||||
# Check consistency between test fold horizons and ensemble model horizons
|
||||
if model_forecast_horizons is None: # Should not happen if len(ensemble_models) >= 1
|
||||
logger.error(f"Could not determine forecast horizons from ensemble models for test fold {test_fold_num}.")
|
||||
return None
|
||||
if set(test_forecast_horizons) != set(model_forecast_horizons):
|
||||
logger.error(f"Forecast horizons of test fold {test_fold_num} ({test_forecast_horizons}) do not match "
|
||||
f"horizons from ensemble models ({model_forecast_horizons}). Cannot evaluate.")
|
||||
return None
|
||||
|
||||
# Make ensemble predictions using the loaded models and the test fold's data loader
|
||||
# Use the test fold's config to determine device implicitly
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
ensemble_preds_dict, targets_np = make_ensemble_predictions(ensemble_models, test_loader, device=device)
|
||||
|
||||
if ensemble_preds_dict is None or targets_np is None:
|
||||
logger.error(f"Failed to generate ensemble predictions or retrieve targets for test fold {test_fold_num}.")
|
||||
return None # Indicate failure
|
||||
|
||||
# Evaluate each ensemble method's predictions against the test fold's targets
|
||||
fold_ensemble_results: Dict[str, Dict[str, float]] = {}
|
||||
for method, preds_np in ensemble_preds_dict.items():
|
||||
logger.info(f"Evaluating ensemble method '{method}' for test fold {test_fold_num}...")
|
||||
|
||||
# Define a unique output directory for this method's plots
|
||||
method_plot_dir = output_base_dir / "ensemble_eval_plots" / f"test_fold_{test_fold_num:02d}" / f"method_{method}"
|
||||
|
||||
# Use the prediction_target_time_index loaded earlier
|
||||
prediction_time_index_for_plot = None
|
||||
if prediction_target_time_index is not None:
|
||||
if len(prediction_target_time_index) == targets_np.shape[0]:
|
||||
prediction_time_index_for_plot = prediction_target_time_index
|
||||
else:
|
||||
logger.warning(f"Length of loaded prediction target time index ({len(prediction_target_time_index)}) does not match "
|
||||
f"number of samples ({targets_np.shape[0]}) for test fold {test_fold_num}, method '{method}'. Plot x-axis may be incorrect.")
|
||||
|
||||
|
||||
# Call the standard evaluation function
|
||||
metrics = evaluate_fold_predictions(
|
||||
y_true_scaled=targets_np,
|
||||
y_pred_scaled=preds_np,
|
||||
target_scaler=target_scaler,
|
||||
eval_config=test_fold_config.evaluation,
|
||||
fold_num=test_fold_num - 1,
|
||||
output_dir=str(method_plot_dir.parent.parent),
|
||||
plot_subdir=f"method_{method}",
|
||||
prediction_time_index=prediction_time_index_for_plot, # Pass the index
|
||||
forecast_horizons=test_forecast_horizons,
|
||||
plot_title_prefix=f"Ensemble ({method})"
|
||||
)
|
||||
fold_ensemble_results[method] = metrics
|
||||
|
||||
logger.info(f"--- Finished Ensemble Evaluation: Test Fold {test_fold_num} ---")
|
||||
return fold_ensemble_results
|
||||
|
||||
|
||||
def run_ensemble_evaluation(
|
||||
config: MainConfig, # Pass main config for context if needed, though fold configs are loaded
|
||||
output_base_dir: Path,
|
||||
# full_data_index: Optional[pd.Index] = None # Removed, get index from loaded objects
|
||||
) -> Dict[int, Dict[str, Dict[str, float]]]:
|
||||
"""
|
||||
Run ensemble evaluation across all folds, treating each as the test set once.
|
||||
|
||||
Args:
|
||||
config: The main configuration object (potentially unused if fold configs sufficient).
|
||||
output_base_dir: Base directory where fold outputs are stored.
|
||||
Returns:
|
||||
Dictionary containing ensemble metrics for each test fold:
|
||||
{ test_fold_num: { ensemble_method: { metric_name: value, ... }, ... }, ... }
|
||||
"""
|
||||
logger.info("===== Starting Cross-Validated Ensemble Evaluation =====")
|
||||
all_ensemble_results: Dict[int, Dict[str, Dict[str, float]]] = {}
|
||||
|
||||
# Discover fold directories
|
||||
fold_dirs = sorted([d for d in output_base_dir.glob("fold_*") if d.is_dir()])
|
||||
if not fold_dirs:
|
||||
logger.error(f"No fold directories found in {output_base_dir} for ensemble evaluation.")
|
||||
return {}
|
||||
if len(fold_dirs) < 2:
|
||||
logger.warning(f"Need at least 2 folds for ensemble evaluation, found {len(fold_dirs)}. Skipping.")
|
||||
return {}
|
||||
|
||||
logger.info(f"Found {len(fold_dirs)} fold directories.")
|
||||
|
||||
# Iterate through each fold, designating it as the test fold
|
||||
for i, test_fold_dir in enumerate(fold_dirs):
|
||||
test_fold_num = i + 1 # 1-based fold number
|
||||
try:
|
||||
results_for_test_fold = evaluate_ensemble_for_test_fold(
|
||||
test_fold_num=test_fold_num,
|
||||
all_fold_dirs=fold_dirs,
|
||||
output_base_dir=output_base_dir,
|
||||
# full_data_index=full_data_index # Removed
|
||||
)
|
||||
|
||||
if results_for_test_fold is not None:
|
||||
# Only add results if the evaluation didn't fail completely
|
||||
all_ensemble_results[test_fold_num] = results_for_test_fold
|
||||
|
||||
except Exception as e:
|
||||
# Catch unexpected errors during a specific test fold evaluation
|
||||
logger.error(f"Unexpected error during ensemble evaluation with test fold {test_fold_num}: {e}", exc_info=True)
|
||||
continue # Continue to the next fold
|
||||
|
||||
# Saving is handled by the main script (`forecasting_model_run.py`) which calls this
|
||||
if not all_ensemble_results:
|
||||
logger.warning("Ensemble evaluation finished, but no results were generated.")
|
||||
else:
|
||||
logger.info("===== Finished Cross-Validated Ensemble Evaluation =====")
|
||||
|
||||
return all_ensemble_results
|
Reference in New Issue
Block a user