intermediate backup
This commit is contained in:
@ -1,82 +1,325 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path # Added
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchmetrics
|
||||
from torch.utils.data import DataLoader
|
||||
from typing import Dict, Any, Optional
|
||||
from utils.config_model import EvaluationConfig
|
||||
from sklearn.preprocessing import StandardScaler, MinMaxScaler # For type hinting target_scaler
|
||||
from typing import Dict, Any, Optional, Union, List, Tuple
|
||||
# import matplotlib.pyplot as plt # No longer needed directly
|
||||
# import seaborn as sns # No longer needed directly
|
||||
|
||||
def calculate_mae(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
||||
"""
|
||||
Calculate Mean Absolute Error.
|
||||
"""
|
||||
# TODO: Implement MAE calculation
|
||||
pass
|
||||
# Assuming config_model and io.plotting are accessible
|
||||
from forecasting_model.utils.config_model import EvaluationConfig
|
||||
from forecasting_model.io.plotting import ( # Import the plotting utilities
|
||||
setup_plot_style,
|
||||
save_plot,
|
||||
create_time_series_plot,
|
||||
create_scatter_plot,
|
||||
create_residuals_plot,
|
||||
create_residuals_distribution_plot
|
||||
)
|
||||
|
||||
def calculate_rmse(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
||||
"""
|
||||
Calculate Root Mean Squared Error.
|
||||
"""
|
||||
# TODO: Implement RMSE calculation
|
||||
pass
|
||||
|
||||
def plot_predictions_vs_actual(
|
||||
y_true: np.ndarray,
|
||||
y_pred: np.ndarray,
|
||||
title_suffix: str,
|
||||
filename: str,
|
||||
max_points: Optional[int] = None
|
||||
) -> None:
|
||||
"""
|
||||
Create line plot of predictions vs actual values.
|
||||
"""
|
||||
# TODO: Implement prediction vs actual plot
|
||||
pass
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def plot_scatter_predictions(
|
||||
y_true: np.ndarray,
|
||||
y_pred: np.ndarray,
|
||||
title_suffix: str,
|
||||
filename: str
|
||||
) -> None:
|
||||
# --- Metric Calculations (Utilities - Optional) ---
|
||||
# (Keep calculate_mae_np, calculate_rmse_np if needed as standalone utils)
|
||||
# ... (code for calculate_mae_np, calculate_rmse_np unchanged) ...
|
||||
def calculate_mae_np(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
||||
"""
|
||||
Create scatter plot of predictions vs actual values.
|
||||
"""
|
||||
# TODO: Implement scatter plot
|
||||
pass
|
||||
[Optional Utility] Calculate Mean Absolute Error using NumPy.
|
||||
Prefer torchmetrics inside training/validation loops.
|
||||
|
||||
def plot_residuals_time(
|
||||
residuals: np.ndarray,
|
||||
title_suffix: str,
|
||||
filename: str,
|
||||
max_points: Optional[int] = None
|
||||
) -> None:
|
||||
"""
|
||||
Create plot of residuals over time.
|
||||
"""
|
||||
# TODO: Implement residuals time plot
|
||||
pass
|
||||
Args:
|
||||
y_true: Ground truth values (flattened).
|
||||
y_pred: Predicted values (flattened).
|
||||
|
||||
def plot_residuals_distribution(
|
||||
residuals: np.ndarray,
|
||||
title_suffix: str,
|
||||
filename: str
|
||||
) -> None:
|
||||
Returns:
|
||||
Calculated MAE, or NaN if inputs are invalid.
|
||||
"""
|
||||
Create histogram/KDE of residuals.
|
||||
"""
|
||||
# TODO: Implement residuals distribution plot
|
||||
pass
|
||||
if y_true.shape != y_pred.shape:
|
||||
logger.error(f"Shape mismatch for MAE: y_true={y_true.shape}, y_pred={y_pred.shape}")
|
||||
return np.nan
|
||||
if len(y_true) == 0:
|
||||
logger.warning("Attempting to calculate MAE on empty arrays.")
|
||||
return np.nan
|
||||
try:
|
||||
# Use scikit-learn for robustness if available, otherwise basic numpy
|
||||
from sklearn.metrics import mean_absolute_error
|
||||
mae = mean_absolute_error(y_true, y_pred)
|
||||
except ImportError:
|
||||
mae = np.mean(np.abs(y_true - y_pred))
|
||||
return float(mae)
|
||||
|
||||
def evaluate_fold(
|
||||
model: torch.nn.Module,
|
||||
test_loader: DataLoader,
|
||||
loss_fn: torch.nn.Module,
|
||||
device: torch.device,
|
||||
target_scaler: Any,
|
||||
|
||||
def calculate_rmse_np(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
||||
"""
|
||||
[Optional Utility] Calculate Root Mean Squared Error using NumPy.
|
||||
Prefer torchmetrics inside training/validation loops.
|
||||
|
||||
Args:
|
||||
y_true: Ground truth values (flattened).
|
||||
y_pred: Predicted values (flattened).
|
||||
|
||||
Returns:
|
||||
Calculated RMSE, or NaN if inputs are invalid.
|
||||
"""
|
||||
if y_true.shape != y_pred.shape:
|
||||
logger.error(f"Shape mismatch for RMSE: y_true={y_true.shape}, y_pred={y_pred.shape}")
|
||||
return np.nan
|
||||
if len(y_true) == 0:
|
||||
logger.warning("Attempting to calculate RMSE on empty arrays.")
|
||||
return np.nan
|
||||
try:
|
||||
# Use scikit-learn for robustness if available, otherwise basic numpy
|
||||
from sklearn.metrics import mean_squared_error
|
||||
mse = mean_squared_error(y_true, y_pred, squared=True)
|
||||
except ImportError:
|
||||
mse = np.mean((y_true - y_pred)**2)
|
||||
rmse = np.sqrt(mse)
|
||||
return float(rmse)
|
||||
|
||||
|
||||
# --- Plotting Functions (Utilities) ---
|
||||
# REMOVED - These are now imported from io.plotting
|
||||
|
||||
|
||||
# --- Fold Evaluation Function ---
|
||||
|
||||
def evaluate_fold_predictions(
|
||||
y_true_scaled: np.ndarray,
|
||||
y_pred_scaled: np.ndarray,
|
||||
target_scaler: Union[StandardScaler, MinMaxScaler, None],
|
||||
eval_config: EvaluationConfig,
|
||||
fold_num: int
|
||||
fold_num: int,
|
||||
output_dir: str, # Base output directory (e.g., output/cv_results)
|
||||
time_index: Optional[np.ndarray] = None # Optional: Pass time index for x-axis
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Evaluate model on test set and generate plots.
|
||||
Processes prediction results for a fold's test set using torchmetrics.
|
||||
|
||||
Takes scaled predictions and targets, inverse transforms them,
|
||||
calculates final metrics (MAE, RMSE) using torchmetrics.functional,
|
||||
and generates evaluation plots using utilities from io.plotting. Assumes
|
||||
model inference is already done.
|
||||
|
||||
Args:
|
||||
y_true_scaled: Numpy array of scaled ground truth targets (n_samples, horizon).
|
||||
y_pred_scaled: Numpy array of scaled model predictions (n_samples, horizon).
|
||||
target_scaler: The scaler fitted on the target variable during training. Needed
|
||||
for inverse transforming to original scale. Can be None.
|
||||
eval_config: Configuration object for evaluation parameters (e.g., plotting).
|
||||
fold_num: The current fold number (e.g., 0, 1, ...).
|
||||
output_dir: The base directory to save fold-specific outputs (plots, metrics).
|
||||
time_index: Optional array representing the time index for the test set,
|
||||
used for x-axis in time-based plots. If None, uses integer indices.
|
||||
|
||||
Returns:
|
||||
Dictionary containing evaluation metrics {'MAE': value, 'RMSE': value} on the
|
||||
original scale. Metrics will be NaN if inverse transform or calculation fails.
|
||||
|
||||
Raises:
|
||||
ValueError: If input shapes are inconsistent or required scaler is missing.
|
||||
"""
|
||||
# TODO: Implement full evaluation pipeline
|
||||
pass
|
||||
logger.info(f"Processing evaluation results for Fold {fold_num + 1}...")
|
||||
fold_id = fold_num + 1 # Use 1-based indexing for reporting/filenames
|
||||
|
||||
if y_true_scaled.shape != y_pred_scaled.shape:
|
||||
raise ValueError(f"Shape mismatch between targets and predictions: "
|
||||
f"{y_true_scaled.shape} vs {y_pred_scaled.shape}")
|
||||
if y_true_scaled.ndim != 2:
|
||||
raise ValueError(f"Expected 2D arrays for targets and predictions, got {y_true_scaled.ndim}D")
|
||||
|
||||
n_samples, horizon = y_true_scaled.shape
|
||||
logger.debug(f"Processing {n_samples} samples with horizon {horizon}.")
|
||||
|
||||
# --- Inverse Transform (Outputs NumPy) ---
|
||||
y_true_flat_scaled = y_true_scaled.reshape(-1, 1)
|
||||
y_pred_flat_scaled = y_pred_scaled.reshape(-1, 1)
|
||||
|
||||
y_true_inv_np: np.ndarray
|
||||
y_pred_inv_np: np.ndarray
|
||||
|
||||
if target_scaler is not None:
|
||||
try:
|
||||
logger.debug("Inverse transforming predictions and targets.")
|
||||
y_true_inv_np = target_scaler.inverse_transform(y_true_flat_scaled)
|
||||
y_pred_inv_np = target_scaler.inverse_transform(y_pred_flat_scaled)
|
||||
# Flatten NumPy arrays for metric calculation and plotting
|
||||
y_true_np = y_true_inv_np.flatten()
|
||||
y_pred_np = y_pred_inv_np.flatten()
|
||||
except Exception as e:
|
||||
logger.error(f"Error during inverse scaling for Fold {fold_id}: {e}", exc_info=True)
|
||||
logger.error("Metrics calculation will be skipped due to inverse transform failure.")
|
||||
return {'MAE': np.nan, 'RMSE': np.nan}
|
||||
else:
|
||||
logger.info("No target scaler provided, assuming inputs are already on original scale.")
|
||||
# Flatten NumPy arrays for metric calculation and plotting
|
||||
y_true_np = y_true_flat_scaled.flatten()
|
||||
y_pred_np = y_pred_flat_scaled.flatten()
|
||||
|
||||
# --- Calculate Metrics using torchmetrics.functional ---
|
||||
metrics: Dict[str, float] = {'MAE': np.nan, 'RMSE': np.nan} # Initialize with NaN
|
||||
try:
|
||||
if len(y_true_np) > 0: # Check if data exists after potential failures
|
||||
y_true_tensor = torch.from_numpy(y_true_np).float().cpu()
|
||||
y_pred_tensor = torch.from_numpy(y_pred_np).float().cpu()
|
||||
|
||||
mae_tensor = torchmetrics.functional.mean_absolute_error(y_pred_tensor, y_true_tensor)
|
||||
mse_tensor = torchmetrics.functional.mean_squared_error(y_pred_tensor, y_true_tensor)
|
||||
rmse_tensor = torch.sqrt(mse_tensor)
|
||||
|
||||
metrics['MAE'] = mae_tensor.item()
|
||||
metrics['RMSE'] = rmse_tensor.item()
|
||||
|
||||
logger.info(f"Fold {fold_id} Test Set Metrics (torchmetrics): MAE={metrics['MAE']:.4f}, RMSE={metrics['RMSE']:.4f}")
|
||||
else:
|
||||
logger.warning(f"Skipping metric calculation for Fold {fold_id} due to empty data after inverse transform.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to calculate metrics using torchmetrics for Fold {fold_id}: {e}", exc_info=True)
|
||||
# metrics already initialized to NaN
|
||||
|
||||
|
||||
# --- Generate Plots (Optional - uses plotting utilities) ---
|
||||
if eval_config.save_plots and len(y_true_np) > 0:
|
||||
logger.info(f"Generating evaluation plots for Fold {fold_id}...")
|
||||
# Define plot directory and setup style
|
||||
fold_plot_dir = Path(output_dir) / f"fold_{fold_id:02d}" / "plots"
|
||||
setup_plot_style() # Apply consistent styling
|
||||
|
||||
title_suffix = f"Fold {fold_id} Test Set"
|
||||
residuals_np = y_true_np - y_pred_np
|
||||
|
||||
# Determine x-axis: use provided time_index if available, else integer indices
|
||||
# Note: Flattened y_true/y_pred have length n_samples * horizon
|
||||
# Need an appropriate index for this flattened view if time_index is provided.
|
||||
# Simple approach: use integer indices for flattened data.
|
||||
plot_indices = np.arange(len(y_true_np))
|
||||
xlabel = "Time Index (Flattened Horizon x Samples)"
|
||||
# If time_index corresponding to the start of each forecast is passed,
|
||||
# more sophisticated x-axis handling could be done, but integer indices are simpler.
|
||||
|
||||
|
||||
try:
|
||||
# Create and save each plot using utility functions
|
||||
fig_ts = create_time_series_plot(
|
||||
plot_indices, y_true_np, y_pred_np,
|
||||
f"Predictions vs Actual - {title_suffix}",
|
||||
xlabel=xlabel,
|
||||
ylabel="Value (Original Scale)",
|
||||
max_points=eval_config.plot_sample_size
|
||||
)
|
||||
save_plot(fig_ts, fold_plot_dir / "predictions_vs_actual.png")
|
||||
|
||||
fig_scatter = create_scatter_plot(
|
||||
y_true_np, y_pred_np,
|
||||
f"Scatter Plot - {title_suffix}",
|
||||
xlabel="Actual Values (Original Scale)",
|
||||
ylabel="Predicted Values (Original Scale)"
|
||||
)
|
||||
save_plot(fig_scatter, fold_plot_dir / "scatter_predictions.png")
|
||||
|
||||
fig_res_time = create_residuals_plot(
|
||||
plot_indices, residuals_np,
|
||||
f"Residuals Over Time - {title_suffix}",
|
||||
xlabel=xlabel,
|
||||
ylabel="Residual (Original Scale)",
|
||||
max_points=eval_config.plot_sample_size
|
||||
)
|
||||
save_plot(fig_res_time, fold_plot_dir / "residuals_time.png")
|
||||
|
||||
fig_res_dist = create_residuals_distribution_plot(
|
||||
residuals_np,
|
||||
f"Residuals Distribution - {title_suffix}",
|
||||
xlabel="Residual Value (Original Scale)",
|
||||
ylabel="Density"
|
||||
)
|
||||
save_plot(fig_res_dist, fold_plot_dir / "residuals_distribution.png")
|
||||
|
||||
logger.info(f"Evaluation plots saved to: {fold_plot_dir}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate or save one or more plots for Fold {fold_id}: {e}", exc_info=True)
|
||||
# Continue without plots, metrics are already calculated.
|
||||
|
||||
elif eval_config.save_plots and len(y_true_np) == 0:
|
||||
logger.warning(f"Skipping plot generation for Fold {fold_id} due to empty data.")
|
||||
|
||||
|
||||
logger.info(f"Evaluation processing finished for Fold {fold_id}.")
|
||||
return metrics
|
||||
|
||||
|
||||
# --- (Optional) Wrapper for non-PL usage or direct testing ---
|
||||
# This function still calls evaluate_fold_predictions internally, so it benefits
|
||||
# from the updated plotting logic without needing direct changes here.
|
||||
def evaluate_model_on_fold_test_set(
|
||||
model: torch.nn.Module,
|
||||
test_loader: DataLoader,
|
||||
device: torch.device,
|
||||
target_scaler: Union[StandardScaler, MinMaxScaler, None],
|
||||
eval_config: EvaluationConfig,
|
||||
fold_num: int,
|
||||
output_dir: str
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
[Optional Function] Evaluates a given model on a fold's test set.
|
||||
|
||||
Runs the inference loop, collects scaled results, then processes them using
|
||||
`evaluate_fold_predictions` (which now uses plotting utilities).
|
||||
Useful for standalone testing or if not using pl.Trainer.test().
|
||||
"""
|
||||
# ... (Implementation of inference loop remains the same) ...
|
||||
logger.info(f"Starting full evaluation (inference + processing) for Fold {fold_num + 1}...")
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
all_preds_scaled_list: List[torch.Tensor] = []
|
||||
all_targets_scaled_list: List[torch.Tensor] = []
|
||||
|
||||
with torch.no_grad():
|
||||
for i, (X_batch, y_batch) in enumerate(test_loader):
|
||||
try:
|
||||
X_batch = X_batch.to(device)
|
||||
outputs = model(X_batch) # Scaled outputs
|
||||
|
||||
# Ensure outputs match target shape (e.g., handle trailing dimension)
|
||||
if outputs.shape != y_batch.shape:
|
||||
if outputs.ndim == y_batch.ndim + 1 and outputs.shape[-1] == 1:
|
||||
outputs = outputs.squeeze(-1)
|
||||
if outputs.shape != y_batch.shape:
|
||||
raise ValueError(f"Shape mismatch: Output {outputs.shape}, Target {y_batch.shape}")
|
||||
|
||||
all_preds_scaled_list.append(outputs.cpu())
|
||||
all_targets_scaled_list.append(y_batch.cpu()) # Keep targets on CPU
|
||||
except Exception as e:
|
||||
logger.error(f"Error during inference batch {i} for Fold {fold_num+1}: {e}", exc_info=True)
|
||||
raise ValueError(f"Inference failed on batch {i} for Fold {fold_num+1}")
|
||||
|
||||
|
||||
# Concatenate results from all batches
|
||||
try:
|
||||
if not all_preds_scaled_list or not all_targets_scaled_list:
|
||||
logger.error(f"No prediction results collected for Fold {fold_num + 1}. Check test_loader.")
|
||||
return {'MAE': np.nan, 'RMSE': np.nan}
|
||||
|
||||
y_pred_scaled = torch.cat(all_preds_scaled_list, dim=0).numpy()
|
||||
y_true_scaled = torch.cat(all_targets_scaled_list, dim=0).numpy()
|
||||
except Exception as e:
|
||||
logger.error(f"Error concatenating prediction results for Fold {fold_num + 1}: {e}", exc_info=True)
|
||||
raise ValueError("Failed to combine batch results during evaluation inference.")
|
||||
|
||||
# Process the collected predictions using the refactored function
|
||||
# No time_index passed here by default, plotting will use integer indices
|
||||
return evaluate_fold_predictions(
|
||||
y_true_scaled=y_true_scaled,
|
||||
y_pred_scaled=y_pred_scaled,
|
||||
target_scaler=target_scaler,
|
||||
eval_config=eval_config,
|
||||
fold_num=fold_num,
|
||||
output_dir=output_dir,
|
||||
time_index=None # Explicitly pass None
|
||||
)
|
Reference in New Issue
Block a user