import matplotlib.pyplot as plt import seaborn as sns import numpy as np from typing import Optional, Union, List import logging import pandas as pd from pathlib import Path # Assuming sklearn scalers are available from sklearn.preprocessing import StandardScaler, MinMaxScaler logger = logging.getLogger(__name__) def setup_plot_style(use_seaborn: bool = True) -> None: """ Set up a consistent plotting style using seaborn if enabled. Args: use_seaborn: Whether to apply seaborn styling. """ if use_seaborn: try: # Use a different style that might be better for multiple lines sns.set_theme(style="whitegrid", palette="viridis") # Changed palette plt.rcParams['figure.figsize'] = (15, 7) # Slightly larger default figure size logger.debug("Seaborn plot style set.") except Exception as e: logger.warning(f"Failed to set seaborn theme: {e}. Using default matplotlib style.") else: # Optional: Define a default matplotlib style if seaborn is not used plt.style.use('default') plt.rcParams['figure.figsize'] = (15, 7) logger.debug("Using default matplotlib plot style.") def save_plot(fig: plt.Figure, filename: Union[str, Path]) -> None: """ Save matplotlib figure to a file with directory creation and error handling. Args: fig: The matplotlib Figure object to save. filename: The full path (including filename and extension) to save the plot to. Can be a string or Path object. Raises: OSError: If the directory cannot be created. Exception: For other file saving errors. """ filepath = Path(filename) try: # Create the parent directory if it doesn't exist filepath.parent.mkdir(parents=True, exist_ok=True) fig.savefig(filepath, bbox_inches='tight', dpi=150) # Save with tight bounding box and decent resolution logger.info(f"Plot saved successfully to: {filepath}") except OSError as e: logger.error(f"Failed to create directory for plot {filepath}: {e}", exc_info=True) # Don't re-raise immediately, try closing figure first # raise # Re-raise OSError for directory creation issues - Removed to ensure finally runs except Exception as e: logger.error(f"Failed to save plot to {filepath}: {e}", exc_info=True) # Don't re-raise immediately, try closing figure first finally: # Close the figure to free up memory, regardless of saving success or failure try: plt.close(fig) logger.debug(f"Closed figure for plot {filepath}.") except Exception as e: logger.warning(f"Failed to close figure for plot {filepath}: {e}") def create_time_series_plot( x: Union[np.ndarray, pd.Index], # Allow pd.Index for time axis y_true: np.ndarray, y_pred: np.ndarray, title: str, xlabel: str = "Time Index", ylabel: str = "Value", max_points: Optional[int] = None ) -> plt.Figure: """ Create a time series plot comparing actual vs predicted values. NOTE: When using multi-horizon forecasts, this typically plots only ONE selected horizon. Args: x: The array or index for the x-axis (e.g., time steps, datetime index). Should align with y_true/y_pred. y_true: Ground truth values (1D array). y_pred: Predicted values (1D array). title: Title for the plot. xlabel: Label for the x-axis. ylabel: Label for the y-axis. max_points: Maximum number of points to display (subsamples if needed). Returns: The generated matplotlib Figure object. Raises: ValueError: If input array shapes are incompatible. """ # Add check for pd.Index for x if not isinstance(x, (np.ndarray, pd.Index)) or x.shape[0] != y_true.shape[0] or x.shape[0] != y_pred.shape[0] or y_true.ndim != 1 or y_pred.ndim != 1: raise ValueError(f"Input shapes mismatch or invalid types: x({type(x)}, {x.shape if hasattr(x, 'shape') else 'N/A'}), y_true({y_true.shape}), y_pred({y_pred.shape}). Expecting 1D y arrays and matching length x.") if len(x) == 0: logger.warning("Attempting to create time series plot with empty data.") # Return an empty figure or raise error? Let's return empty. return plt.figure() logger.debug(f"Creating time series plot: {title}") fig, ax = plt.subplots(figsize=(15, 6)) # Consistent size n_points = len(x) indices = np.arange(n_points) # Use internal indices for potential slicing if max_points and n_points > max_points: step = max(1, n_points // max_points) plot_indices = indices[::step] plot_x = x[::step] plot_y_true = y_true[::step] plot_y_pred = y_pred[::step] effective_title = f'{title} (Sampled {len(plot_indices)} points)' else: plot_x = x plot_y_true = y_true plot_y_pred = y_pred effective_title = title ax.plot(plot_x, plot_y_true, label='Actual', marker='.', linestyle='-', markersize=4, linewidth=1.5) ax.plot(plot_x, plot_y_pred, label='Predicted', marker='x', linestyle='--', markersize=4, alpha=0.8, linewidth=1) ax.set_title(effective_title, fontsize=14) ax.set_xlabel(xlabel, fontsize=12) ax.set_ylabel(ylabel, fontsize=12) ax.legend() ax.grid(True, linestyle='--', alpha=0.6) fig.tight_layout() return fig def create_scatter_plot( y_true: np.ndarray, y_pred: np.ndarray, title: str, xlabel: str = "Actual Values", ylabel: str = "Predicted Values" ) -> plt.Figure: """ Create a scatter plot of actual vs predicted values. Args: y_true: Ground truth values (1D array). y_pred: Predicted values (1D array). title: Title for the plot. xlabel: Label for the x-axis. ylabel: Label for the y-axis. Returns: The generated matplotlib Figure object. Raises: ValueError: If input array shapes are incompatible. """ if not (y_true.shape == y_pred.shape and y_true.ndim == 1): raise ValueError("Input arrays (y_true, y_pred) must be 1D and have the same shape.") if len(y_true) == 0: logger.warning("Attempting to create scatter plot with empty data.") return plt.figure() logger.debug(f"Creating scatter plot: {title}") fig, ax = plt.subplots(figsize=(8, 8)) # Square figure common for scatter # Determine plot limits, handle potential NaNs valid_mask = ~np.isnan(y_true) & ~np.isnan(y_pred) if not np.any(valid_mask): logger.warning(f"No valid (non-NaN) data points found for scatter plot '{title}'.") # Return empty figure, plot would be blank return fig y_true_valid = y_true[valid_mask] y_pred_valid = y_pred[valid_mask] min_val = min(y_true_valid.min(), y_pred_valid.min()) max_val = max(y_true_valid.max(), y_pred_valid.max()) plot_range = max_val - min_val if plot_range < 1e-6: # Handle cases where all points are identical plot_range = 1.0 # Avoid zero range lim_min = min_val - 0.05 * plot_range lim_max = max_val + 0.05 * plot_range ax.scatter(y_true_valid, y_pred_valid, alpha=0.5, s=10, label='Predictions') ax.plot([lim_min, lim_max], [lim_min, lim_max], 'r--', label='Ideal (y=x)', linewidth=1.5) ax.set_title(title, fontsize=14) ax.set_xlabel(xlabel, fontsize=12) ax.set_ylabel(ylabel, fontsize=12) ax.set_xlim(lim_min, lim_max) ax.set_ylim(lim_min, lim_max) ax.legend() ax.grid(True, linestyle='--', alpha=0.6) ax.set_aspect('equal', adjustable='box') # Ensure square scaling fig.tight_layout() return fig def create_residuals_plot( x: np.ndarray, residuals: np.ndarray, title: str, xlabel: str = "Time Index", ylabel: str = "Residual (Actual - Predicted)", max_points: Optional[int] = None ) -> plt.Figure: """ Create a plot of residuals over time. Args: x: The array for the x-axis (e.g., time steps, indices). residuals: Array of residual values (1D array). title: Title for the plot. xlabel: Label for the x-axis. ylabel: Label for the y-axis. max_points: Maximum number of points to display (subsamples if needed). Returns: The generated matplotlib Figure object. Raises: ValueError: If input array shapes are incompatible. """ if not (x.shape == residuals.shape and x.ndim == 1): raise ValueError("Input arrays (x, residuals) must be 1D and have the same shape.") if len(x) == 0: logger.warning("Attempting to create residuals time plot with empty data.") return plt.figure() logger.debug(f"Creating residuals time plot: {title}") fig, ax = plt.subplots(figsize=(15, 5)) # Often wider than tall n_points = len(x) indices = np.arange(n_points) if max_points and n_points > max_points: step = max(1, n_points // max_points) plot_indices = indices[::step] plot_x = x[::step] plot_residuals = residuals[::step] effective_title = f'{title} (Sampled {len(plot_indices)} points)' else: plot_x = x plot_residuals = residuals effective_title = title ax.plot(plot_x, plot_residuals, marker='.', linestyle='-', markersize=4, linewidth=1, label='Residuals') ax.axhline(0, color='red', linestyle='--', label='Zero Error', linewidth=1.5) ax.set_title(effective_title, fontsize=14) ax.set_xlabel(xlabel, fontsize=12) ax.set_ylabel(ylabel, fontsize=12) ax.legend() ax.grid(True, linestyle='--', alpha=0.6) fig.tight_layout() return fig def create_residuals_distribution_plot( residuals: np.ndarray, title: str, xlabel: str = "Residual Value", ylabel: str = "Density" ) -> plt.Figure: """ Create a distribution plot (histogram and KDE) of residuals using seaborn. Args: residuals: Array of residual values (1D array). title: Title for the plot. xlabel: Label for the x-axis. ylabel: Label for the y-axis. Returns: The generated matplotlib Figure object. Raises: ValueError: If input array shape is invalid. """ if residuals.ndim != 1: raise ValueError("Input array (residuals) must be 1D.") if len(residuals) == 0: logger.warning("Attempting to create residuals distribution plot with empty data.") return plt.figure() logger.debug(f"Creating residuals distribution plot: {title}") fig, ax = plt.subplots(figsize=(8, 6)) # Filter out NaNs before plotting and calculating stats residuals_valid = residuals[~np.isnan(residuals)] if len(residuals_valid) == 0: logger.warning(f"No valid (non-NaN) data points found for residual distribution plot '{title}'.") return fig # Return empty figure # Use seaborn histplot which combines histogram and KDE try: sns.histplot(residuals_valid, kde=True, bins=50, stat="density", ax=ax) except Exception as e: logger.error(f"Seaborn histplot failed for '{title}': {e}. Falling back to matplotlib hist.", exc_info=True) # Fallback to basic matplotlib histogram if seaborn fails ax.hist(residuals_valid, bins=50, density=True, alpha=0.7) ylabel = "Frequency" # Adjust label if only histogram shown mean_res = np.mean(residuals_valid) std_res = np.std(residuals_valid) ax.axvline(float(mean_res), color='red', linestyle='--', label=f'Mean: {mean_res:.3f}') ax.set_title(f'{title}\n(Std Dev: {std_res:.3f})', fontsize=14) ax.set_xlabel(xlabel, fontsize=12) ax.set_ylabel(ylabel, fontsize=12) ax.legend() ax.grid(True, axis='y', linestyle='--', alpha=0.6) fig.tight_layout() return fig def create_multi_horizon_time_series_plot( y_true_scaled_all_horizons: np.ndarray, # (N, H) y_pred_scaled_all_horizons: np.ndarray, # (N, H) target_scaler: Optional[Union[StandardScaler, MinMaxScaler]], prediction_time_index_h1: pd.DatetimeIndex, # Time index for the first horizon predictions forecast_horizons: List[int], title: str, xlabel: str = "Time", ylabel: str = "Value (Original Scale)", max_points: Optional[int] = 1000 # Limit points for clarity ) -> plt.Figure: """ Create a time series plot comparing actual values to predictions for multiple horizons. Predictions for each horizon are plotted on their corresponding target time step. Args: y_true_scaled_all_horizons: Ground truth values (N, H array) on scaled scale. y_pred_scaled_all_horizons: Predicted values (N, H array) on scaled scale. target_scaler: The scaler used for the target variable, needed for inverse transform. prediction_time_index_h1: DatetimeIndex for the first horizon (h=h1) predictions. Length should be N. forecast_horizons: List of forecast horizons (e.g., [1, 6, 12, 24]). title: Title for the plot. xlabel: Label for the x-axis. ylabel: Label for the y-axis. max_points: Maximum number of points to display (subsamples if needed). Returns: The generated matplotlib Figure object. Raises: ValueError: If input shapes are incompatible or horizons list is invalid. """ if y_true_scaled_all_horizons.shape != y_pred_scaled_all_horizons.shape: raise ValueError(f"Shapes of y_true_scaled_all_horizons {y_true_scaled_all_horizons.shape} and y_pred_scaled_all_horizons {y_pred_scaled_all_horizons.shape} must match.") if y_true_scaled_all_horizons.ndim != 2 or y_true_scaled_all_horizons.shape[1] != len(forecast_horizons): raise ValueError(f"y arrays must be 2D (N, H) where H is the number of horizons ({len(forecast_horizons)}). Shape is {y_true_scaled_all_horizons.shape}.") if len(prediction_time_index_h1) != y_true_scaled_all_horizons.shape[0]: raise ValueError(f"Length of prediction_time_index_h1 ({len(prediction_time_index_h1)}) must match the number of predictions ({y_true_scaled_all_horizons.shape[0]}).") if not isinstance(prediction_time_index_h1, pd.DatetimeIndex): logger.warning("prediction_time_index_h1 is not a DatetimeIndex. Time shifts may not work as expected.") if not forecast_horizons or len(forecast_horizons) == 0: raise ValueError("forecast_horizons list cannot be empty.") logger.debug(f"Creating multi-horizon time series plot: {title}") setup_plot_style() # Apply standard style fig, ax = plt.subplots(figsize=(18, 8)) # Larger figure for multi-horizon n_points = y_true_scaled_all_horizons.shape[0] plot_indices = np.arange(n_points) if max_points and n_points > max_points: step = max(1, n_points // max_points) plot_indices = plot_indices[::step] # Subsample the data and index y_true_scaled_plot = y_true_scaled_all_horizons[plot_indices] y_pred_scaled_plot = y_pred_scaled_all_horizons[plot_indices] time_index_h1_plot = prediction_time_index_h1[plot_indices] effective_title = f'{title} (Sampled {len(plot_indices)} points)' else: y_true_scaled_plot = y_true_scaled_all_horizons y_pred_scaled_plot = y_pred_scaled_all_horizons time_index_h1_plot = prediction_time_index_h1 effective_title = title # Inverse transform the subsampled data y_true_inv_plot = None y_pred_inv_plot = None if target_scaler is not None: try: # Scaler expects (N * H, 1), reshape (N, H) to (N*H, 1) y_true_inv_plot_flat = target_scaler.inverse_transform(y_true_scaled_plot.reshape(-1, 1)) y_pred_inv_plot_flat = target_scaler.inverse_transform(y_pred_scaled_plot.reshape(-1, 1)) # Reshape back to (N, H) y_true_inv_plot = y_true_inv_plot_flat.reshape(y_true_scaled_plot.shape) y_pred_inv_plot = y_pred_inv_plot_flat.reshape(y_pred_scaled_plot.shape) logger.debug("Successfully inverse-transformed data for multi-horizon plot.") except Exception as e: logger.error(f"Failed to inverse transform data for multi-horizon plot: {e}", exc_info=True) # Fallback to plotting scaled data if inverse transform fails y_true_inv_plot = y_true_scaled_plot y_pred_inv_plot = y_pred_scaled_plot ylabel = f"{ylabel} (Scaled Data - Inverse Transform Failed)" if y_true_inv_plot is None or y_pred_inv_plot is None: # This should not happen with the fallback, but as a safeguard logger.error("Inverse transformed data is None, cannot plot.") return fig # Return empty figure # Plot Actuals (using h1's time index, as it's the reference point) ax.plot(time_index_h1_plot, y_true_inv_plot[:, 0], label='Actuals', marker='.', linestyle='-', markersize=4, linewidth=1.5, color='black') # Actuals for H1 # Plot predictions for each horizon colors = sns.color_palette("viridis", len(forecast_horizons)) # Use palette for distinct colors linestyles = ['-', '--', '-.', ':'] * (len(forecast_horizons) // 4 + 1) # Cycle through linestyles for i, horizon in enumerate(forecast_horizons): preds_h = y_pred_inv_plot[:, i] # Calculate time index for this specific horizon by shifting the h1 index # Assumes the time index frequency is appropriate for the horizon steps try: time_index_h = time_index_h1_plot + pd.to_timedelta(horizon - forecast_horizons[0], unit='h') # Assuming 'h' for hours ax.plot(time_index_h, preds_h, label=f'Predicted (h={horizon})', marker='x', linestyle=linestyles[i], markersize=4, alpha=0.8, linewidth=1, color=colors[i]) except Exception as e: logger.warning(f"Could not calculate time index for horizon {horizon}: {e}. Skipping plot for this horizon.", exc_info=True) # Configure plot appearance ax.set_title(effective_title, fontsize=16) # Slightly larger title ax.set_xlabel(xlabel, fontsize=12) ax.set_ylabel(ylabel, fontsize=12) ax.legend(fontsize=10) # Smaller legend font ax.grid(True, linestyle='--', alpha=0.6) # Improve x-axis readability for datetimes fig.autofmt_xdate() # Auto-rotate date labels fig.tight_layout() return fig def plot_loss_curve_from_csv( metrics_csv_path: Union[str, Path], output_path: Union[str, Path], title: str = "Training Loss Curve", train_loss_col: str = "train_loss", # Changed to match logging in model.py val_loss_col: str = "val_loss", # Common validation loss metric logged by PL epoch_col: str = "epoch" ) -> None: """ Reads training metrics from a PyTorch Lightning CSVLogger file and plots training and validation loss curves over epochs. Args: metrics_csv_path: Path to the metrics.csv file generated by CSVLogger. output_path: Path where the plot image will be saved. title: Title for the plot. train_loss_col: Name of the column containing epoch-level training loss. val_loss_col: Name of the column containing epoch-level validation loss. epoch_col: Name of the column containing the epoch number. Raises: FileNotFoundError: If the metrics_csv_path does not exist. KeyError: If required columns are not found in the CSV. Exception: For other plotting or file reading errors. """ logger.info(f"Generating loss curve plot from: {metrics_csv_path}") metrics_path = Path(metrics_csv_path) if not metrics_path.is_file(): raise FileNotFoundError(f"Metrics CSV file not found at: {metrics_path}") try: metrics_df = pd.read_csv(metrics_path) # Check if required columns exist required_cols = [epoch_col, train_loss_col] # Val loss column might be the scaled loss or the original scale MAE possible_val_cols = [val_loss_col, 'val_MeanAbsoluteError_Original_Scale', 'val_mae_orig_scale'] # Include potential names found_val_col = None for col in possible_val_cols: if col in metrics_df.columns: found_val_col = col break if not found_val_col: missing_cols = [col for col in required_cols if col not in metrics_df.columns] raise KeyError(f"Missing required columns in {metrics_path}: {missing_cols} or a suitable validation loss/metric column from {possible_val_cols}.") # --- Plotting --- setup_plot_style() # Apply standard style fig, ax1 = plt.subplots(figsize=(12, 6)) color1 = 'tab:red' ax1.set_xlabel(epoch_col.capitalize()) # Adjust ylabel based on actual column name used for train loss ax1.set_ylabel(train_loss_col.replace('_epoch','').replace('_',' ').capitalize(), color=color1) # Drop NaNs specific to this column for plotting integrity train_plot_data = metrics_df[[epoch_col, train_loss_col]].dropna(subset=[train_loss_col]) # Filter for epoch column only if needed (usually not for loss plots) # train_plot_data = train_plot_data[train_plot_data[epoch_col].notna()] # Ensure epoch starts from 0 or 1 consistently if train_plot_data[epoch_col].min() > 0 and 0 in metrics_df[epoch_col].unique(): # If epoch starts from 1 in plot data but 0 exists, adjust x-axis for alignment ax1.plot(train_plot_data[epoch_col] + 1, train_plot_data[train_loss_col], color=color1, label='Train Loss', marker='.', linestyle='-') logger.debug("Adjusting train loss x-axis by +1 for epoch alignment.") else: ax1.plot(train_plot_data[epoch_col], train_plot_data[train_loss_col], color=color1, label='Train Loss', marker='.', linestyle='-') ax1.tick_params(axis='y', labelcolor=color1) ax1.grid(True, axis='y', linestyle='--', alpha=0.6, which='major') # Validation loss/metric plotting on twin axis ax2 = ax1.twinx() color2 = 'tab:blue' # Adjust ylabel based on actual column name used for val metric ax2.set_ylabel(found_val_col.replace('_epoch','').replace('_',' ').capitalize(), color=color2) # Drop NaNs specific to the found validation column val_plot_data = metrics_df[[epoch_col, found_val_col]].dropna(subset=[found_val_col]) # val_plot_data = val_plot_data[val_plot_data[epoch_col].notna()] # Ensure epoch is not NaN # Ensure epoch starts from 0 or 1 consistently if val_plot_data[epoch_col].min() > 0 and 0 in metrics_df[epoch_col].unique(): # If epoch starts from 1 in plot data but 0 exists, adjust x-axis for alignment ax2.plot(val_plot_data[epoch_col] + 1, val_plot_data[found_val_col], color=color2, label='Validation Metric', marker='x', linestyle='--') logger.debug("Adjusting val metric x-axis by +1 for epoch alignment.") else: ax2.plot(val_plot_data[epoch_col], val_plot_data[found_val_col], color=color2, label='Validation Metric', marker='x', linestyle='--') ax2.tick_params(axis='y', labelcolor=color2) # Add legend manually combining lines from both axes lines, labels = ax1.get_legend_handles_labels() lines2, labels2 = ax2.get_legend_handles_labels() ax2.legend(lines + lines2, labels + labels2, loc='upper right') plt.title(title, fontsize=14) fig.tight_layout() # Otherwise the right y-label is slightly clipped # Save the plot save_plot(fig, output_path) except pd.errors.EmptyDataError: logger.error(f"Metrics CSV file is empty: {metrics_csv_path}") except KeyError as e: logger.error(f"Could not find expected column in {metrics_csv_path}: {e}") raise # Re-raise specific error after logging except Exception as e: logger.error(f"Failed to create or save loss curve plot from {metrics_csv_path}: {e}", exc_info=True) raise # Re-raise general errors