68 lines
3.1 KiB
Python
68 lines
3.1 KiB
Python
from typing import List, Optional, Dict, Any
|
|
|
|
import numpy as np
|
|
import logging
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# --- Interpolation Helper ---
|
|
def interpolate_forecast(
|
|
native_horizons: List[int],
|
|
native_predictions: np.ndarray,
|
|
target_horizon: int,
|
|
last_known_actual: Optional[float] = None # Optional: use last known price as t=0 for anchor
|
|
) -> np.ndarray | None:
|
|
"""
|
|
Linearly interpolates model predictions at native horizons to a full hourly sequence.
|
|
|
|
Args:
|
|
native_horizons: List of horizons the model predicts (e.g., [1, 6, 12, 24]). Must not be empty.
|
|
native_predictions: Numpy array of predictions corresponding to native_horizons. Must not be empty.
|
|
target_horizon: The desired length of the hourly forecast (e.g., 24).
|
|
last_known_actual: Optional last actual price before the forecast starts (at t=0). Used as anchor if 0 not in native_horizons.
|
|
|
|
Returns:
|
|
A numpy array of shape (target_horizon,) with interpolated values, or None on error.
|
|
"""
|
|
if not native_horizons or native_predictions is None or native_predictions.size == 0:
|
|
logger.error("Cannot interpolate with empty native horizons or predictions.")
|
|
return None
|
|
if len(native_horizons) != len(native_predictions):
|
|
logger.error(f"Mismatched lengths: native_horizons ({len(native_horizons)}) vs native_predictions ({len(native_predictions)})")
|
|
return None
|
|
|
|
try:
|
|
# Ensure horizons are sorted
|
|
sorted_indices = np.argsort(native_horizons)
|
|
# Use float for potentially non-integer horizons if ever needed, ensure points > 0 usually
|
|
xp = np.array(native_horizons, dtype=float)[sorted_indices]
|
|
fp = native_predictions[sorted_indices]
|
|
|
|
# Target points for interpolation (hours 1 to target_horizon)
|
|
x_target = np.arange(1, target_horizon + 1, dtype=float)
|
|
|
|
# Add t=0 point if provided and 0 is not already a native horizon
|
|
# This anchors the start of the interpolation.
|
|
if last_known_actual is not None and xp[0] > 0:
|
|
xp = np.insert(xp, 0, 0.0)
|
|
fp = np.insert(fp, 0, last_known_actual)
|
|
elif xp[0] == 0 and last_known_actual is not None:
|
|
logger.debug("Native horizons include 0, using model's prediction for t=0 instead of last_known_actual.")
|
|
elif last_known_actual is None and xp[0] > 0:
|
|
logger.warning("No last_known_actual provided and native horizons start > 0. Interpolation might be less accurate at the beginning.")
|
|
# If the first native horizon is > 1, np.interp will extrapolate constantly backwards from the first point.
|
|
|
|
|
|
# Check if target range requires extrapolation beyond the model's capability
|
|
if target_horizon > xp[-1]:
|
|
logger.warning(f"Target horizon ({target_horizon}) extends beyond the maximum native forecast horizon ({xp[-1]}). Extrapolation will occur (constant value).")
|
|
|
|
interpolated_values = np.interp(x_target, xp, fp)
|
|
return interpolated_values
|
|
|
|
except Exception as e:
|
|
logger.error(f"Linear interpolation failed: {e}", exc_info=True)
|
|
return None
|