intermediate backup

This commit is contained in:
2025-05-03 20:46:14 +02:00
parent 2b0a5728d4
commit 6542caf48f
38 changed files with 4513 additions and 1067 deletions

View File

@ -0,0 +1,425 @@
"""
Ensemble evaluation for time series forecasting models.
This module provides functionality to evaluate ensemble predictions
by combining predictions from n-1 folds and testing on the remaining fold.
"""
import logging
import numpy as np
import torch
import yaml # For loading fold config
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import pandas as pd # For time index handling
import pickle # Need pickle for the specific error check
from forecasting_model.evaluation import evaluate_fold_predictions
from forecasting_model.train.model import LSTMForecastLightningModule
from forecasting_model.utils.forecast_config_model import MainConfig
logger = logging.getLogger(__name__)
def load_fold_model_and_objects(
fold_dir: Path,
) -> Optional[Tuple[LSTMForecastLightningModule, MainConfig, torch.utils.data.DataLoader, Union[StandardScaler, MinMaxScaler, None], int, Optional[pd.Index], List[int]]]:
"""
Load a trained model, its config, dataloader, scaler, input_size, prediction time index, and forecast horizons.
Args:
fold_dir: Directory containing the fold's artifacts (checkpoint, config, loader, etc.).
Returns:
A tuple containing (model, config, test_loader, target_scaler, input_size, prediction_target_time_index, forecast_horizons)
or None if any essential artifact is missing or loading fails.
"""
try:
logger.info(f"Loading artifacts from: {fold_dir}")
# 1. Load Fold Configuration
config_path = fold_dir / "config.yaml"
if not config_path.is_file():
logger.error(f"Fold config file not found in {fold_dir}")
return None
with open(config_path, 'r') as f:
fold_config_dict = yaml.safe_load(f)
fold_config = MainConfig(**fold_config_dict) # Validate fold's config
# 2. Load Saved Objects using torch.load
test_loader_path = fold_dir / "test_loader.pt"
scaler_path = fold_dir / "target_scaler.pt"
input_size_path = fold_dir / "input_size.pt"
prediction_index_path = fold_dir / "prediction_target_time_index.pt"
if not all([p.is_file() for p in [test_loader_path, scaler_path, input_size_path]]):
logger.error(f"Missing one or more required artifacts (test_loader, target_scaler, input_size) in {fold_dir}")
return None
try:
# --- Explicitly set weights_only=False for non-model objects ---
test_loader = torch.load(test_loader_path, weights_only=False)
target_scaler = torch.load(scaler_path, weights_only=False)
input_size = torch.load(input_size_path, weights_only=False)
# --- End Modification ---
except pickle.UnpicklingError as e:
# Catch potential unpickling errors even with weights_only=False
logger.error(f"Failed to unpickle saved object in {fold_dir}: {e}", exc_info=True)
return None
except AttributeError as e:
# Catch potential issues if class definitions changed between saving and loading
logger.error(f"AttributeError loading saved object in {fold_dir} (class definition changed?): {e}", exc_info=True)
return None
except Exception as e:
# Catch other potential loading errors
logger.error(f"Unexpected error loading saved objects (loader/scaler/size) from {fold_dir}: {e}", exc_info=True)
return None
# Retrieve forecast horizon list from the fold's config
forecast_horizons = fold_config.features.forecast_horizon
# --- Extract prediction target time index (if available) ---
prediction_target_time_index: Optional[pd.Index] = None
if prediction_index_path.is_file():
try:
prediction_target_time_index = torch.load(prediction_index_path, weights_only=False)
# Basic validation
if not isinstance(prediction_target_time_index, pd.Index):
logger.warning(f"Loaded prediction index from {prediction_index_path} is not a pandas Index.")
prediction_target_time_index = None
else:
logger.debug(f"Loaded prediction target time index from {prediction_index_path}")
except Exception as e:
logger.warning(f"Failed to load prediction target time index from {prediction_index_path}: {e}")
else:
logger.warning(f"Prediction target time index file not found at {prediction_index_path}. Plotting x-axis might be inaccurate for ensemble plots.")
# --- End Index Extraction ---
# 3. Find Checkpoint and Load Model
checkpoint_path = None
try:
# Use rglob to find the checkpoint potentially nested deeper
checkpoints = list(fold_dir.glob("**/best_model_fold_*.ckpt"))
if not checkpoints:
logger.error(f"No 'best_model_fold_*.ckpt' checkpoint found in {fold_dir} or subdirectories.")
return None
if len(checkpoints) > 1:
logger.warning(f"Multiple checkpoints found in {fold_dir}, using the first one: {checkpoints[0]}")
checkpoint_path = checkpoints[0]
logger.info(f"Loading model from checkpoint: {checkpoint_path}")
model = LSTMForecastLightningModule.load_from_checkpoint(
checkpoint_path,
map_location=torch.device('cpu'), # Optional: load to CPU first if memory is tight
model_config=fold_config.model,
train_config=fold_config.training,
input_size=input_size,
target_scaler=target_scaler
)
model.eval()
logger.info(f"Successfully loaded model and artifacts from {fold_dir}")
return model, fold_config, test_loader, target_scaler, input_size, prediction_target_time_index, forecast_horizons
except FileNotFoundError:
logger.error(f"Checkpoint file not found: {checkpoint_path}")
return None
except Exception as e:
logger.error(f"Failed to load model from checkpoint {checkpoint_path} in {fold_dir}: {e}", exc_info=True)
return None
except Exception as e:
logger.error(f"Generic error loading artifacts from {fold_dir}: {e}", exc_info=True)
return None
def make_ensemble_predictions(
models: List[LSTMForecastLightningModule],
test_loader: torch.utils.data.DataLoader,
device: Optional[torch.device] = None
) -> Tuple[Optional[Dict[str, np.ndarray]], Optional[np.ndarray]]:
"""
Make predictions using an ensemble of models efficiently.
Processes the test_loader once, getting predictions from all models per batch.
Args:
models: List of trained models (already in eval mode).
test_loader: DataLoader for the test set.
device: Device to run predictions on (e.g., torch.device("cuda:0")).
If None, attempts to use GPU if available, else CPU.
Returns:
Tuple of (ensemble_predictions, targets):
- ensemble_predictions: Dict containing ensemble predictions keyed by method
('mean', 'median', 'min', 'max'). Values are np.arrays.
Returns None if prediction fails.
- targets: Ground truth values as a single np.array. Returns None if prediction fails
or targets are unavailable in loader.
"""
if not models:
logger.warning("make_ensemble_predictions received an empty list of models.")
return None, None
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Running ensemble predictions on device: {device}")
# Move all models to the target device
for model in models:
model.to(device)
all_batch_preds: List[List[np.ndarray]] = [[] for _ in models] # Outer list: models, Inner list: batches
all_batch_targets: List[np.ndarray] = []
targets_available = True
with torch.no_grad():
for batch_idx, batch in enumerate(test_loader):
try:
# Determine if batch contains targets
if isinstance(batch, (list, tuple)) and len(batch) == 2:
x, y = batch
x = x.to(device)
# Keep targets on CPU until needed for concatenation
all_batch_targets.append(y.cpu().numpy())
else:
x = batch.to(device)
targets_available = False # No targets found in this batch
# Get predictions from all models for this batch
for i, model in enumerate(models):
try:
pred = model(x) # Shape: (batch, horizon)
all_batch_preds[i].append(pred.cpu().numpy())
except Exception as model_err:
logger.error(f"Error during prediction with model {i} on batch {batch_idx}: {model_err}", exc_info=True)
# Handle error: Fill with NaNs? Skip model? For now, fill with NaNs of expected shape
# Infer expected shape: (batch_size, horizon)
batch_size = x.shape[0]
horizon = models[0].output_size # Assume all models have same horizon
nan_preds = np.full((batch_size, horizon), np.nan)
all_batch_preds[i].append(nan_preds)
except Exception as batch_err:
logger.error(f"Error processing batch {batch_idx} for ensemble prediction: {batch_err}", exc_info=True)
# If a batch fails catastrophically, we might not be able to proceed reliably
return None, None # Indicate failure
# Concatenate batch results for each model
model_preds_concat = []
for i in range(len(models)):
if not all_batch_preds[i]: # Check if any predictions were collected for this model
logger.warning(f"No predictions collected for model index {i}. Skipping this model in ensemble.")
continue # Skip this model if it failed on all batches
try:
model_preds_concat.append(np.concatenate(all_batch_preds[i], axis=0))
except ValueError as e:
logger.error(f"Failed to concatenate predictions for model index {i}: {e}. Check for shape mismatches or empty lists.")
# Decide how to handle: skip model or fail? Let's skip for robustness.
continue
if not model_preds_concat:
logger.error("No valid predictions collected from any model in the ensemble.")
return None, None
# Concatenate targets if available
targets_concat = None
if targets_available and all_batch_targets:
try:
targets_concat = np.concatenate(all_batch_targets, axis=0)
except ValueError as e:
logger.error(f"Failed to concatenate targets: {e}")
return None, None # Fail if targets were expected but couldn't be combined
elif targets_available and not all_batch_targets:
logger.warning("Targets were expected based on first batch, but none were collected.")
# Proceed without targets, returning None for them
# Stack predictions from all models: Shape (num_models, num_samples, horizon)
try:
stacked_preds = np.stack(model_preds_concat, axis=0)
except ValueError as e:
logger.error(f"Failed to stack model predictions: {e}. Check if all models produced compatible shapes.")
return None, targets_concat # Return targets if available, but no ensemble preds
# Calculate different ensemble predictions (handle NaNs potentially introduced by model failures)
# np.nanmean, np.nanmedian etc. ignore NaNs
ensemble_preds = {
'mean': np.nanmean(stacked_preds, axis=0),
'median': np.nanmedian(stacked_preds, axis=0),
'min': np.nanmin(stacked_preds, axis=0),
'max': np.nanmax(stacked_preds, axis=0)
}
logger.info(f"Ensemble predictions generated using {stacked_preds.shape[0]} models.")
return ensemble_preds, targets_concat
def evaluate_ensemble_for_test_fold(
test_fold_num: int,
all_fold_dirs: List[Path],
output_base_dir: Path,
# full_data_index: Optional[pd.Index] = None # Removed, get from loaded objects
) -> Optional[Dict[str, Dict[str, float]]]:
"""
Evaluates ensemble predictions for a specific test fold.
Args:
test_fold_num: The 1-based number of the fold to use as the test set.
all_fold_dirs: List of paths to all fold directories.
output_base_dir: Base directory for saving evaluation results/plots.
Returns:
Dictionary containing metrics for each ensemble method for this test fold,
or None if evaluation fails.
"""
logger.info(f"--- Evaluating Ensemble: Test Fold {test_fold_num} ---")
test_fold_dir = output_base_dir / f"fold_{test_fold_num:02d}"
load_result = load_fold_model_and_objects(test_fold_dir)
if load_result is None:
logger.error(f"Failed to load necessary artifacts for test fold {test_fold_num}. Skipping ensemble evaluation for this fold.")
return None
# Unpack results including the prediction time index and horizons
_, test_fold_config, test_loader, target_scaler, _, prediction_target_time_index, test_forecast_horizons = load_result
# Load models from all *other* folds
ensemble_models: List[LSTMForecastLightningModule] = []
model_forecast_horizons = None # Track horizons from loaded models
for i, fold_dir in enumerate(all_fold_dirs):
current_fold_num = i + 1
if current_fold_num == test_fold_num:
continue # Skip the test fold itself
model_load_result = load_fold_model_and_objects(fold_dir)
if model_load_result:
model, _, _, _, _, _, fold_horizons = model_load_result # Only need the model here
if model:
ensemble_models.append(model)
# Store horizons from the first successful model load
if model_forecast_horizons is None:
model_forecast_horizons = fold_horizons
# Optional: Check consistency of horizons across ensemble models
elif set(model_forecast_horizons) != set(fold_horizons):
logger.error(f"Inconsistent forecast horizons between ensemble models! Test fold {test_fold_num} expected {test_forecast_horizons}, "
f"Model {i+1} has {fold_horizons}. Ensemble may be invalid.")
# Decide how to handle: error out, or proceed with caution?
# return None # Option: Fail hard
else:
logger.warning(f"Could not load model from fold {current_fold_num} to include in ensemble for test fold {test_fold_num}.")
if len(ensemble_models) < 2:
logger.warning(f"Skipping ensemble evaluation for test fold {test_fold_num}: "
f"Need at least 2 models for ensemble, only loaded {len(ensemble_models)}.")
return {} # Return empty dict, not None, to indicate process ran but no ensemble formed
# Check consistency between test fold horizons and ensemble model horizons
if model_forecast_horizons is None: # Should not happen if len(ensemble_models) >= 1
logger.error(f"Could not determine forecast horizons from ensemble models for test fold {test_fold_num}.")
return None
if set(test_forecast_horizons) != set(model_forecast_horizons):
logger.error(f"Forecast horizons of test fold {test_fold_num} ({test_forecast_horizons}) do not match "
f"horizons from ensemble models ({model_forecast_horizons}). Cannot evaluate.")
return None
# Make ensemble predictions using the loaded models and the test fold's data loader
# Use the test fold's config to determine device implicitly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ensemble_preds_dict, targets_np = make_ensemble_predictions(ensemble_models, test_loader, device=device)
if ensemble_preds_dict is None or targets_np is None:
logger.error(f"Failed to generate ensemble predictions or retrieve targets for test fold {test_fold_num}.")
return None # Indicate failure
# Evaluate each ensemble method's predictions against the test fold's targets
fold_ensemble_results: Dict[str, Dict[str, float]] = {}
for method, preds_np in ensemble_preds_dict.items():
logger.info(f"Evaluating ensemble method '{method}' for test fold {test_fold_num}...")
# Define a unique output directory for this method's plots
method_plot_dir = output_base_dir / "ensemble_eval_plots" / f"test_fold_{test_fold_num:02d}" / f"method_{method}"
# Use the prediction_target_time_index loaded earlier
prediction_time_index_for_plot = None
if prediction_target_time_index is not None:
if len(prediction_target_time_index) == targets_np.shape[0]:
prediction_time_index_for_plot = prediction_target_time_index
else:
logger.warning(f"Length of loaded prediction target time index ({len(prediction_target_time_index)}) does not match "
f"number of samples ({targets_np.shape[0]}) for test fold {test_fold_num}, method '{method}'. Plot x-axis may be incorrect.")
# Call the standard evaluation function
metrics = evaluate_fold_predictions(
y_true_scaled=targets_np,
y_pred_scaled=preds_np,
target_scaler=target_scaler,
eval_config=test_fold_config.evaluation,
fold_num=test_fold_num - 1,
output_dir=str(method_plot_dir.parent.parent),
plot_subdir=f"method_{method}",
prediction_time_index=prediction_time_index_for_plot, # Pass the index
forecast_horizons=test_forecast_horizons,
plot_title_prefix=f"Ensemble ({method})"
)
fold_ensemble_results[method] = metrics
logger.info(f"--- Finished Ensemble Evaluation: Test Fold {test_fold_num} ---")
return fold_ensemble_results
def run_ensemble_evaluation(
config: MainConfig, # Pass main config for context if needed, though fold configs are loaded
output_base_dir: Path,
# full_data_index: Optional[pd.Index] = None # Removed, get index from loaded objects
) -> Dict[int, Dict[str, Dict[str, float]]]:
"""
Run ensemble evaluation across all folds, treating each as the test set once.
Args:
config: The main configuration object (potentially unused if fold configs sufficient).
output_base_dir: Base directory where fold outputs are stored.
Returns:
Dictionary containing ensemble metrics for each test fold:
{ test_fold_num: { ensemble_method: { metric_name: value, ... }, ... }, ... }
"""
logger.info("===== Starting Cross-Validated Ensemble Evaluation =====")
all_ensemble_results: Dict[int, Dict[str, Dict[str, float]]] = {}
# Discover fold directories
fold_dirs = sorted([d for d in output_base_dir.glob("fold_*") if d.is_dir()])
if not fold_dirs:
logger.error(f"No fold directories found in {output_base_dir} for ensemble evaluation.")
return {}
if len(fold_dirs) < 2:
logger.warning(f"Need at least 2 folds for ensemble evaluation, found {len(fold_dirs)}. Skipping.")
return {}
logger.info(f"Found {len(fold_dirs)} fold directories.")
# Iterate through each fold, designating it as the test fold
for i, test_fold_dir in enumerate(fold_dirs):
test_fold_num = i + 1 # 1-based fold number
try:
results_for_test_fold = evaluate_ensemble_for_test_fold(
test_fold_num=test_fold_num,
all_fold_dirs=fold_dirs,
output_base_dir=output_base_dir,
# full_data_index=full_data_index # Removed
)
if results_for_test_fold is not None:
# Only add results if the evaluation didn't fail completely
all_ensemble_results[test_fold_num] = results_for_test_fold
except Exception as e:
# Catch unexpected errors during a specific test fold evaluation
logger.error(f"Unexpected error during ensemble evaluation with test fold {test_fold_num}: {e}", exc_info=True)
continue # Continue to the next fold
# Saving is handled by the main script (`forecasting_model_run.py`) which calls this
if not all_ensemble_results:
logger.warning("Ensemble evaluation finished, but no results were generated.")
else:
logger.info("===== Finished Cross-Validated Ensemble Evaluation =====")
return all_ensemble_results