307 lines
11 KiB
Python
307 lines
11 KiB
Python
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
import numpy as np
|
|
from typing import Optional, Union
|
|
import logging
|
|
|
|
from pathlib import Path
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def setup_plot_style(use_seaborn: bool = True) -> None:
|
|
"""
|
|
Set up a consistent plotting style using seaborn if enabled.
|
|
|
|
Args:
|
|
use_seaborn: Whether to apply seaborn styling.
|
|
"""
|
|
if use_seaborn:
|
|
try:
|
|
sns.set_theme(style="whitegrid", palette="muted")
|
|
plt.rcParams['figure.figsize'] = (12, 6) # Default figure size
|
|
logger.debug("Seaborn plot style set.")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to set seaborn theme: {e}. Using default matplotlib style.")
|
|
else:
|
|
# Optional: Define a default matplotlib style if seaborn is not used
|
|
plt.style.use('default')
|
|
logger.debug("Using default matplotlib plot style.")
|
|
|
|
def save_plot(fig: plt.Figure, filename: Union[str, Path]) -> None:
|
|
"""
|
|
Save matplotlib figure to a file with directory creation and error handling.
|
|
|
|
Args:
|
|
fig: The matplotlib Figure object to save.
|
|
filename: The full path (including filename and extension) to save the plot to.
|
|
Can be a string or Path object.
|
|
|
|
Raises:
|
|
OSError: If the directory cannot be created.
|
|
Exception: For other file saving errors.
|
|
"""
|
|
filepath = Path(filename)
|
|
try:
|
|
# Create the parent directory if it doesn't exist
|
|
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
fig.savefig(filepath, bbox_inches='tight', dpi=150) # Save with tight bounding box and decent resolution
|
|
logger.info(f"Plot saved successfully to: {filepath}")
|
|
except OSError as e:
|
|
logger.error(f"Failed to create directory for plot {filepath}: {e}", exc_info=True)
|
|
raise # Re-raise OSError for directory creation issues
|
|
except Exception as e:
|
|
logger.error(f"Failed to save plot to {filepath}: {e}", exc_info=True)
|
|
raise # Re-raise other saving errors
|
|
finally:
|
|
# Close the figure to free up memory, regardless of saving success
|
|
plt.close(fig)
|
|
|
|
def create_time_series_plot(
|
|
x: np.ndarray,
|
|
y_true: np.ndarray,
|
|
y_pred: np.ndarray,
|
|
title: str,
|
|
xlabel: str = "Time Index",
|
|
ylabel: str = "Value",
|
|
max_points: Optional[int] = None
|
|
) -> plt.Figure:
|
|
"""
|
|
Create a time series plot comparing actual vs predicted values.
|
|
|
|
Args:
|
|
x: The array for the x-axis (e.g., time steps, indices).
|
|
y_true: Ground truth values (1D array).
|
|
y_pred: Predicted values (1D array).
|
|
title: Title for the plot.
|
|
xlabel: Label for the x-axis.
|
|
ylabel: Label for the y-axis.
|
|
max_points: Maximum number of points to display (subsamples if needed).
|
|
|
|
Returns:
|
|
The generated matplotlib Figure object.
|
|
|
|
Raises:
|
|
ValueError: If input array shapes are incompatible.
|
|
"""
|
|
if not (x.shape == y_true.shape == y_pred.shape and x.ndim == 1):
|
|
raise ValueError("Input arrays (x, y_true, y_pred) must be 1D and have the same shape.")
|
|
if len(x) == 0:
|
|
logger.warning("Attempting to create time series plot with empty data.")
|
|
# Return an empty figure or raise error? Let's return empty.
|
|
return plt.figure()
|
|
|
|
logger.debug(f"Creating time series plot: {title}")
|
|
fig, ax = plt.subplots(figsize=(15, 6)) # Consistent size
|
|
|
|
n_points = len(x)
|
|
indices = np.arange(n_points) # Use internal indices for potential slicing
|
|
|
|
if max_points and n_points > max_points:
|
|
step = max(1, n_points // max_points)
|
|
plot_indices = indices[::step]
|
|
plot_x = x[::step]
|
|
plot_y_true = y_true[::step]
|
|
plot_y_pred = y_pred[::step]
|
|
effective_title = f'{title} (Sampled {len(plot_indices)} points)'
|
|
else:
|
|
plot_x = x
|
|
plot_y_true = y_true
|
|
plot_y_pred = y_pred
|
|
effective_title = title
|
|
|
|
ax.plot(plot_x, plot_y_true, label='Actual', marker='.', linestyle='-', markersize=4, linewidth=1.5)
|
|
ax.plot(plot_x, plot_y_pred, label='Predicted', marker='x', linestyle='--', markersize=4, alpha=0.8, linewidth=1)
|
|
|
|
ax.set_title(effective_title, fontsize=14)
|
|
ax.set_xlabel(xlabel, fontsize=12)
|
|
ax.set_ylabel(ylabel, fontsize=12)
|
|
ax.legend()
|
|
ax.grid(True, linestyle='--', alpha=0.6)
|
|
fig.tight_layout()
|
|
|
|
return fig
|
|
|
|
def create_scatter_plot(
|
|
y_true: np.ndarray,
|
|
y_pred: np.ndarray,
|
|
title: str,
|
|
xlabel: str = "Actual Values",
|
|
ylabel: str = "Predicted Values"
|
|
) -> plt.Figure:
|
|
"""
|
|
Create a scatter plot of actual vs predicted values.
|
|
|
|
Args:
|
|
y_true: Ground truth values (1D array).
|
|
y_pred: Predicted values (1D array).
|
|
title: Title for the plot.
|
|
xlabel: Label for the x-axis.
|
|
ylabel: Label for the y-axis.
|
|
|
|
Returns:
|
|
The generated matplotlib Figure object.
|
|
|
|
Raises:
|
|
ValueError: If input array shapes are incompatible.
|
|
"""
|
|
if not (y_true.shape == y_pred.shape and y_true.ndim == 1):
|
|
raise ValueError("Input arrays (y_true, y_pred) must be 1D and have the same shape.")
|
|
if len(y_true) == 0:
|
|
logger.warning("Attempting to create scatter plot with empty data.")
|
|
return plt.figure()
|
|
|
|
logger.debug(f"Creating scatter plot: {title}")
|
|
fig, ax = plt.subplots(figsize=(8, 8)) # Square figure common for scatter
|
|
|
|
# Determine plot limits, handle potential NaNs
|
|
valid_mask = ~np.isnan(y_true) & ~np.isnan(y_pred)
|
|
if not np.any(valid_mask):
|
|
logger.warning(f"No valid (non-NaN) data points found for scatter plot '{title}'.")
|
|
# Return empty figure, plot would be blank
|
|
return fig
|
|
|
|
y_true_valid = y_true[valid_mask]
|
|
y_pred_valid = y_pred[valid_mask]
|
|
|
|
min_val = min(y_true_valid.min(), y_pred_valid.min())
|
|
max_val = max(y_true_valid.max(), y_pred_valid.max())
|
|
plot_range = max_val - min_val
|
|
if plot_range < 1e-6: # Handle cases where all points are identical
|
|
plot_range = 1.0 # Avoid zero range
|
|
|
|
lim_min = min_val - 0.05 * plot_range
|
|
lim_max = max_val + 0.05 * plot_range
|
|
|
|
ax.scatter(y_true_valid, y_pred_valid, alpha=0.5, s=10, label='Predictions')
|
|
ax.plot([lim_min, lim_max], [lim_min, lim_max], 'r--', label='Ideal (y=x)', linewidth=1.5)
|
|
|
|
ax.set_title(title, fontsize=14)
|
|
ax.set_xlabel(xlabel, fontsize=12)
|
|
ax.set_ylabel(ylabel, fontsize=12)
|
|
ax.set_xlim(lim_min, lim_max)
|
|
ax.set_ylim(lim_min, lim_max)
|
|
ax.legend()
|
|
ax.grid(True, linestyle='--', alpha=0.6)
|
|
ax.set_aspect('equal', adjustable='box') # Ensure square scaling
|
|
fig.tight_layout()
|
|
|
|
return fig
|
|
|
|
def create_residuals_plot(
|
|
x: np.ndarray,
|
|
residuals: np.ndarray,
|
|
title: str,
|
|
xlabel: str = "Time Index",
|
|
ylabel: str = "Residual (Actual - Predicted)",
|
|
max_points: Optional[int] = None
|
|
) -> plt.Figure:
|
|
"""
|
|
Create a plot of residuals over time.
|
|
|
|
Args:
|
|
x: The array for the x-axis (e.g., time steps, indices).
|
|
residuals: Array of residual values (1D array).
|
|
title: Title for the plot.
|
|
xlabel: Label for the x-axis.
|
|
ylabel: Label for the y-axis.
|
|
max_points: Maximum number of points to display (subsamples if needed).
|
|
|
|
Returns:
|
|
The generated matplotlib Figure object.
|
|
|
|
Raises:
|
|
ValueError: If input array shapes are incompatible.
|
|
"""
|
|
if not (x.shape == residuals.shape and x.ndim == 1):
|
|
raise ValueError("Input arrays (x, residuals) must be 1D and have the same shape.")
|
|
if len(x) == 0:
|
|
logger.warning("Attempting to create residuals time plot with empty data.")
|
|
return plt.figure()
|
|
|
|
logger.debug(f"Creating residuals time plot: {title}")
|
|
fig, ax = plt.subplots(figsize=(15, 5)) # Often wider than tall
|
|
|
|
n_points = len(x)
|
|
indices = np.arange(n_points)
|
|
|
|
if max_points and n_points > max_points:
|
|
step = max(1, n_points // max_points)
|
|
plot_indices = indices[::step]
|
|
plot_x = x[::step]
|
|
plot_residuals = residuals[::step]
|
|
effective_title = f'{title} (Sampled {len(plot_indices)} points)'
|
|
else:
|
|
plot_x = x
|
|
plot_residuals = residuals
|
|
effective_title = title
|
|
|
|
ax.plot(plot_x, plot_residuals, marker='.', linestyle='-', markersize=4, linewidth=1, label='Residuals')
|
|
ax.axhline(0, color='red', linestyle='--', label='Zero Error', linewidth=1.5)
|
|
|
|
ax.set_title(effective_title, fontsize=14)
|
|
ax.set_xlabel(xlabel, fontsize=12)
|
|
ax.set_ylabel(ylabel, fontsize=12)
|
|
ax.legend()
|
|
ax.grid(True, linestyle='--', alpha=0.6)
|
|
fig.tight_layout()
|
|
|
|
return fig
|
|
|
|
def create_residuals_distribution_plot(
|
|
residuals: np.ndarray,
|
|
title: str,
|
|
xlabel: str = "Residual Value",
|
|
ylabel: str = "Density"
|
|
) -> plt.Figure:
|
|
"""
|
|
Create a distribution plot (histogram and KDE) of residuals using seaborn.
|
|
|
|
Args:
|
|
residuals: Array of residual values (1D array).
|
|
title: Title for the plot.
|
|
xlabel: Label for the x-axis.
|
|
ylabel: Label for the y-axis.
|
|
|
|
Returns:
|
|
The generated matplotlib Figure object.
|
|
|
|
Raises:
|
|
ValueError: If input array shape is invalid.
|
|
"""
|
|
if residuals.ndim != 1:
|
|
raise ValueError("Input array (residuals) must be 1D.")
|
|
if len(residuals) == 0:
|
|
logger.warning("Attempting to create residuals distribution plot with empty data.")
|
|
return plt.figure()
|
|
|
|
logger.debug(f"Creating residuals distribution plot: {title}")
|
|
fig, ax = plt.subplots(figsize=(8, 6))
|
|
|
|
# Filter out NaNs before plotting and calculating stats
|
|
residuals_valid = residuals[~np.isnan(residuals)]
|
|
if len(residuals_valid) == 0:
|
|
logger.warning(f"No valid (non-NaN) data points found for residual distribution plot '{title}'.")
|
|
return fig # Return empty figure
|
|
|
|
# Use seaborn histplot which combines histogram and KDE
|
|
try:
|
|
sns.histplot(residuals_valid, kde=True, bins=50, stat="density", ax=ax)
|
|
except Exception as e:
|
|
logger.error(f"Seaborn histplot failed for '{title}': {e}. Falling back to matplotlib hist.", exc_info=True)
|
|
# Fallback to basic matplotlib histogram if seaborn fails
|
|
ax.hist(residuals_valid, bins=50, density=True, alpha=0.7)
|
|
ylabel = "Frequency" # Adjust label if only histogram shown
|
|
|
|
mean_res = np.mean(residuals_valid)
|
|
std_res = np.std(residuals_valid)
|
|
ax.axvline(float(mean_res), color='red', linestyle='--', label=f'Mean: {mean_res:.3f}')
|
|
|
|
ax.set_title(f'{title}\n(Std Dev: {std_res:.3f})', fontsize=14)
|
|
ax.set_xlabel(xlabel, fontsize=12)
|
|
ax.set_ylabel(ylabel, fontsize=12)
|
|
ax.legend()
|
|
ax.grid(True, axis='y', linestyle='--', alpha=0.6)
|
|
fig.tight_layout()
|
|
|
|
return fig |