import matplotlib.pyplot as plt import seaborn as sns import numpy as np from typing import Optional, Union import logging from pathlib import Path logger = logging.getLogger(__name__) def setup_plot_style(use_seaborn: bool = True) -> None: """ Set up a consistent plotting style using seaborn if enabled. Args: use_seaborn: Whether to apply seaborn styling. """ if use_seaborn: try: sns.set_theme(style="whitegrid", palette="muted") plt.rcParams['figure.figsize'] = (12, 6) # Default figure size logger.debug("Seaborn plot style set.") except Exception as e: logger.warning(f"Failed to set seaborn theme: {e}. Using default matplotlib style.") else: # Optional: Define a default matplotlib style if seaborn is not used plt.style.use('default') logger.debug("Using default matplotlib plot style.") def save_plot(fig: plt.Figure, filename: Union[str, Path]) -> None: """ Save matplotlib figure to a file with directory creation and error handling. Args: fig: The matplotlib Figure object to save. filename: The full path (including filename and extension) to save the plot to. Can be a string or Path object. Raises: OSError: If the directory cannot be created. Exception: For other file saving errors. """ filepath = Path(filename) try: # Create the parent directory if it doesn't exist filepath.parent.mkdir(parents=True, exist_ok=True) fig.savefig(filepath, bbox_inches='tight', dpi=150) # Save with tight bounding box and decent resolution logger.info(f"Plot saved successfully to: {filepath}") except OSError as e: logger.error(f"Failed to create directory for plot {filepath}: {e}", exc_info=True) raise # Re-raise OSError for directory creation issues except Exception as e: logger.error(f"Failed to save plot to {filepath}: {e}", exc_info=True) raise # Re-raise other saving errors finally: # Close the figure to free up memory, regardless of saving success plt.close(fig) def create_time_series_plot( x: np.ndarray, y_true: np.ndarray, y_pred: np.ndarray, title: str, xlabel: str = "Time Index", ylabel: str = "Value", max_points: Optional[int] = None ) -> plt.Figure: """ Create a time series plot comparing actual vs predicted values. Args: x: The array for the x-axis (e.g., time steps, indices). y_true: Ground truth values (1D array). y_pred: Predicted values (1D array). title: Title for the plot. xlabel: Label for the x-axis. ylabel: Label for the y-axis. max_points: Maximum number of points to display (subsamples if needed). Returns: The generated matplotlib Figure object. Raises: ValueError: If input array shapes are incompatible. """ if not (x.shape == y_true.shape == y_pred.shape and x.ndim == 1): raise ValueError("Input arrays (x, y_true, y_pred) must be 1D and have the same shape.") if len(x) == 0: logger.warning("Attempting to create time series plot with empty data.") # Return an empty figure or raise error? Let's return empty. return plt.figure() logger.debug(f"Creating time series plot: {title}") fig, ax = plt.subplots(figsize=(15, 6)) # Consistent size n_points = len(x) indices = np.arange(n_points) # Use internal indices for potential slicing if max_points and n_points > max_points: step = max(1, n_points // max_points) plot_indices = indices[::step] plot_x = x[::step] plot_y_true = y_true[::step] plot_y_pred = y_pred[::step] effective_title = f'{title} (Sampled {len(plot_indices)} points)' else: plot_x = x plot_y_true = y_true plot_y_pred = y_pred effective_title = title ax.plot(plot_x, plot_y_true, label='Actual', marker='.', linestyle='-', markersize=4, linewidth=1.5) ax.plot(plot_x, plot_y_pred, label='Predicted', marker='x', linestyle='--', markersize=4, alpha=0.8, linewidth=1) ax.set_title(effective_title, fontsize=14) ax.set_xlabel(xlabel, fontsize=12) ax.set_ylabel(ylabel, fontsize=12) ax.legend() ax.grid(True, linestyle='--', alpha=0.6) fig.tight_layout() return fig def create_scatter_plot( y_true: np.ndarray, y_pred: np.ndarray, title: str, xlabel: str = "Actual Values", ylabel: str = "Predicted Values" ) -> plt.Figure: """ Create a scatter plot of actual vs predicted values. Args: y_true: Ground truth values (1D array). y_pred: Predicted values (1D array). title: Title for the plot. xlabel: Label for the x-axis. ylabel: Label for the y-axis. Returns: The generated matplotlib Figure object. Raises: ValueError: If input array shapes are incompatible. """ if not (y_true.shape == y_pred.shape and y_true.ndim == 1): raise ValueError("Input arrays (y_true, y_pred) must be 1D and have the same shape.") if len(y_true) == 0: logger.warning("Attempting to create scatter plot with empty data.") return plt.figure() logger.debug(f"Creating scatter plot: {title}") fig, ax = plt.subplots(figsize=(8, 8)) # Square figure common for scatter # Determine plot limits, handle potential NaNs valid_mask = ~np.isnan(y_true) & ~np.isnan(y_pred) if not np.any(valid_mask): logger.warning(f"No valid (non-NaN) data points found for scatter plot '{title}'.") # Return empty figure, plot would be blank return fig y_true_valid = y_true[valid_mask] y_pred_valid = y_pred[valid_mask] min_val = min(y_true_valid.min(), y_pred_valid.min()) max_val = max(y_true_valid.max(), y_pred_valid.max()) plot_range = max_val - min_val if plot_range < 1e-6: # Handle cases where all points are identical plot_range = 1.0 # Avoid zero range lim_min = min_val - 0.05 * plot_range lim_max = max_val + 0.05 * plot_range ax.scatter(y_true_valid, y_pred_valid, alpha=0.5, s=10, label='Predictions') ax.plot([lim_min, lim_max], [lim_min, lim_max], 'r--', label='Ideal (y=x)', linewidth=1.5) ax.set_title(title, fontsize=14) ax.set_xlabel(xlabel, fontsize=12) ax.set_ylabel(ylabel, fontsize=12) ax.set_xlim(lim_min, lim_max) ax.set_ylim(lim_min, lim_max) ax.legend() ax.grid(True, linestyle='--', alpha=0.6) ax.set_aspect('equal', adjustable='box') # Ensure square scaling fig.tight_layout() return fig def create_residuals_plot( x: np.ndarray, residuals: np.ndarray, title: str, xlabel: str = "Time Index", ylabel: str = "Residual (Actual - Predicted)", max_points: Optional[int] = None ) -> plt.Figure: """ Create a plot of residuals over time. Args: x: The array for the x-axis (e.g., time steps, indices). residuals: Array of residual values (1D array). title: Title for the plot. xlabel: Label for the x-axis. ylabel: Label for the y-axis. max_points: Maximum number of points to display (subsamples if needed). Returns: The generated matplotlib Figure object. Raises: ValueError: If input array shapes are incompatible. """ if not (x.shape == residuals.shape and x.ndim == 1): raise ValueError("Input arrays (x, residuals) must be 1D and have the same shape.") if len(x) == 0: logger.warning("Attempting to create residuals time plot with empty data.") return plt.figure() logger.debug(f"Creating residuals time plot: {title}") fig, ax = plt.subplots(figsize=(15, 5)) # Often wider than tall n_points = len(x) indices = np.arange(n_points) if max_points and n_points > max_points: step = max(1, n_points // max_points) plot_indices = indices[::step] plot_x = x[::step] plot_residuals = residuals[::step] effective_title = f'{title} (Sampled {len(plot_indices)} points)' else: plot_x = x plot_residuals = residuals effective_title = title ax.plot(plot_x, plot_residuals, marker='.', linestyle='-', markersize=4, linewidth=1, label='Residuals') ax.axhline(0, color='red', linestyle='--', label='Zero Error', linewidth=1.5) ax.set_title(effective_title, fontsize=14) ax.set_xlabel(xlabel, fontsize=12) ax.set_ylabel(ylabel, fontsize=12) ax.legend() ax.grid(True, linestyle='--', alpha=0.6) fig.tight_layout() return fig def create_residuals_distribution_plot( residuals: np.ndarray, title: str, xlabel: str = "Residual Value", ylabel: str = "Density" ) -> plt.Figure: """ Create a distribution plot (histogram and KDE) of residuals using seaborn. Args: residuals: Array of residual values (1D array). title: Title for the plot. xlabel: Label for the x-axis. ylabel: Label for the y-axis. Returns: The generated matplotlib Figure object. Raises: ValueError: If input array shape is invalid. """ if residuals.ndim != 1: raise ValueError("Input array (residuals) must be 1D.") if len(residuals) == 0: logger.warning("Attempting to create residuals distribution plot with empty data.") return plt.figure() logger.debug(f"Creating residuals distribution plot: {title}") fig, ax = plt.subplots(figsize=(8, 6)) # Filter out NaNs before plotting and calculating stats residuals_valid = residuals[~np.isnan(residuals)] if len(residuals_valid) == 0: logger.warning(f"No valid (non-NaN) data points found for residual distribution plot '{title}'.") return fig # Return empty figure # Use seaborn histplot which combines histogram and KDE try: sns.histplot(residuals_valid, kde=True, bins=50, stat="density", ax=ax) except Exception as e: logger.error(f"Seaborn histplot failed for '{title}': {e}. Falling back to matplotlib hist.", exc_info=True) # Fallback to basic matplotlib histogram if seaborn fails ax.hist(residuals_valid, bins=50, density=True, alpha=0.7) ylabel = "Frequency" # Adjust label if only histogram shown mean_res = np.mean(residuals_valid) std_res = np.std(residuals_valid) ax.axvline(float(mean_res), color='red', linestyle='--', label=f'Mean: {mean_res:.3f}') ax.set_title(f'{title}\n(Std Dev: {std_res:.3f})', fontsize=14) ax.set_xlabel(xlabel, fontsize=12) ax.set_ylabel(ylabel, fontsize=12) ax.legend() ax.grid(True, axis='y', linestyle='--', alpha=0.6) fig.tight_layout() return fig