Files
entrix_case_challange/forecasting_model/evaluation.py
2025-05-02 14:36:19 +02:00

325 lines
14 KiB
Python

import logging
import os
from pathlib import Path # Added
import numpy as np
import torch
import torchmetrics
from torch.utils.data import DataLoader
from sklearn.preprocessing import StandardScaler, MinMaxScaler # For type hinting target_scaler
from typing import Dict, Any, Optional, Union, List, Tuple
# import matplotlib.pyplot as plt # No longer needed directly
# import seaborn as sns # No longer needed directly
# Assuming config_model and io.plotting are accessible
from forecasting_model.utils.config_model import EvaluationConfig
from forecasting_model.io.plotting import ( # Import the plotting utilities
setup_plot_style,
save_plot,
create_time_series_plot,
create_scatter_plot,
create_residuals_plot,
create_residuals_distribution_plot
)
logger = logging.getLogger(__name__)
# --- Metric Calculations (Utilities - Optional) ---
# (Keep calculate_mae_np, calculate_rmse_np if needed as standalone utils)
# ... (code for calculate_mae_np, calculate_rmse_np unchanged) ...
def calculate_mae_np(y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""
[Optional Utility] Calculate Mean Absolute Error using NumPy.
Prefer torchmetrics inside training/validation loops.
Args:
y_true: Ground truth values (flattened).
y_pred: Predicted values (flattened).
Returns:
Calculated MAE, or NaN if inputs are invalid.
"""
if y_true.shape != y_pred.shape:
logger.error(f"Shape mismatch for MAE: y_true={y_true.shape}, y_pred={y_pred.shape}")
return np.nan
if len(y_true) == 0:
logger.warning("Attempting to calculate MAE on empty arrays.")
return np.nan
try:
# Use scikit-learn for robustness if available, otherwise basic numpy
from sklearn.metrics import mean_absolute_error
mae = mean_absolute_error(y_true, y_pred)
except ImportError:
mae = np.mean(np.abs(y_true - y_pred))
return float(mae)
def calculate_rmse_np(y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""
[Optional Utility] Calculate Root Mean Squared Error using NumPy.
Prefer torchmetrics inside training/validation loops.
Args:
y_true: Ground truth values (flattened).
y_pred: Predicted values (flattened).
Returns:
Calculated RMSE, or NaN if inputs are invalid.
"""
if y_true.shape != y_pred.shape:
logger.error(f"Shape mismatch for RMSE: y_true={y_true.shape}, y_pred={y_pred.shape}")
return np.nan
if len(y_true) == 0:
logger.warning("Attempting to calculate RMSE on empty arrays.")
return np.nan
try:
# Use scikit-learn for robustness if available, otherwise basic numpy
from sklearn.metrics import mean_squared_error
mse = mean_squared_error(y_true, y_pred, squared=True)
except ImportError:
mse = np.mean((y_true - y_pred)**2)
rmse = np.sqrt(mse)
return float(rmse)
# --- Plotting Functions (Utilities) ---
# REMOVED - These are now imported from io.plotting
# --- Fold Evaluation Function ---
def evaluate_fold_predictions(
y_true_scaled: np.ndarray,
y_pred_scaled: np.ndarray,
target_scaler: Union[StandardScaler, MinMaxScaler, None],
eval_config: EvaluationConfig,
fold_num: int,
output_dir: str, # Base output directory (e.g., output/cv_results)
time_index: Optional[np.ndarray] = None # Optional: Pass time index for x-axis
) -> Dict[str, float]:
"""
Processes prediction results for a fold's test set using torchmetrics.
Takes scaled predictions and targets, inverse transforms them,
calculates final metrics (MAE, RMSE) using torchmetrics.functional,
and generates evaluation plots using utilities from io.plotting. Assumes
model inference is already done.
Args:
y_true_scaled: Numpy array of scaled ground truth targets (n_samples, horizon).
y_pred_scaled: Numpy array of scaled model predictions (n_samples, horizon).
target_scaler: The scaler fitted on the target variable during training. Needed
for inverse transforming to original scale. Can be None.
eval_config: Configuration object for evaluation parameters (e.g., plotting).
fold_num: The current fold number (e.g., 0, 1, ...).
output_dir: The base directory to save fold-specific outputs (plots, metrics).
time_index: Optional array representing the time index for the test set,
used for x-axis in time-based plots. If None, uses integer indices.
Returns:
Dictionary containing evaluation metrics {'MAE': value, 'RMSE': value} on the
original scale. Metrics will be NaN if inverse transform or calculation fails.
Raises:
ValueError: If input shapes are inconsistent or required scaler is missing.
"""
logger.info(f"Processing evaluation results for Fold {fold_num + 1}...")
fold_id = fold_num + 1 # Use 1-based indexing for reporting/filenames
if y_true_scaled.shape != y_pred_scaled.shape:
raise ValueError(f"Shape mismatch between targets and predictions: "
f"{y_true_scaled.shape} vs {y_pred_scaled.shape}")
if y_true_scaled.ndim != 2:
raise ValueError(f"Expected 2D arrays for targets and predictions, got {y_true_scaled.ndim}D")
n_samples, horizon = y_true_scaled.shape
logger.debug(f"Processing {n_samples} samples with horizon {horizon}.")
# --- Inverse Transform (Outputs NumPy) ---
y_true_flat_scaled = y_true_scaled.reshape(-1, 1)
y_pred_flat_scaled = y_pred_scaled.reshape(-1, 1)
y_true_inv_np: np.ndarray
y_pred_inv_np: np.ndarray
if target_scaler is not None:
try:
logger.debug("Inverse transforming predictions and targets.")
y_true_inv_np = target_scaler.inverse_transform(y_true_flat_scaled)
y_pred_inv_np = target_scaler.inverse_transform(y_pred_flat_scaled)
# Flatten NumPy arrays for metric calculation and plotting
y_true_np = y_true_inv_np.flatten()
y_pred_np = y_pred_inv_np.flatten()
except Exception as e:
logger.error(f"Error during inverse scaling for Fold {fold_id}: {e}", exc_info=True)
logger.error("Metrics calculation will be skipped due to inverse transform failure.")
return {'MAE': np.nan, 'RMSE': np.nan}
else:
logger.info("No target scaler provided, assuming inputs are already on original scale.")
# Flatten NumPy arrays for metric calculation and plotting
y_true_np = y_true_flat_scaled.flatten()
y_pred_np = y_pred_flat_scaled.flatten()
# --- Calculate Metrics using torchmetrics.functional ---
metrics: Dict[str, float] = {'MAE': np.nan, 'RMSE': np.nan} # Initialize with NaN
try:
if len(y_true_np) > 0: # Check if data exists after potential failures
y_true_tensor = torch.from_numpy(y_true_np).float().cpu()
y_pred_tensor = torch.from_numpy(y_pred_np).float().cpu()
mae_tensor = torchmetrics.functional.mean_absolute_error(y_pred_tensor, y_true_tensor)
mse_tensor = torchmetrics.functional.mean_squared_error(y_pred_tensor, y_true_tensor)
rmse_tensor = torch.sqrt(mse_tensor)
metrics['MAE'] = mae_tensor.item()
metrics['RMSE'] = rmse_tensor.item()
logger.info(f"Fold {fold_id} Test Set Metrics (torchmetrics): MAE={metrics['MAE']:.4f}, RMSE={metrics['RMSE']:.4f}")
else:
logger.warning(f"Skipping metric calculation for Fold {fold_id} due to empty data after inverse transform.")
except Exception as e:
logger.error(f"Failed to calculate metrics using torchmetrics for Fold {fold_id}: {e}", exc_info=True)
# metrics already initialized to NaN
# --- Generate Plots (Optional - uses plotting utilities) ---
if eval_config.save_plots and len(y_true_np) > 0:
logger.info(f"Generating evaluation plots for Fold {fold_id}...")
# Define plot directory and setup style
fold_plot_dir = Path(output_dir) / f"fold_{fold_id:02d}" / "plots"
setup_plot_style() # Apply consistent styling
title_suffix = f"Fold {fold_id} Test Set"
residuals_np = y_true_np - y_pred_np
# Determine x-axis: use provided time_index if available, else integer indices
# Note: Flattened y_true/y_pred have length n_samples * horizon
# Need an appropriate index for this flattened view if time_index is provided.
# Simple approach: use integer indices for flattened data.
plot_indices = np.arange(len(y_true_np))
xlabel = "Time Index (Flattened Horizon x Samples)"
# If time_index corresponding to the start of each forecast is passed,
# more sophisticated x-axis handling could be done, but integer indices are simpler.
try:
# Create and save each plot using utility functions
fig_ts = create_time_series_plot(
plot_indices, y_true_np, y_pred_np,
f"Predictions vs Actual - {title_suffix}",
xlabel=xlabel,
ylabel="Value (Original Scale)",
max_points=eval_config.plot_sample_size
)
save_plot(fig_ts, fold_plot_dir / "predictions_vs_actual.png")
fig_scatter = create_scatter_plot(
y_true_np, y_pred_np,
f"Scatter Plot - {title_suffix}",
xlabel="Actual Values (Original Scale)",
ylabel="Predicted Values (Original Scale)"
)
save_plot(fig_scatter, fold_plot_dir / "scatter_predictions.png")
fig_res_time = create_residuals_plot(
plot_indices, residuals_np,
f"Residuals Over Time - {title_suffix}",
xlabel=xlabel,
ylabel="Residual (Original Scale)",
max_points=eval_config.plot_sample_size
)
save_plot(fig_res_time, fold_plot_dir / "residuals_time.png")
fig_res_dist = create_residuals_distribution_plot(
residuals_np,
f"Residuals Distribution - {title_suffix}",
xlabel="Residual Value (Original Scale)",
ylabel="Density"
)
save_plot(fig_res_dist, fold_plot_dir / "residuals_distribution.png")
logger.info(f"Evaluation plots saved to: {fold_plot_dir}")
except Exception as e:
logger.error(f"Failed to generate or save one or more plots for Fold {fold_id}: {e}", exc_info=True)
# Continue without plots, metrics are already calculated.
elif eval_config.save_plots and len(y_true_np) == 0:
logger.warning(f"Skipping plot generation for Fold {fold_id} due to empty data.")
logger.info(f"Evaluation processing finished for Fold {fold_id}.")
return metrics
# --- (Optional) Wrapper for non-PL usage or direct testing ---
# This function still calls evaluate_fold_predictions internally, so it benefits
# from the updated plotting logic without needing direct changes here.
def evaluate_model_on_fold_test_set(
model: torch.nn.Module,
test_loader: DataLoader,
device: torch.device,
target_scaler: Union[StandardScaler, MinMaxScaler, None],
eval_config: EvaluationConfig,
fold_num: int,
output_dir: str
) -> Dict[str, float]:
"""
[Optional Function] Evaluates a given model on a fold's test set.
Runs the inference loop, collects scaled results, then processes them using
`evaluate_fold_predictions` (which now uses plotting utilities).
Useful for standalone testing or if not using pl.Trainer.test().
"""
# ... (Implementation of inference loop remains the same) ...
logger.info(f"Starting full evaluation (inference + processing) for Fold {fold_num + 1}...")
model.eval()
model.to(device)
all_preds_scaled_list: List[torch.Tensor] = []
all_targets_scaled_list: List[torch.Tensor] = []
with torch.no_grad():
for i, (X_batch, y_batch) in enumerate(test_loader):
try:
X_batch = X_batch.to(device)
outputs = model(X_batch) # Scaled outputs
# Ensure outputs match target shape (e.g., handle trailing dimension)
if outputs.shape != y_batch.shape:
if outputs.ndim == y_batch.ndim + 1 and outputs.shape[-1] == 1:
outputs = outputs.squeeze(-1)
if outputs.shape != y_batch.shape:
raise ValueError(f"Shape mismatch: Output {outputs.shape}, Target {y_batch.shape}")
all_preds_scaled_list.append(outputs.cpu())
all_targets_scaled_list.append(y_batch.cpu()) # Keep targets on CPU
except Exception as e:
logger.error(f"Error during inference batch {i} for Fold {fold_num+1}: {e}", exc_info=True)
raise ValueError(f"Inference failed on batch {i} for Fold {fold_num+1}")
# Concatenate results from all batches
try:
if not all_preds_scaled_list or not all_targets_scaled_list:
logger.error(f"No prediction results collected for Fold {fold_num + 1}. Check test_loader.")
return {'MAE': np.nan, 'RMSE': np.nan}
y_pred_scaled = torch.cat(all_preds_scaled_list, dim=0).numpy()
y_true_scaled = torch.cat(all_targets_scaled_list, dim=0).numpy()
except Exception as e:
logger.error(f"Error concatenating prediction results for Fold {fold_num + 1}: {e}", exc_info=True)
raise ValueError("Failed to combine batch results during evaluation inference.")
# Process the collected predictions using the refactored function
# No time_index passed here by default, plotting will use integer indices
return evaluate_fold_predictions(
y_true_scaled=y_true_scaled,
y_pred_scaled=y_pred_scaled,
target_scaler=target_scaler,
eval_config=eval_config,
fold_num=fold_num,
output_dir=output_dir,
time_index=None # Explicitly pass None
)