init
This commit is contained in:
82
forecasting_model/evaluation.py
Normal file
82
forecasting_model/evaluation.py
Normal file
@ -0,0 +1,82 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from typing import Dict, Any, Optional
|
||||
from utils.config_model import EvaluationConfig
|
||||
|
||||
def calculate_mae(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
||||
"""
|
||||
Calculate Mean Absolute Error.
|
||||
"""
|
||||
# TODO: Implement MAE calculation
|
||||
pass
|
||||
|
||||
def calculate_rmse(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
||||
"""
|
||||
Calculate Root Mean Squared Error.
|
||||
"""
|
||||
# TODO: Implement RMSE calculation
|
||||
pass
|
||||
|
||||
def plot_predictions_vs_actual(
|
||||
y_true: np.ndarray,
|
||||
y_pred: np.ndarray,
|
||||
title_suffix: str,
|
||||
filename: str,
|
||||
max_points: Optional[int] = None
|
||||
) -> None:
|
||||
"""
|
||||
Create line plot of predictions vs actual values.
|
||||
"""
|
||||
# TODO: Implement prediction vs actual plot
|
||||
pass
|
||||
|
||||
def plot_scatter_predictions(
|
||||
y_true: np.ndarray,
|
||||
y_pred: np.ndarray,
|
||||
title_suffix: str,
|
||||
filename: str
|
||||
) -> None:
|
||||
"""
|
||||
Create scatter plot of predictions vs actual values.
|
||||
"""
|
||||
# TODO: Implement scatter plot
|
||||
pass
|
||||
|
||||
def plot_residuals_time(
|
||||
residuals: np.ndarray,
|
||||
title_suffix: str,
|
||||
filename: str,
|
||||
max_points: Optional[int] = None
|
||||
) -> None:
|
||||
"""
|
||||
Create plot of residuals over time.
|
||||
"""
|
||||
# TODO: Implement residuals time plot
|
||||
pass
|
||||
|
||||
def plot_residuals_distribution(
|
||||
residuals: np.ndarray,
|
||||
title_suffix: str,
|
||||
filename: str
|
||||
) -> None:
|
||||
"""
|
||||
Create histogram/KDE of residuals.
|
||||
"""
|
||||
# TODO: Implement residuals distribution plot
|
||||
pass
|
||||
|
||||
def evaluate_fold(
|
||||
model: torch.nn.Module,
|
||||
test_loader: DataLoader,
|
||||
loss_fn: torch.nn.Module,
|
||||
device: torch.device,
|
||||
target_scaler: Any,
|
||||
eval_config: EvaluationConfig,
|
||||
fold_num: int
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Evaluate model on test set and generate plots.
|
||||
"""
|
||||
# TODO: Implement full evaluation pipeline
|
||||
pass
|
Reference in New Issue
Block a user