intermediate backup
This commit is contained in:
@ -1,5 +1,26 @@
|
||||
"""
|
||||
IO utilities for the forecasting model.
|
||||
Input/Output utilities for the forecasting model package.
|
||||
Currently, primarily includes plotting functions used internally by evaluation.
|
||||
"""
|
||||
|
||||
This package contains utilities for data loading, saving, and visualization.
|
||||
"""
|
||||
# Expose plotting utilities if intended for external use
|
||||
# from .plotting import (
|
||||
# setup_plot_style,
|
||||
# save_plot,
|
||||
# create_time_series_plot,
|
||||
# create_scatter_plot,
|
||||
# create_residuals_plot,
|
||||
# create_residuals_distribution_plot
|
||||
# )
|
||||
|
||||
# __all__ = [
|
||||
# "setup_plot_style",
|
||||
# "save_plot",
|
||||
# "create_time_series_plot",
|
||||
# "create_scatter_plot",
|
||||
# "create_residuals_plot",
|
||||
# "create_residuals_distribution_plot",
|
||||
# ]
|
||||
|
||||
# If nothing is intended for public API from this submodule, leave this file empty
|
||||
# or with just a docstring.
|
@ -1,75 +1,307 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
import logging
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def setup_plot_style() -> None:
|
||||
def setup_plot_style(use_seaborn: bool = True) -> None:
|
||||
"""
|
||||
Set up consistent plotting style.
|
||||
"""
|
||||
# TODO: Implement plot style configuration
|
||||
pass
|
||||
Set up a consistent plotting style using seaborn if enabled.
|
||||
|
||||
def save_plot(fig: plt.Figure, filename: str) -> None:
|
||||
Args:
|
||||
use_seaborn: Whether to apply seaborn styling.
|
||||
"""
|
||||
Save plot to file with proper error handling.
|
||||
if use_seaborn:
|
||||
try:
|
||||
sns.set_theme(style="whitegrid", palette="muted")
|
||||
plt.rcParams['figure.figsize'] = (12, 6) # Default figure size
|
||||
logger.debug("Seaborn plot style set.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to set seaborn theme: {e}. Using default matplotlib style.")
|
||||
else:
|
||||
# Optional: Define a default matplotlib style if seaborn is not used
|
||||
plt.style.use('default')
|
||||
logger.debug("Using default matplotlib plot style.")
|
||||
|
||||
def save_plot(fig: plt.Figure, filename: Union[str, Path]) -> None:
|
||||
"""
|
||||
# TODO: Implement plot saving with error handling
|
||||
pass
|
||||
Save matplotlib figure to a file with directory creation and error handling.
|
||||
|
||||
Args:
|
||||
fig: The matplotlib Figure object to save.
|
||||
filename: The full path (including filename and extension) to save the plot to.
|
||||
Can be a string or Path object.
|
||||
|
||||
Raises:
|
||||
OSError: If the directory cannot be created.
|
||||
Exception: For other file saving errors.
|
||||
"""
|
||||
filepath = Path(filename)
|
||||
try:
|
||||
# Create the parent directory if it doesn't exist
|
||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
fig.savefig(filepath, bbox_inches='tight', dpi=150) # Save with tight bounding box and decent resolution
|
||||
logger.info(f"Plot saved successfully to: {filepath}")
|
||||
except OSError as e:
|
||||
logger.error(f"Failed to create directory for plot {filepath}: {e}", exc_info=True)
|
||||
raise # Re-raise OSError for directory creation issues
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save plot to {filepath}: {e}", exc_info=True)
|
||||
raise # Re-raise other saving errors
|
||||
finally:
|
||||
# Close the figure to free up memory, regardless of saving success
|
||||
plt.close(fig)
|
||||
|
||||
def create_time_series_plot(
|
||||
x: np.ndarray,
|
||||
y_true: np.ndarray,
|
||||
y_pred: np.ndarray,
|
||||
title: str,
|
||||
xlabel: str,
|
||||
ylabel: str,
|
||||
xlabel: str = "Time Index",
|
||||
ylabel: str = "Value",
|
||||
max_points: Optional[int] = None
|
||||
) -> plt.Figure:
|
||||
"""
|
||||
Create a time series plot with actual vs predicted values.
|
||||
Create a time series plot comparing actual vs predicted values.
|
||||
|
||||
Args:
|
||||
x: The array for the x-axis (e.g., time steps, indices).
|
||||
y_true: Ground truth values (1D array).
|
||||
y_pred: Predicted values (1D array).
|
||||
title: Title for the plot.
|
||||
xlabel: Label for the x-axis.
|
||||
ylabel: Label for the y-axis.
|
||||
max_points: Maximum number of points to display (subsamples if needed).
|
||||
|
||||
Returns:
|
||||
The generated matplotlib Figure object.
|
||||
|
||||
Raises:
|
||||
ValueError: If input array shapes are incompatible.
|
||||
"""
|
||||
# TODO: Implement time series plot creation
|
||||
pass
|
||||
if not (x.shape == y_true.shape == y_pred.shape and x.ndim == 1):
|
||||
raise ValueError("Input arrays (x, y_true, y_pred) must be 1D and have the same shape.")
|
||||
if len(x) == 0:
|
||||
logger.warning("Attempting to create time series plot with empty data.")
|
||||
# Return an empty figure or raise error? Let's return empty.
|
||||
return plt.figure()
|
||||
|
||||
logger.debug(f"Creating time series plot: {title}")
|
||||
fig, ax = plt.subplots(figsize=(15, 6)) # Consistent size
|
||||
|
||||
n_points = len(x)
|
||||
indices = np.arange(n_points) # Use internal indices for potential slicing
|
||||
|
||||
if max_points and n_points > max_points:
|
||||
step = max(1, n_points // max_points)
|
||||
plot_indices = indices[::step]
|
||||
plot_x = x[::step]
|
||||
plot_y_true = y_true[::step]
|
||||
plot_y_pred = y_pred[::step]
|
||||
effective_title = f'{title} (Sampled {len(plot_indices)} points)'
|
||||
else:
|
||||
plot_x = x
|
||||
plot_y_true = y_true
|
||||
plot_y_pred = y_pred
|
||||
effective_title = title
|
||||
|
||||
ax.plot(plot_x, plot_y_true, label='Actual', marker='.', linestyle='-', markersize=4, linewidth=1.5)
|
||||
ax.plot(plot_x, plot_y_pred, label='Predicted', marker='x', linestyle='--', markersize=4, alpha=0.8, linewidth=1)
|
||||
|
||||
ax.set_title(effective_title, fontsize=14)
|
||||
ax.set_xlabel(xlabel, fontsize=12)
|
||||
ax.set_ylabel(ylabel, fontsize=12)
|
||||
ax.legend()
|
||||
ax.grid(True, linestyle='--', alpha=0.6)
|
||||
fig.tight_layout()
|
||||
|
||||
return fig
|
||||
|
||||
def create_scatter_plot(
|
||||
y_true: np.ndarray,
|
||||
y_pred: np.ndarray,
|
||||
title: str,
|
||||
xlabel: str,
|
||||
ylabel: str
|
||||
xlabel: str = "Actual Values",
|
||||
ylabel: str = "Predicted Values"
|
||||
) -> plt.Figure:
|
||||
"""
|
||||
Create a scatter plot of actual vs predicted values.
|
||||
|
||||
Args:
|
||||
y_true: Ground truth values (1D array).
|
||||
y_pred: Predicted values (1D array).
|
||||
title: Title for the plot.
|
||||
xlabel: Label for the x-axis.
|
||||
ylabel: Label for the y-axis.
|
||||
|
||||
Returns:
|
||||
The generated matplotlib Figure object.
|
||||
|
||||
Raises:
|
||||
ValueError: If input array shapes are incompatible.
|
||||
"""
|
||||
# TODO: Implement scatter plot creation
|
||||
pass
|
||||
if not (y_true.shape == y_pred.shape and y_true.ndim == 1):
|
||||
raise ValueError("Input arrays (y_true, y_pred) must be 1D and have the same shape.")
|
||||
if len(y_true) == 0:
|
||||
logger.warning("Attempting to create scatter plot with empty data.")
|
||||
return plt.figure()
|
||||
|
||||
logger.debug(f"Creating scatter plot: {title}")
|
||||
fig, ax = plt.subplots(figsize=(8, 8)) # Square figure common for scatter
|
||||
|
||||
# Determine plot limits, handle potential NaNs
|
||||
valid_mask = ~np.isnan(y_true) & ~np.isnan(y_pred)
|
||||
if not np.any(valid_mask):
|
||||
logger.warning(f"No valid (non-NaN) data points found for scatter plot '{title}'.")
|
||||
# Return empty figure, plot would be blank
|
||||
return fig
|
||||
|
||||
y_true_valid = y_true[valid_mask]
|
||||
y_pred_valid = y_pred[valid_mask]
|
||||
|
||||
min_val = min(y_true_valid.min(), y_pred_valid.min())
|
||||
max_val = max(y_true_valid.max(), y_pred_valid.max())
|
||||
plot_range = max_val - min_val
|
||||
if plot_range < 1e-6: # Handle cases where all points are identical
|
||||
plot_range = 1.0 # Avoid zero range
|
||||
|
||||
lim_min = min_val - 0.05 * plot_range
|
||||
lim_max = max_val + 0.05 * plot_range
|
||||
|
||||
ax.scatter(y_true_valid, y_pred_valid, alpha=0.5, s=10, label='Predictions')
|
||||
ax.plot([lim_min, lim_max], [lim_min, lim_max], 'r--', label='Ideal (y=x)', linewidth=1.5)
|
||||
|
||||
ax.set_title(title, fontsize=14)
|
||||
ax.set_xlabel(xlabel, fontsize=12)
|
||||
ax.set_ylabel(ylabel, fontsize=12)
|
||||
ax.set_xlim(lim_min, lim_max)
|
||||
ax.set_ylim(lim_min, lim_max)
|
||||
ax.legend()
|
||||
ax.grid(True, linestyle='--', alpha=0.6)
|
||||
ax.set_aspect('equal', adjustable='box') # Ensure square scaling
|
||||
fig.tight_layout()
|
||||
|
||||
return fig
|
||||
|
||||
def create_residuals_plot(
|
||||
x: np.ndarray,
|
||||
residuals: np.ndarray,
|
||||
title: str,
|
||||
xlabel: str,
|
||||
ylabel: str,
|
||||
xlabel: str = "Time Index",
|
||||
ylabel: str = "Residual (Actual - Predicted)",
|
||||
max_points: Optional[int] = None
|
||||
) -> plt.Figure:
|
||||
"""
|
||||
Create a plot of residuals over time.
|
||||
|
||||
Args:
|
||||
x: The array for the x-axis (e.g., time steps, indices).
|
||||
residuals: Array of residual values (1D array).
|
||||
title: Title for the plot.
|
||||
xlabel: Label for the x-axis.
|
||||
ylabel: Label for the y-axis.
|
||||
max_points: Maximum number of points to display (subsamples if needed).
|
||||
|
||||
Returns:
|
||||
The generated matplotlib Figure object.
|
||||
|
||||
Raises:
|
||||
ValueError: If input array shapes are incompatible.
|
||||
"""
|
||||
# TODO: Implement residuals plot creation
|
||||
pass
|
||||
if not (x.shape == residuals.shape and x.ndim == 1):
|
||||
raise ValueError("Input arrays (x, residuals) must be 1D and have the same shape.")
|
||||
if len(x) == 0:
|
||||
logger.warning("Attempting to create residuals time plot with empty data.")
|
||||
return plt.figure()
|
||||
|
||||
logger.debug(f"Creating residuals time plot: {title}")
|
||||
fig, ax = plt.subplots(figsize=(15, 5)) # Often wider than tall
|
||||
|
||||
n_points = len(x)
|
||||
indices = np.arange(n_points)
|
||||
|
||||
if max_points and n_points > max_points:
|
||||
step = max(1, n_points // max_points)
|
||||
plot_indices = indices[::step]
|
||||
plot_x = x[::step]
|
||||
plot_residuals = residuals[::step]
|
||||
effective_title = f'{title} (Sampled {len(plot_indices)} points)'
|
||||
else:
|
||||
plot_x = x
|
||||
plot_residuals = residuals
|
||||
effective_title = title
|
||||
|
||||
ax.plot(plot_x, plot_residuals, marker='.', linestyle='-', markersize=4, linewidth=1, label='Residuals')
|
||||
ax.axhline(0, color='red', linestyle='--', label='Zero Error', linewidth=1.5)
|
||||
|
||||
ax.set_title(effective_title, fontsize=14)
|
||||
ax.set_xlabel(xlabel, fontsize=12)
|
||||
ax.set_ylabel(ylabel, fontsize=12)
|
||||
ax.legend()
|
||||
ax.grid(True, linestyle='--', alpha=0.6)
|
||||
fig.tight_layout()
|
||||
|
||||
return fig
|
||||
|
||||
def create_residuals_distribution_plot(
|
||||
residuals: np.ndarray,
|
||||
title: str,
|
||||
xlabel: str,
|
||||
ylabel: str
|
||||
xlabel: str = "Residual Value",
|
||||
ylabel: str = "Density"
|
||||
) -> plt.Figure:
|
||||
"""
|
||||
Create a distribution plot of residuals.
|
||||
Create a distribution plot (histogram and KDE) of residuals using seaborn.
|
||||
|
||||
Args:
|
||||
residuals: Array of residual values (1D array).
|
||||
title: Title for the plot.
|
||||
xlabel: Label for the x-axis.
|
||||
ylabel: Label for the y-axis.
|
||||
|
||||
Returns:
|
||||
The generated matplotlib Figure object.
|
||||
|
||||
Raises:
|
||||
ValueError: If input array shape is invalid.
|
||||
"""
|
||||
# TODO: Implement residuals distribution plot creation
|
||||
pass
|
||||
if residuals.ndim != 1:
|
||||
raise ValueError("Input array (residuals) must be 1D.")
|
||||
if len(residuals) == 0:
|
||||
logger.warning("Attempting to create residuals distribution plot with empty data.")
|
||||
return plt.figure()
|
||||
|
||||
logger.debug(f"Creating residuals distribution plot: {title}")
|
||||
fig, ax = plt.subplots(figsize=(8, 6))
|
||||
|
||||
# Filter out NaNs before plotting and calculating stats
|
||||
residuals_valid = residuals[~np.isnan(residuals)]
|
||||
if len(residuals_valid) == 0:
|
||||
logger.warning(f"No valid (non-NaN) data points found for residual distribution plot '{title}'.")
|
||||
return fig # Return empty figure
|
||||
|
||||
# Use seaborn histplot which combines histogram and KDE
|
||||
try:
|
||||
sns.histplot(residuals_valid, kde=True, bins=50, stat="density", ax=ax)
|
||||
except Exception as e:
|
||||
logger.error(f"Seaborn histplot failed for '{title}': {e}. Falling back to matplotlib hist.", exc_info=True)
|
||||
# Fallback to basic matplotlib histogram if seaborn fails
|
||||
ax.hist(residuals_valid, bins=50, density=True, alpha=0.7)
|
||||
ylabel = "Frequency" # Adjust label if only histogram shown
|
||||
|
||||
mean_res = np.mean(residuals_valid)
|
||||
std_res = np.std(residuals_valid)
|
||||
ax.axvline(float(mean_res), color='red', linestyle='--', label=f'Mean: {mean_res:.3f}')
|
||||
|
||||
ax.set_title(f'{title}\n(Std Dev: {std_res:.3f})', fontsize=14)
|
||||
ax.set_xlabel(xlabel, fontsize=12)
|
||||
ax.set_ylabel(ylabel, fontsize=12)
|
||||
ax.legend()
|
||||
ax.grid(True, axis='y', linestyle='--', alpha=0.6)
|
||||
fig.tight_layout()
|
||||
|
||||
return fig
|
Reference in New Issue
Block a user