Files
entrix_case_challange/forecasting_model/io/plotting.py
2025-05-02 14:36:19 +02:00

307 lines
11 KiB
Python

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from typing import Optional, Union
import logging
from pathlib import Path
logger = logging.getLogger(__name__)
def setup_plot_style(use_seaborn: bool = True) -> None:
"""
Set up a consistent plotting style using seaborn if enabled.
Args:
use_seaborn: Whether to apply seaborn styling.
"""
if use_seaborn:
try:
sns.set_theme(style="whitegrid", palette="muted")
plt.rcParams['figure.figsize'] = (12, 6) # Default figure size
logger.debug("Seaborn plot style set.")
except Exception as e:
logger.warning(f"Failed to set seaborn theme: {e}. Using default matplotlib style.")
else:
# Optional: Define a default matplotlib style if seaborn is not used
plt.style.use('default')
logger.debug("Using default matplotlib plot style.")
def save_plot(fig: plt.Figure, filename: Union[str, Path]) -> None:
"""
Save matplotlib figure to a file with directory creation and error handling.
Args:
fig: The matplotlib Figure object to save.
filename: The full path (including filename and extension) to save the plot to.
Can be a string or Path object.
Raises:
OSError: If the directory cannot be created.
Exception: For other file saving errors.
"""
filepath = Path(filename)
try:
# Create the parent directory if it doesn't exist
filepath.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(filepath, bbox_inches='tight', dpi=150) # Save with tight bounding box and decent resolution
logger.info(f"Plot saved successfully to: {filepath}")
except OSError as e:
logger.error(f"Failed to create directory for plot {filepath}: {e}", exc_info=True)
raise # Re-raise OSError for directory creation issues
except Exception as e:
logger.error(f"Failed to save plot to {filepath}: {e}", exc_info=True)
raise # Re-raise other saving errors
finally:
# Close the figure to free up memory, regardless of saving success
plt.close(fig)
def create_time_series_plot(
x: np.ndarray,
y_true: np.ndarray,
y_pred: np.ndarray,
title: str,
xlabel: str = "Time Index",
ylabel: str = "Value",
max_points: Optional[int] = None
) -> plt.Figure:
"""
Create a time series plot comparing actual vs predicted values.
Args:
x: The array for the x-axis (e.g., time steps, indices).
y_true: Ground truth values (1D array).
y_pred: Predicted values (1D array).
title: Title for the plot.
xlabel: Label for the x-axis.
ylabel: Label for the y-axis.
max_points: Maximum number of points to display (subsamples if needed).
Returns:
The generated matplotlib Figure object.
Raises:
ValueError: If input array shapes are incompatible.
"""
if not (x.shape == y_true.shape == y_pred.shape and x.ndim == 1):
raise ValueError("Input arrays (x, y_true, y_pred) must be 1D and have the same shape.")
if len(x) == 0:
logger.warning("Attempting to create time series plot with empty data.")
# Return an empty figure or raise error? Let's return empty.
return plt.figure()
logger.debug(f"Creating time series plot: {title}")
fig, ax = plt.subplots(figsize=(15, 6)) # Consistent size
n_points = len(x)
indices = np.arange(n_points) # Use internal indices for potential slicing
if max_points and n_points > max_points:
step = max(1, n_points // max_points)
plot_indices = indices[::step]
plot_x = x[::step]
plot_y_true = y_true[::step]
plot_y_pred = y_pred[::step]
effective_title = f'{title} (Sampled {len(plot_indices)} points)'
else:
plot_x = x
plot_y_true = y_true
plot_y_pred = y_pred
effective_title = title
ax.plot(plot_x, plot_y_true, label='Actual', marker='.', linestyle='-', markersize=4, linewidth=1.5)
ax.plot(plot_x, plot_y_pred, label='Predicted', marker='x', linestyle='--', markersize=4, alpha=0.8, linewidth=1)
ax.set_title(effective_title, fontsize=14)
ax.set_xlabel(xlabel, fontsize=12)
ax.set_ylabel(ylabel, fontsize=12)
ax.legend()
ax.grid(True, linestyle='--', alpha=0.6)
fig.tight_layout()
return fig
def create_scatter_plot(
y_true: np.ndarray,
y_pred: np.ndarray,
title: str,
xlabel: str = "Actual Values",
ylabel: str = "Predicted Values"
) -> plt.Figure:
"""
Create a scatter plot of actual vs predicted values.
Args:
y_true: Ground truth values (1D array).
y_pred: Predicted values (1D array).
title: Title for the plot.
xlabel: Label for the x-axis.
ylabel: Label for the y-axis.
Returns:
The generated matplotlib Figure object.
Raises:
ValueError: If input array shapes are incompatible.
"""
if not (y_true.shape == y_pred.shape and y_true.ndim == 1):
raise ValueError("Input arrays (y_true, y_pred) must be 1D and have the same shape.")
if len(y_true) == 0:
logger.warning("Attempting to create scatter plot with empty data.")
return plt.figure()
logger.debug(f"Creating scatter plot: {title}")
fig, ax = plt.subplots(figsize=(8, 8)) # Square figure common for scatter
# Determine plot limits, handle potential NaNs
valid_mask = ~np.isnan(y_true) & ~np.isnan(y_pred)
if not np.any(valid_mask):
logger.warning(f"No valid (non-NaN) data points found for scatter plot '{title}'.")
# Return empty figure, plot would be blank
return fig
y_true_valid = y_true[valid_mask]
y_pred_valid = y_pred[valid_mask]
min_val = min(y_true_valid.min(), y_pred_valid.min())
max_val = max(y_true_valid.max(), y_pred_valid.max())
plot_range = max_val - min_val
if plot_range < 1e-6: # Handle cases where all points are identical
plot_range = 1.0 # Avoid zero range
lim_min = min_val - 0.05 * plot_range
lim_max = max_val + 0.05 * plot_range
ax.scatter(y_true_valid, y_pred_valid, alpha=0.5, s=10, label='Predictions')
ax.plot([lim_min, lim_max], [lim_min, lim_max], 'r--', label='Ideal (y=x)', linewidth=1.5)
ax.set_title(title, fontsize=14)
ax.set_xlabel(xlabel, fontsize=12)
ax.set_ylabel(ylabel, fontsize=12)
ax.set_xlim(lim_min, lim_max)
ax.set_ylim(lim_min, lim_max)
ax.legend()
ax.grid(True, linestyle='--', alpha=0.6)
ax.set_aspect('equal', adjustable='box') # Ensure square scaling
fig.tight_layout()
return fig
def create_residuals_plot(
x: np.ndarray,
residuals: np.ndarray,
title: str,
xlabel: str = "Time Index",
ylabel: str = "Residual (Actual - Predicted)",
max_points: Optional[int] = None
) -> plt.Figure:
"""
Create a plot of residuals over time.
Args:
x: The array for the x-axis (e.g., time steps, indices).
residuals: Array of residual values (1D array).
title: Title for the plot.
xlabel: Label for the x-axis.
ylabel: Label for the y-axis.
max_points: Maximum number of points to display (subsamples if needed).
Returns:
The generated matplotlib Figure object.
Raises:
ValueError: If input array shapes are incompatible.
"""
if not (x.shape == residuals.shape and x.ndim == 1):
raise ValueError("Input arrays (x, residuals) must be 1D and have the same shape.")
if len(x) == 0:
logger.warning("Attempting to create residuals time plot with empty data.")
return plt.figure()
logger.debug(f"Creating residuals time plot: {title}")
fig, ax = plt.subplots(figsize=(15, 5)) # Often wider than tall
n_points = len(x)
indices = np.arange(n_points)
if max_points and n_points > max_points:
step = max(1, n_points // max_points)
plot_indices = indices[::step]
plot_x = x[::step]
plot_residuals = residuals[::step]
effective_title = f'{title} (Sampled {len(plot_indices)} points)'
else:
plot_x = x
plot_residuals = residuals
effective_title = title
ax.plot(plot_x, plot_residuals, marker='.', linestyle='-', markersize=4, linewidth=1, label='Residuals')
ax.axhline(0, color='red', linestyle='--', label='Zero Error', linewidth=1.5)
ax.set_title(effective_title, fontsize=14)
ax.set_xlabel(xlabel, fontsize=12)
ax.set_ylabel(ylabel, fontsize=12)
ax.legend()
ax.grid(True, linestyle='--', alpha=0.6)
fig.tight_layout()
return fig
def create_residuals_distribution_plot(
residuals: np.ndarray,
title: str,
xlabel: str = "Residual Value",
ylabel: str = "Density"
) -> plt.Figure:
"""
Create a distribution plot (histogram and KDE) of residuals using seaborn.
Args:
residuals: Array of residual values (1D array).
title: Title for the plot.
xlabel: Label for the x-axis.
ylabel: Label for the y-axis.
Returns:
The generated matplotlib Figure object.
Raises:
ValueError: If input array shape is invalid.
"""
if residuals.ndim != 1:
raise ValueError("Input array (residuals) must be 1D.")
if len(residuals) == 0:
logger.warning("Attempting to create residuals distribution plot with empty data.")
return plt.figure()
logger.debug(f"Creating residuals distribution plot: {title}")
fig, ax = plt.subplots(figsize=(8, 6))
# Filter out NaNs before plotting and calculating stats
residuals_valid = residuals[~np.isnan(residuals)]
if len(residuals_valid) == 0:
logger.warning(f"No valid (non-NaN) data points found for residual distribution plot '{title}'.")
return fig # Return empty figure
# Use seaborn histplot which combines histogram and KDE
try:
sns.histplot(residuals_valid, kde=True, bins=50, stat="density", ax=ax)
except Exception as e:
logger.error(f"Seaborn histplot failed for '{title}': {e}. Falling back to matplotlib hist.", exc_info=True)
# Fallback to basic matplotlib histogram if seaborn fails
ax.hist(residuals_valid, bins=50, density=True, alpha=0.7)
ylabel = "Frequency" # Adjust label if only histogram shown
mean_res = np.mean(residuals_valid)
std_res = np.std(residuals_valid)
ax.axvline(float(mean_res), color='red', linestyle='--', label=f'Mean: {mean_res:.3f}')
ax.set_title(f'{title}\n(Std Dev: {std_res:.3f})', fontsize=14)
ax.set_xlabel(xlabel, fontsize=12)
ax.set_ylabel(ylabel, fontsize=12)
ax.legend()
ax.grid(True, axis='y', linestyle='--', alpha=0.6)
fig.tight_layout()
return fig