Files
entrix_case_challange/forecasting_model/io/plotting.py
2025-05-03 20:46:14 +02:00

558 lines
24 KiB
Python

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from typing import Optional, Union, List
import logging
import pandas as pd
from pathlib import Path
# Assuming sklearn scalers are available
from sklearn.preprocessing import StandardScaler, MinMaxScaler
logger = logging.getLogger(__name__)
def setup_plot_style(use_seaborn: bool = True) -> None:
"""
Set up a consistent plotting style using seaborn if enabled.
Args:
use_seaborn: Whether to apply seaborn styling.
"""
if use_seaborn:
try:
# Use a different style that might be better for multiple lines
sns.set_theme(style="whitegrid", palette="viridis") # Changed palette
plt.rcParams['figure.figsize'] = (15, 7) # Slightly larger default figure size
logger.debug("Seaborn plot style set.")
except Exception as e:
logger.warning(f"Failed to set seaborn theme: {e}. Using default matplotlib style.")
else:
# Optional: Define a default matplotlib style if seaborn is not used
plt.style.use('default')
plt.rcParams['figure.figsize'] = (15, 7)
logger.debug("Using default matplotlib plot style.")
def save_plot(fig: plt.Figure, filename: Union[str, Path]) -> None:
"""
Save matplotlib figure to a file with directory creation and error handling.
Args:
fig: The matplotlib Figure object to save.
filename: The full path (including filename and extension) to save the plot to.
Can be a string or Path object.
Raises:
OSError: If the directory cannot be created.
Exception: For other file saving errors.
"""
filepath = Path(filename)
try:
# Create the parent directory if it doesn't exist
filepath.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(filepath, bbox_inches='tight', dpi=150) # Save with tight bounding box and decent resolution
logger.info(f"Plot saved successfully to: {filepath}")
except OSError as e:
logger.error(f"Failed to create directory for plot {filepath}: {e}", exc_info=True)
# Don't re-raise immediately, try closing figure first
# raise # Re-raise OSError for directory creation issues - Removed to ensure finally runs
except Exception as e:
logger.error(f"Failed to save plot to {filepath}: {e}", exc_info=True)
# Don't re-raise immediately, try closing figure first
finally:
# Close the figure to free up memory, regardless of saving success or failure
try:
plt.close(fig)
logger.debug(f"Closed figure for plot {filepath}.")
except Exception as e:
logger.warning(f"Failed to close figure for plot {filepath}: {e}")
def create_time_series_plot(
x: Union[np.ndarray, pd.Index], # Allow pd.Index for time axis
y_true: np.ndarray,
y_pred: np.ndarray,
title: str,
xlabel: str = "Time Index",
ylabel: str = "Value",
max_points: Optional[int] = None
) -> plt.Figure:
"""
Create a time series plot comparing actual vs predicted values.
NOTE: When using multi-horizon forecasts, this typically plots only ONE selected horizon.
Args:
x: The array or index for the x-axis (e.g., time steps, datetime index). Should align with y_true/y_pred.
y_true: Ground truth values (1D array).
y_pred: Predicted values (1D array).
title: Title for the plot.
xlabel: Label for the x-axis.
ylabel: Label for the y-axis.
max_points: Maximum number of points to display (subsamples if needed).
Returns:
The generated matplotlib Figure object.
Raises:
ValueError: If input array shapes are incompatible.
"""
# Add check for pd.Index for x
if not isinstance(x, (np.ndarray, pd.Index)) or x.shape[0] != y_true.shape[0] or x.shape[0] != y_pred.shape[0] or y_true.ndim != 1 or y_pred.ndim != 1:
raise ValueError(f"Input shapes mismatch or invalid types: x({type(x)}, {x.shape if hasattr(x, 'shape') else 'N/A'}), y_true({y_true.shape}), y_pred({y_pred.shape}). Expecting 1D y arrays and matching length x.")
if len(x) == 0:
logger.warning("Attempting to create time series plot with empty data.")
# Return an empty figure or raise error? Let's return empty.
return plt.figure()
logger.debug(f"Creating time series plot: {title}")
fig, ax = plt.subplots(figsize=(15, 6)) # Consistent size
n_points = len(x)
indices = np.arange(n_points) # Use internal indices for potential slicing
if max_points and n_points > max_points:
step = max(1, n_points // max_points)
plot_indices = indices[::step]
plot_x = x[::step]
plot_y_true = y_true[::step]
plot_y_pred = y_pred[::step]
effective_title = f'{title} (Sampled {len(plot_indices)} points)'
else:
plot_x = x
plot_y_true = y_true
plot_y_pred = y_pred
effective_title = title
ax.plot(plot_x, plot_y_true, label='Actual', marker='.', linestyle='-', markersize=4, linewidth=1.5)
ax.plot(plot_x, plot_y_pred, label='Predicted', marker='x', linestyle='--', markersize=4, alpha=0.8, linewidth=1)
ax.set_title(effective_title, fontsize=14)
ax.set_xlabel(xlabel, fontsize=12)
ax.set_ylabel(ylabel, fontsize=12)
ax.legend()
ax.grid(True, linestyle='--', alpha=0.6)
fig.tight_layout()
return fig
def create_scatter_plot(
y_true: np.ndarray,
y_pred: np.ndarray,
title: str,
xlabel: str = "Actual Values",
ylabel: str = "Predicted Values"
) -> plt.Figure:
"""
Create a scatter plot of actual vs predicted values.
Args:
y_true: Ground truth values (1D array).
y_pred: Predicted values (1D array).
title: Title for the plot.
xlabel: Label for the x-axis.
ylabel: Label for the y-axis.
Returns:
The generated matplotlib Figure object.
Raises:
ValueError: If input array shapes are incompatible.
"""
if not (y_true.shape == y_pred.shape and y_true.ndim == 1):
raise ValueError("Input arrays (y_true, y_pred) must be 1D and have the same shape.")
if len(y_true) == 0:
logger.warning("Attempting to create scatter plot with empty data.")
return plt.figure()
logger.debug(f"Creating scatter plot: {title}")
fig, ax = plt.subplots(figsize=(8, 8)) # Square figure common for scatter
# Determine plot limits, handle potential NaNs
valid_mask = ~np.isnan(y_true) & ~np.isnan(y_pred)
if not np.any(valid_mask):
logger.warning(f"No valid (non-NaN) data points found for scatter plot '{title}'.")
# Return empty figure, plot would be blank
return fig
y_true_valid = y_true[valid_mask]
y_pred_valid = y_pred[valid_mask]
min_val = min(y_true_valid.min(), y_pred_valid.min())
max_val = max(y_true_valid.max(), y_pred_valid.max())
plot_range = max_val - min_val
if plot_range < 1e-6: # Handle cases where all points are identical
plot_range = 1.0 # Avoid zero range
lim_min = min_val - 0.05 * plot_range
lim_max = max_val + 0.05 * plot_range
ax.scatter(y_true_valid, y_pred_valid, alpha=0.5, s=10, label='Predictions')
ax.plot([lim_min, lim_max], [lim_min, lim_max], 'r--', label='Ideal (y=x)', linewidth=1.5)
ax.set_title(title, fontsize=14)
ax.set_xlabel(xlabel, fontsize=12)
ax.set_ylabel(ylabel, fontsize=12)
ax.set_xlim(lim_min, lim_max)
ax.set_ylim(lim_min, lim_max)
ax.legend()
ax.grid(True, linestyle='--', alpha=0.6)
ax.set_aspect('equal', adjustable='box') # Ensure square scaling
fig.tight_layout()
return fig
def create_residuals_plot(
x: np.ndarray,
residuals: np.ndarray,
title: str,
xlabel: str = "Time Index",
ylabel: str = "Residual (Actual - Predicted)",
max_points: Optional[int] = None
) -> plt.Figure:
"""
Create a plot of residuals over time.
Args:
x: The array for the x-axis (e.g., time steps, indices).
residuals: Array of residual values (1D array).
title: Title for the plot.
xlabel: Label for the x-axis.
ylabel: Label for the y-axis.
max_points: Maximum number of points to display (subsamples if needed).
Returns:
The generated matplotlib Figure object.
Raises:
ValueError: If input array shapes are incompatible.
"""
if not (x.shape == residuals.shape and x.ndim == 1):
raise ValueError("Input arrays (x, residuals) must be 1D and have the same shape.")
if len(x) == 0:
logger.warning("Attempting to create residuals time plot with empty data.")
return plt.figure()
logger.debug(f"Creating residuals time plot: {title}")
fig, ax = plt.subplots(figsize=(15, 5)) # Often wider than tall
n_points = len(x)
indices = np.arange(n_points)
if max_points and n_points > max_points:
step = max(1, n_points // max_points)
plot_indices = indices[::step]
plot_x = x[::step]
plot_residuals = residuals[::step]
effective_title = f'{title} (Sampled {len(plot_indices)} points)'
else:
plot_x = x
plot_residuals = residuals
effective_title = title
ax.plot(plot_x, plot_residuals, marker='.', linestyle='-', markersize=4, linewidth=1, label='Residuals')
ax.axhline(0, color='red', linestyle='--', label='Zero Error', linewidth=1.5)
ax.set_title(effective_title, fontsize=14)
ax.set_xlabel(xlabel, fontsize=12)
ax.set_ylabel(ylabel, fontsize=12)
ax.legend()
ax.grid(True, linestyle='--', alpha=0.6)
fig.tight_layout()
return fig
def create_residuals_distribution_plot(
residuals: np.ndarray,
title: str,
xlabel: str = "Residual Value",
ylabel: str = "Density"
) -> plt.Figure:
"""
Create a distribution plot (histogram and KDE) of residuals using seaborn.
Args:
residuals: Array of residual values (1D array).
title: Title for the plot.
xlabel: Label for the x-axis.
ylabel: Label for the y-axis.
Returns:
The generated matplotlib Figure object.
Raises:
ValueError: If input array shape is invalid.
"""
if residuals.ndim != 1:
raise ValueError("Input array (residuals) must be 1D.")
if len(residuals) == 0:
logger.warning("Attempting to create residuals distribution plot with empty data.")
return plt.figure()
logger.debug(f"Creating residuals distribution plot: {title}")
fig, ax = plt.subplots(figsize=(8, 6))
# Filter out NaNs before plotting and calculating stats
residuals_valid = residuals[~np.isnan(residuals)]
if len(residuals_valid) == 0:
logger.warning(f"No valid (non-NaN) data points found for residual distribution plot '{title}'.")
return fig # Return empty figure
# Use seaborn histplot which combines histogram and KDE
try:
sns.histplot(residuals_valid, kde=True, bins=50, stat="density", ax=ax)
except Exception as e:
logger.error(f"Seaborn histplot failed for '{title}': {e}. Falling back to matplotlib hist.", exc_info=True)
# Fallback to basic matplotlib histogram if seaborn fails
ax.hist(residuals_valid, bins=50, density=True, alpha=0.7)
ylabel = "Frequency" # Adjust label if only histogram shown
mean_res = np.mean(residuals_valid)
std_res = np.std(residuals_valid)
ax.axvline(float(mean_res), color='red', linestyle='--', label=f'Mean: {mean_res:.3f}')
ax.set_title(f'{title}\n(Std Dev: {std_res:.3f})', fontsize=14)
ax.set_xlabel(xlabel, fontsize=12)
ax.set_ylabel(ylabel, fontsize=12)
ax.legend()
ax.grid(True, axis='y', linestyle='--', alpha=0.6)
fig.tight_layout()
return fig
def create_multi_horizon_time_series_plot(
y_true_scaled_all_horizons: np.ndarray, # (N, H)
y_pred_scaled_all_horizons: np.ndarray, # (N, H)
target_scaler: Optional[Union[StandardScaler, MinMaxScaler]],
prediction_time_index_h1: pd.DatetimeIndex, # Time index for the first horizon predictions
forecast_horizons: List[int],
title: str,
xlabel: str = "Time",
ylabel: str = "Value (Original Scale)",
max_points: Optional[int] = 1000 # Limit points for clarity
) -> plt.Figure:
"""
Create a time series plot comparing actual values to predictions for multiple horizons.
Predictions for each horizon are plotted on their corresponding target time step.
Args:
y_true_scaled_all_horizons: Ground truth values (N, H array) on scaled scale.
y_pred_scaled_all_horizons: Predicted values (N, H array) on scaled scale.
target_scaler: The scaler used for the target variable, needed for inverse transform.
prediction_time_index_h1: DatetimeIndex for the first horizon (h=h1) predictions.
Length should be N.
forecast_horizons: List of forecast horizons (e.g., [1, 6, 12, 24]).
title: Title for the plot.
xlabel: Label for the x-axis.
ylabel: Label for the y-axis.
max_points: Maximum number of points to display (subsamples if needed).
Returns:
The generated matplotlib Figure object.
Raises:
ValueError: If input shapes are incompatible or horizons list is invalid.
"""
if y_true_scaled_all_horizons.shape != y_pred_scaled_all_horizons.shape:
raise ValueError(f"Shapes of y_true_scaled_all_horizons {y_true_scaled_all_horizons.shape} and y_pred_scaled_all_horizons {y_pred_scaled_all_horizons.shape} must match.")
if y_true_scaled_all_horizons.ndim != 2 or y_true_scaled_all_horizons.shape[1] != len(forecast_horizons):
raise ValueError(f"y arrays must be 2D (N, H) where H is the number of horizons ({len(forecast_horizons)}). Shape is {y_true_scaled_all_horizons.shape}.")
if len(prediction_time_index_h1) != y_true_scaled_all_horizons.shape[0]:
raise ValueError(f"Length of prediction_time_index_h1 ({len(prediction_time_index_h1)}) must match the number of predictions ({y_true_scaled_all_horizons.shape[0]}).")
if not isinstance(prediction_time_index_h1, pd.DatetimeIndex):
logger.warning("prediction_time_index_h1 is not a DatetimeIndex. Time shifts may not work as expected.")
if not forecast_horizons or len(forecast_horizons) == 0:
raise ValueError("forecast_horizons list cannot be empty.")
logger.debug(f"Creating multi-horizon time series plot: {title}")
setup_plot_style() # Apply standard style
fig, ax = plt.subplots(figsize=(18, 8)) # Larger figure for multi-horizon
n_points = y_true_scaled_all_horizons.shape[0]
plot_indices = np.arange(n_points)
if max_points and n_points > max_points:
step = max(1, n_points // max_points)
plot_indices = plot_indices[::step]
# Subsample the data and index
y_true_scaled_plot = y_true_scaled_all_horizons[plot_indices]
y_pred_scaled_plot = y_pred_scaled_all_horizons[plot_indices]
time_index_h1_plot = prediction_time_index_h1[plot_indices]
effective_title = f'{title} (Sampled {len(plot_indices)} points)'
else:
y_true_scaled_plot = y_true_scaled_all_horizons
y_pred_scaled_plot = y_pred_scaled_all_horizons
time_index_h1_plot = prediction_time_index_h1
effective_title = title
# Inverse transform the subsampled data
y_true_inv_plot = None
y_pred_inv_plot = None
if target_scaler is not None:
try:
# Scaler expects (N * H, 1), reshape (N, H) to (N*H, 1)
y_true_inv_plot_flat = target_scaler.inverse_transform(y_true_scaled_plot.reshape(-1, 1))
y_pred_inv_plot_flat = target_scaler.inverse_transform(y_pred_scaled_plot.reshape(-1, 1))
# Reshape back to (N, H)
y_true_inv_plot = y_true_inv_plot_flat.reshape(y_true_scaled_plot.shape)
y_pred_inv_plot = y_pred_inv_plot_flat.reshape(y_pred_scaled_plot.shape)
logger.debug("Successfully inverse-transformed data for multi-horizon plot.")
except Exception as e:
logger.error(f"Failed to inverse transform data for multi-horizon plot: {e}", exc_info=True)
# Fallback to plotting scaled data if inverse transform fails
y_true_inv_plot = y_true_scaled_plot
y_pred_inv_plot = y_pred_scaled_plot
ylabel = f"{ylabel} (Scaled Data - Inverse Transform Failed)"
if y_true_inv_plot is None or y_pred_inv_plot is None:
# This should not happen with the fallback, but as a safeguard
logger.error("Inverse transformed data is None, cannot plot.")
return fig # Return empty figure
# Plot Actuals (using h1's time index, as it's the reference point)
ax.plot(time_index_h1_plot, y_true_inv_plot[:, 0], label='Actuals', marker='.', linestyle='-', markersize=4, linewidth=1.5, color='black') # Actuals for H1
# Plot predictions for each horizon
colors = sns.color_palette("viridis", len(forecast_horizons)) # Use palette for distinct colors
linestyles = ['-', '--', '-.', ':'] * (len(forecast_horizons) // 4 + 1) # Cycle through linestyles
for i, horizon in enumerate(forecast_horizons):
preds_h = y_pred_inv_plot[:, i]
# Calculate time index for this specific horizon by shifting the h1 index
# Assumes the time index frequency is appropriate for the horizon steps
try:
time_index_h = time_index_h1_plot + pd.to_timedelta(horizon - forecast_horizons[0], unit='h') # Assuming 'h' for hours
ax.plot(time_index_h, preds_h, label=f'Predicted (h={horizon})', marker='x', linestyle=linestyles[i], markersize=4, alpha=0.8, linewidth=1, color=colors[i])
except Exception as e:
logger.warning(f"Could not calculate time index for horizon {horizon}: {e}. Skipping plot for this horizon.", exc_info=True)
# Configure plot appearance
ax.set_title(effective_title, fontsize=16) # Slightly larger title
ax.set_xlabel(xlabel, fontsize=12)
ax.set_ylabel(ylabel, fontsize=12)
ax.legend(fontsize=10) # Smaller legend font
ax.grid(True, linestyle='--', alpha=0.6)
# Improve x-axis readability for datetimes
fig.autofmt_xdate() # Auto-rotate date labels
fig.tight_layout()
return fig
def plot_loss_curve_from_csv(
metrics_csv_path: Union[str, Path],
output_path: Union[str, Path],
title: str = "Training Loss Curve",
train_loss_col: str = "train_loss", # Changed to match logging in model.py
val_loss_col: str = "val_loss", # Common validation loss metric logged by PL
epoch_col: str = "epoch"
) -> None:
"""
Reads training metrics from a PyTorch Lightning CSVLogger file and plots
training and validation loss curves over epochs.
Args:
metrics_csv_path: Path to the metrics.csv file generated by CSVLogger.
output_path: Path where the plot image will be saved.
title: Title for the plot.
train_loss_col: Name of the column containing epoch-level training loss.
val_loss_col: Name of the column containing epoch-level validation loss.
epoch_col: Name of the column containing the epoch number.
Raises:
FileNotFoundError: If the metrics_csv_path does not exist.
KeyError: If required columns are not found in the CSV.
Exception: For other plotting or file reading errors.
"""
logger.info(f"Generating loss curve plot from: {metrics_csv_path}")
metrics_path = Path(metrics_csv_path)
if not metrics_path.is_file():
raise FileNotFoundError(f"Metrics CSV file not found at: {metrics_path}")
try:
metrics_df = pd.read_csv(metrics_path)
# Check if required columns exist
required_cols = [epoch_col, train_loss_col]
# Val loss column might be the scaled loss or the original scale MAE
possible_val_cols = [val_loss_col, 'val_MeanAbsoluteError_Original_Scale', 'val_mae_orig_scale'] # Include potential names
found_val_col = None
for col in possible_val_cols:
if col in metrics_df.columns:
found_val_col = col
break
if not found_val_col:
missing_cols = [col for col in required_cols if col not in metrics_df.columns]
raise KeyError(f"Missing required columns in {metrics_path}: {missing_cols} or a suitable validation loss/metric column from {possible_val_cols}.")
# --- Plotting ---
setup_plot_style() # Apply standard style
fig, ax1 = plt.subplots(figsize=(12, 6))
color1 = 'tab:red'
ax1.set_xlabel(epoch_col.capitalize())
# Adjust ylabel based on actual column name used for train loss
ax1.set_ylabel(train_loss_col.replace('_epoch','').replace('_',' ').capitalize(), color=color1)
# Drop NaNs specific to this column for plotting integrity
train_plot_data = metrics_df[[epoch_col, train_loss_col]].dropna(subset=[train_loss_col])
# Filter for epoch column only if needed (usually not for loss plots)
# train_plot_data = train_plot_data[train_plot_data[epoch_col].notna()]
# Ensure epoch starts from 0 or 1 consistently
if train_plot_data[epoch_col].min() > 0 and 0 in metrics_df[epoch_col].unique():
# If epoch starts from 1 in plot data but 0 exists, adjust x-axis for alignment
ax1.plot(train_plot_data[epoch_col] + 1, train_plot_data[train_loss_col], color=color1, label='Train Loss', marker='.', linestyle='-')
logger.debug("Adjusting train loss x-axis by +1 for epoch alignment.")
else:
ax1.plot(train_plot_data[epoch_col], train_plot_data[train_loss_col], color=color1, label='Train Loss', marker='.', linestyle='-')
ax1.tick_params(axis='y', labelcolor=color1)
ax1.grid(True, axis='y', linestyle='--', alpha=0.6, which='major')
# Validation loss/metric plotting on twin axis
ax2 = ax1.twinx()
color2 = 'tab:blue'
# Adjust ylabel based on actual column name used for val metric
ax2.set_ylabel(found_val_col.replace('_epoch','').replace('_',' ').capitalize(), color=color2)
# Drop NaNs specific to the found validation column
val_plot_data = metrics_df[[epoch_col, found_val_col]].dropna(subset=[found_val_col])
# val_plot_data = val_plot_data[val_plot_data[epoch_col].notna()] # Ensure epoch is not NaN
# Ensure epoch starts from 0 or 1 consistently
if val_plot_data[epoch_col].min() > 0 and 0 in metrics_df[epoch_col].unique():
# If epoch starts from 1 in plot data but 0 exists, adjust x-axis for alignment
ax2.plot(val_plot_data[epoch_col] + 1, val_plot_data[found_val_col], color=color2, label='Validation Metric', marker='x', linestyle='--')
logger.debug("Adjusting val metric x-axis by +1 for epoch alignment.")
else:
ax2.plot(val_plot_data[epoch_col], val_plot_data[found_val_col], color=color2, label='Validation Metric', marker='x', linestyle='--')
ax2.tick_params(axis='y', labelcolor=color2)
# Add legend manually combining lines from both axes
lines, labels = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax2.legend(lines + lines2, labels + labels2, loc='upper right')
plt.title(title, fontsize=14)
fig.tight_layout() # Otherwise the right y-label is slightly clipped
# Save the plot
save_plot(fig, output_path)
except pd.errors.EmptyDataError:
logger.error(f"Metrics CSV file is empty: {metrics_csv_path}")
except KeyError as e:
logger.error(f"Could not find expected column in {metrics_csv_path}: {e}")
raise # Re-raise specific error after logging
except Exception as e:
logger.error(f"Failed to create or save loss curve plot from {metrics_csv_path}: {e}", exc_info=True)
raise # Re-raise general errors