558 lines
24 KiB
Python
558 lines
24 KiB
Python
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
import numpy as np
|
|
from typing import Optional, Union, List
|
|
import logging
|
|
import pandas as pd
|
|
|
|
from pathlib import Path
|
|
|
|
# Assuming sklearn scalers are available
|
|
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def setup_plot_style(use_seaborn: bool = True) -> None:
|
|
"""
|
|
Set up a consistent plotting style using seaborn if enabled.
|
|
|
|
Args:
|
|
use_seaborn: Whether to apply seaborn styling.
|
|
"""
|
|
if use_seaborn:
|
|
try:
|
|
# Use a different style that might be better for multiple lines
|
|
sns.set_theme(style="whitegrid", palette="viridis") # Changed palette
|
|
plt.rcParams['figure.figsize'] = (15, 7) # Slightly larger default figure size
|
|
logger.debug("Seaborn plot style set.")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to set seaborn theme: {e}. Using default matplotlib style.")
|
|
else:
|
|
# Optional: Define a default matplotlib style if seaborn is not used
|
|
plt.style.use('default')
|
|
plt.rcParams['figure.figsize'] = (15, 7)
|
|
logger.debug("Using default matplotlib plot style.")
|
|
|
|
def save_plot(fig: plt.Figure, filename: Union[str, Path]) -> None:
|
|
"""
|
|
Save matplotlib figure to a file with directory creation and error handling.
|
|
|
|
Args:
|
|
fig: The matplotlib Figure object to save.
|
|
filename: The full path (including filename and extension) to save the plot to.
|
|
Can be a string or Path object.
|
|
|
|
Raises:
|
|
OSError: If the directory cannot be created.
|
|
Exception: For other file saving errors.
|
|
"""
|
|
filepath = Path(filename)
|
|
try:
|
|
# Create the parent directory if it doesn't exist
|
|
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
fig.savefig(filepath, bbox_inches='tight', dpi=150) # Save with tight bounding box and decent resolution
|
|
logger.info(f"Plot saved successfully to: {filepath}")
|
|
except OSError as e:
|
|
logger.error(f"Failed to create directory for plot {filepath}: {e}", exc_info=True)
|
|
# Don't re-raise immediately, try closing figure first
|
|
# raise # Re-raise OSError for directory creation issues - Removed to ensure finally runs
|
|
except Exception as e:
|
|
logger.error(f"Failed to save plot to {filepath}: {e}", exc_info=True)
|
|
# Don't re-raise immediately, try closing figure first
|
|
finally:
|
|
# Close the figure to free up memory, regardless of saving success or failure
|
|
try:
|
|
plt.close(fig)
|
|
logger.debug(f"Closed figure for plot {filepath}.")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to close figure for plot {filepath}: {e}")
|
|
|
|
def create_time_series_plot(
|
|
x: Union[np.ndarray, pd.Index], # Allow pd.Index for time axis
|
|
y_true: np.ndarray,
|
|
y_pred: np.ndarray,
|
|
title: str,
|
|
xlabel: str = "Time Index",
|
|
ylabel: str = "Value",
|
|
max_points: Optional[int] = None
|
|
) -> plt.Figure:
|
|
"""
|
|
Create a time series plot comparing actual vs predicted values.
|
|
NOTE: When using multi-horizon forecasts, this typically plots only ONE selected horizon.
|
|
Args:
|
|
x: The array or index for the x-axis (e.g., time steps, datetime index). Should align with y_true/y_pred.
|
|
y_true: Ground truth values (1D array).
|
|
y_pred: Predicted values (1D array).
|
|
title: Title for the plot.
|
|
xlabel: Label for the x-axis.
|
|
ylabel: Label for the y-axis.
|
|
max_points: Maximum number of points to display (subsamples if needed).
|
|
|
|
Returns:
|
|
The generated matplotlib Figure object.
|
|
|
|
Raises:
|
|
ValueError: If input array shapes are incompatible.
|
|
"""
|
|
# Add check for pd.Index for x
|
|
if not isinstance(x, (np.ndarray, pd.Index)) or x.shape[0] != y_true.shape[0] or x.shape[0] != y_pred.shape[0] or y_true.ndim != 1 or y_pred.ndim != 1:
|
|
raise ValueError(f"Input shapes mismatch or invalid types: x({type(x)}, {x.shape if hasattr(x, 'shape') else 'N/A'}), y_true({y_true.shape}), y_pred({y_pred.shape}). Expecting 1D y arrays and matching length x.")
|
|
if len(x) == 0:
|
|
logger.warning("Attempting to create time series plot with empty data.")
|
|
# Return an empty figure or raise error? Let's return empty.
|
|
return plt.figure()
|
|
|
|
logger.debug(f"Creating time series plot: {title}")
|
|
fig, ax = plt.subplots(figsize=(15, 6)) # Consistent size
|
|
|
|
n_points = len(x)
|
|
indices = np.arange(n_points) # Use internal indices for potential slicing
|
|
|
|
if max_points and n_points > max_points:
|
|
step = max(1, n_points // max_points)
|
|
plot_indices = indices[::step]
|
|
plot_x = x[::step]
|
|
plot_y_true = y_true[::step]
|
|
plot_y_pred = y_pred[::step]
|
|
effective_title = f'{title} (Sampled {len(plot_indices)} points)'
|
|
else:
|
|
plot_x = x
|
|
plot_y_true = y_true
|
|
plot_y_pred = y_pred
|
|
effective_title = title
|
|
|
|
ax.plot(plot_x, plot_y_true, label='Actual', marker='.', linestyle='-', markersize=4, linewidth=1.5)
|
|
ax.plot(plot_x, plot_y_pred, label='Predicted', marker='x', linestyle='--', markersize=4, alpha=0.8, linewidth=1)
|
|
|
|
ax.set_title(effective_title, fontsize=14)
|
|
ax.set_xlabel(xlabel, fontsize=12)
|
|
ax.set_ylabel(ylabel, fontsize=12)
|
|
ax.legend()
|
|
ax.grid(True, linestyle='--', alpha=0.6)
|
|
fig.tight_layout()
|
|
|
|
return fig
|
|
|
|
def create_scatter_plot(
|
|
y_true: np.ndarray,
|
|
y_pred: np.ndarray,
|
|
title: str,
|
|
xlabel: str = "Actual Values",
|
|
ylabel: str = "Predicted Values"
|
|
) -> plt.Figure:
|
|
"""
|
|
Create a scatter plot of actual vs predicted values.
|
|
|
|
Args:
|
|
y_true: Ground truth values (1D array).
|
|
y_pred: Predicted values (1D array).
|
|
title: Title for the plot.
|
|
xlabel: Label for the x-axis.
|
|
ylabel: Label for the y-axis.
|
|
|
|
Returns:
|
|
The generated matplotlib Figure object.
|
|
|
|
Raises:
|
|
ValueError: If input array shapes are incompatible.
|
|
"""
|
|
if not (y_true.shape == y_pred.shape and y_true.ndim == 1):
|
|
raise ValueError("Input arrays (y_true, y_pred) must be 1D and have the same shape.")
|
|
if len(y_true) == 0:
|
|
logger.warning("Attempting to create scatter plot with empty data.")
|
|
return plt.figure()
|
|
|
|
logger.debug(f"Creating scatter plot: {title}")
|
|
fig, ax = plt.subplots(figsize=(8, 8)) # Square figure common for scatter
|
|
|
|
# Determine plot limits, handle potential NaNs
|
|
valid_mask = ~np.isnan(y_true) & ~np.isnan(y_pred)
|
|
if not np.any(valid_mask):
|
|
logger.warning(f"No valid (non-NaN) data points found for scatter plot '{title}'.")
|
|
# Return empty figure, plot would be blank
|
|
return fig
|
|
|
|
y_true_valid = y_true[valid_mask]
|
|
y_pred_valid = y_pred[valid_mask]
|
|
|
|
min_val = min(y_true_valid.min(), y_pred_valid.min())
|
|
max_val = max(y_true_valid.max(), y_pred_valid.max())
|
|
plot_range = max_val - min_val
|
|
if plot_range < 1e-6: # Handle cases where all points are identical
|
|
plot_range = 1.0 # Avoid zero range
|
|
|
|
lim_min = min_val - 0.05 * plot_range
|
|
lim_max = max_val + 0.05 * plot_range
|
|
|
|
ax.scatter(y_true_valid, y_pred_valid, alpha=0.5, s=10, label='Predictions')
|
|
ax.plot([lim_min, lim_max], [lim_min, lim_max], 'r--', label='Ideal (y=x)', linewidth=1.5)
|
|
|
|
ax.set_title(title, fontsize=14)
|
|
ax.set_xlabel(xlabel, fontsize=12)
|
|
ax.set_ylabel(ylabel, fontsize=12)
|
|
ax.set_xlim(lim_min, lim_max)
|
|
ax.set_ylim(lim_min, lim_max)
|
|
ax.legend()
|
|
ax.grid(True, linestyle='--', alpha=0.6)
|
|
ax.set_aspect('equal', adjustable='box') # Ensure square scaling
|
|
fig.tight_layout()
|
|
|
|
return fig
|
|
|
|
def create_residuals_plot(
|
|
x: np.ndarray,
|
|
residuals: np.ndarray,
|
|
title: str,
|
|
xlabel: str = "Time Index",
|
|
ylabel: str = "Residual (Actual - Predicted)",
|
|
max_points: Optional[int] = None
|
|
) -> plt.Figure:
|
|
"""
|
|
Create a plot of residuals over time.
|
|
|
|
Args:
|
|
x: The array for the x-axis (e.g., time steps, indices).
|
|
residuals: Array of residual values (1D array).
|
|
title: Title for the plot.
|
|
xlabel: Label for the x-axis.
|
|
ylabel: Label for the y-axis.
|
|
max_points: Maximum number of points to display (subsamples if needed).
|
|
|
|
Returns:
|
|
The generated matplotlib Figure object.
|
|
|
|
Raises:
|
|
ValueError: If input array shapes are incompatible.
|
|
"""
|
|
if not (x.shape == residuals.shape and x.ndim == 1):
|
|
raise ValueError("Input arrays (x, residuals) must be 1D and have the same shape.")
|
|
if len(x) == 0:
|
|
logger.warning("Attempting to create residuals time plot with empty data.")
|
|
return plt.figure()
|
|
|
|
logger.debug(f"Creating residuals time plot: {title}")
|
|
fig, ax = plt.subplots(figsize=(15, 5)) # Often wider than tall
|
|
|
|
n_points = len(x)
|
|
indices = np.arange(n_points)
|
|
|
|
if max_points and n_points > max_points:
|
|
step = max(1, n_points // max_points)
|
|
plot_indices = indices[::step]
|
|
plot_x = x[::step]
|
|
plot_residuals = residuals[::step]
|
|
effective_title = f'{title} (Sampled {len(plot_indices)} points)'
|
|
else:
|
|
plot_x = x
|
|
plot_residuals = residuals
|
|
effective_title = title
|
|
|
|
ax.plot(plot_x, plot_residuals, marker='.', linestyle='-', markersize=4, linewidth=1, label='Residuals')
|
|
ax.axhline(0, color='red', linestyle='--', label='Zero Error', linewidth=1.5)
|
|
|
|
ax.set_title(effective_title, fontsize=14)
|
|
ax.set_xlabel(xlabel, fontsize=12)
|
|
ax.set_ylabel(ylabel, fontsize=12)
|
|
ax.legend()
|
|
ax.grid(True, linestyle='--', alpha=0.6)
|
|
fig.tight_layout()
|
|
|
|
return fig
|
|
|
|
def create_residuals_distribution_plot(
|
|
residuals: np.ndarray,
|
|
title: str,
|
|
xlabel: str = "Residual Value",
|
|
ylabel: str = "Density"
|
|
) -> plt.Figure:
|
|
"""
|
|
Create a distribution plot (histogram and KDE) of residuals using seaborn.
|
|
|
|
Args:
|
|
residuals: Array of residual values (1D array).
|
|
title: Title for the plot.
|
|
xlabel: Label for the x-axis.
|
|
ylabel: Label for the y-axis.
|
|
|
|
Returns:
|
|
The generated matplotlib Figure object.
|
|
|
|
Raises:
|
|
ValueError: If input array shape is invalid.
|
|
"""
|
|
if residuals.ndim != 1:
|
|
raise ValueError("Input array (residuals) must be 1D.")
|
|
if len(residuals) == 0:
|
|
logger.warning("Attempting to create residuals distribution plot with empty data.")
|
|
return plt.figure()
|
|
|
|
logger.debug(f"Creating residuals distribution plot: {title}")
|
|
fig, ax = plt.subplots(figsize=(8, 6))
|
|
|
|
# Filter out NaNs before plotting and calculating stats
|
|
residuals_valid = residuals[~np.isnan(residuals)]
|
|
if len(residuals_valid) == 0:
|
|
logger.warning(f"No valid (non-NaN) data points found for residual distribution plot '{title}'.")
|
|
return fig # Return empty figure
|
|
|
|
# Use seaborn histplot which combines histogram and KDE
|
|
try:
|
|
sns.histplot(residuals_valid, kde=True, bins=50, stat="density", ax=ax)
|
|
except Exception as e:
|
|
logger.error(f"Seaborn histplot failed for '{title}': {e}. Falling back to matplotlib hist.", exc_info=True)
|
|
# Fallback to basic matplotlib histogram if seaborn fails
|
|
ax.hist(residuals_valid, bins=50, density=True, alpha=0.7)
|
|
ylabel = "Frequency" # Adjust label if only histogram shown
|
|
|
|
mean_res = np.mean(residuals_valid)
|
|
std_res = np.std(residuals_valid)
|
|
ax.axvline(float(mean_res), color='red', linestyle='--', label=f'Mean: {mean_res:.3f}')
|
|
|
|
ax.set_title(f'{title}\n(Std Dev: {std_res:.3f})', fontsize=14)
|
|
ax.set_xlabel(xlabel, fontsize=12)
|
|
ax.set_ylabel(ylabel, fontsize=12)
|
|
ax.legend()
|
|
ax.grid(True, axis='y', linestyle='--', alpha=0.6)
|
|
fig.tight_layout()
|
|
|
|
return fig
|
|
|
|
def create_multi_horizon_time_series_plot(
|
|
y_true_scaled_all_horizons: np.ndarray, # (N, H)
|
|
y_pred_scaled_all_horizons: np.ndarray, # (N, H)
|
|
target_scaler: Optional[Union[StandardScaler, MinMaxScaler]],
|
|
prediction_time_index_h1: pd.DatetimeIndex, # Time index for the first horizon predictions
|
|
forecast_horizons: List[int],
|
|
title: str,
|
|
xlabel: str = "Time",
|
|
ylabel: str = "Value (Original Scale)",
|
|
max_points: Optional[int] = 1000 # Limit points for clarity
|
|
) -> plt.Figure:
|
|
"""
|
|
Create a time series plot comparing actual values to predictions for multiple horizons.
|
|
Predictions for each horizon are plotted on their corresponding target time step.
|
|
|
|
Args:
|
|
y_true_scaled_all_horizons: Ground truth values (N, H array) on scaled scale.
|
|
y_pred_scaled_all_horizons: Predicted values (N, H array) on scaled scale.
|
|
target_scaler: The scaler used for the target variable, needed for inverse transform.
|
|
prediction_time_index_h1: DatetimeIndex for the first horizon (h=h1) predictions.
|
|
Length should be N.
|
|
forecast_horizons: List of forecast horizons (e.g., [1, 6, 12, 24]).
|
|
title: Title for the plot.
|
|
xlabel: Label for the x-axis.
|
|
ylabel: Label for the y-axis.
|
|
max_points: Maximum number of points to display (subsamples if needed).
|
|
|
|
Returns:
|
|
The generated matplotlib Figure object.
|
|
|
|
Raises:
|
|
ValueError: If input shapes are incompatible or horizons list is invalid.
|
|
"""
|
|
if y_true_scaled_all_horizons.shape != y_pred_scaled_all_horizons.shape:
|
|
raise ValueError(f"Shapes of y_true_scaled_all_horizons {y_true_scaled_all_horizons.shape} and y_pred_scaled_all_horizons {y_pred_scaled_all_horizons.shape} must match.")
|
|
if y_true_scaled_all_horizons.ndim != 2 or y_true_scaled_all_horizons.shape[1] != len(forecast_horizons):
|
|
raise ValueError(f"y arrays must be 2D (N, H) where H is the number of horizons ({len(forecast_horizons)}). Shape is {y_true_scaled_all_horizons.shape}.")
|
|
if len(prediction_time_index_h1) != y_true_scaled_all_horizons.shape[0]:
|
|
raise ValueError(f"Length of prediction_time_index_h1 ({len(prediction_time_index_h1)}) must match the number of predictions ({y_true_scaled_all_horizons.shape[0]}).")
|
|
if not isinstance(prediction_time_index_h1, pd.DatetimeIndex):
|
|
logger.warning("prediction_time_index_h1 is not a DatetimeIndex. Time shifts may not work as expected.")
|
|
if not forecast_horizons or len(forecast_horizons) == 0:
|
|
raise ValueError("forecast_horizons list cannot be empty.")
|
|
|
|
logger.debug(f"Creating multi-horizon time series plot: {title}")
|
|
setup_plot_style() # Apply standard style
|
|
|
|
fig, ax = plt.subplots(figsize=(18, 8)) # Larger figure for multi-horizon
|
|
|
|
n_points = y_true_scaled_all_horizons.shape[0]
|
|
plot_indices = np.arange(n_points)
|
|
|
|
if max_points and n_points > max_points:
|
|
step = max(1, n_points // max_points)
|
|
plot_indices = plot_indices[::step]
|
|
# Subsample the data and index
|
|
y_true_scaled_plot = y_true_scaled_all_horizons[plot_indices]
|
|
y_pred_scaled_plot = y_pred_scaled_all_horizons[plot_indices]
|
|
time_index_h1_plot = prediction_time_index_h1[plot_indices]
|
|
effective_title = f'{title} (Sampled {len(plot_indices)} points)'
|
|
else:
|
|
y_true_scaled_plot = y_true_scaled_all_horizons
|
|
y_pred_scaled_plot = y_pred_scaled_all_horizons
|
|
time_index_h1_plot = prediction_time_index_h1
|
|
effective_title = title
|
|
|
|
# Inverse transform the subsampled data
|
|
y_true_inv_plot = None
|
|
y_pred_inv_plot = None
|
|
if target_scaler is not None:
|
|
try:
|
|
# Scaler expects (N * H, 1), reshape (N, H) to (N*H, 1)
|
|
y_true_inv_plot_flat = target_scaler.inverse_transform(y_true_scaled_plot.reshape(-1, 1))
|
|
y_pred_inv_plot_flat = target_scaler.inverse_transform(y_pred_scaled_plot.reshape(-1, 1))
|
|
# Reshape back to (N, H)
|
|
y_true_inv_plot = y_true_inv_plot_flat.reshape(y_true_scaled_plot.shape)
|
|
y_pred_inv_plot = y_pred_inv_plot_flat.reshape(y_pred_scaled_plot.shape)
|
|
logger.debug("Successfully inverse-transformed data for multi-horizon plot.")
|
|
except Exception as e:
|
|
logger.error(f"Failed to inverse transform data for multi-horizon plot: {e}", exc_info=True)
|
|
# Fallback to plotting scaled data if inverse transform fails
|
|
y_true_inv_plot = y_true_scaled_plot
|
|
y_pred_inv_plot = y_pred_scaled_plot
|
|
ylabel = f"{ylabel} (Scaled Data - Inverse Transform Failed)"
|
|
|
|
|
|
if y_true_inv_plot is None or y_pred_inv_plot is None:
|
|
# This should not happen with the fallback, but as a safeguard
|
|
logger.error("Inverse transformed data is None, cannot plot.")
|
|
return fig # Return empty figure
|
|
|
|
# Plot Actuals (using h1's time index, as it's the reference point)
|
|
ax.plot(time_index_h1_plot, y_true_inv_plot[:, 0], label='Actuals', marker='.', linestyle='-', markersize=4, linewidth=1.5, color='black') # Actuals for H1
|
|
|
|
# Plot predictions for each horizon
|
|
colors = sns.color_palette("viridis", len(forecast_horizons)) # Use palette for distinct colors
|
|
linestyles = ['-', '--', '-.', ':'] * (len(forecast_horizons) // 4 + 1) # Cycle through linestyles
|
|
|
|
for i, horizon in enumerate(forecast_horizons):
|
|
preds_h = y_pred_inv_plot[:, i]
|
|
# Calculate time index for this specific horizon by shifting the h1 index
|
|
# Assumes the time index frequency is appropriate for the horizon steps
|
|
try:
|
|
time_index_h = time_index_h1_plot + pd.to_timedelta(horizon - forecast_horizons[0], unit='h') # Assuming 'h' for hours
|
|
ax.plot(time_index_h, preds_h, label=f'Predicted (h={horizon})', marker='x', linestyle=linestyles[i], markersize=4, alpha=0.8, linewidth=1, color=colors[i])
|
|
except Exception as e:
|
|
logger.warning(f"Could not calculate time index for horizon {horizon}: {e}. Skipping plot for this horizon.", exc_info=True)
|
|
|
|
|
|
# Configure plot appearance
|
|
ax.set_title(effective_title, fontsize=16) # Slightly larger title
|
|
ax.set_xlabel(xlabel, fontsize=12)
|
|
ax.set_ylabel(ylabel, fontsize=12)
|
|
ax.legend(fontsize=10) # Smaller legend font
|
|
ax.grid(True, linestyle='--', alpha=0.6)
|
|
|
|
# Improve x-axis readability for datetimes
|
|
fig.autofmt_xdate() # Auto-rotate date labels
|
|
|
|
fig.tight_layout()
|
|
|
|
return fig
|
|
|
|
def plot_loss_curve_from_csv(
|
|
metrics_csv_path: Union[str, Path],
|
|
output_path: Union[str, Path],
|
|
title: str = "Training Loss Curve",
|
|
train_loss_col: str = "train_loss", # Changed to match logging in model.py
|
|
val_loss_col: str = "val_loss", # Common validation loss metric logged by PL
|
|
epoch_col: str = "epoch"
|
|
) -> None:
|
|
"""
|
|
Reads training metrics from a PyTorch Lightning CSVLogger file and plots
|
|
training and validation loss curves over epochs.
|
|
|
|
Args:
|
|
metrics_csv_path: Path to the metrics.csv file generated by CSVLogger.
|
|
output_path: Path where the plot image will be saved.
|
|
title: Title for the plot.
|
|
train_loss_col: Name of the column containing epoch-level training loss.
|
|
val_loss_col: Name of the column containing epoch-level validation loss.
|
|
epoch_col: Name of the column containing the epoch number.
|
|
|
|
Raises:
|
|
FileNotFoundError: If the metrics_csv_path does not exist.
|
|
KeyError: If required columns are not found in the CSV.
|
|
Exception: For other plotting or file reading errors.
|
|
"""
|
|
logger.info(f"Generating loss curve plot from: {metrics_csv_path}")
|
|
metrics_path = Path(metrics_csv_path)
|
|
if not metrics_path.is_file():
|
|
raise FileNotFoundError(f"Metrics CSV file not found at: {metrics_path}")
|
|
|
|
try:
|
|
metrics_df = pd.read_csv(metrics_path)
|
|
|
|
# Check if required columns exist
|
|
required_cols = [epoch_col, train_loss_col]
|
|
# Val loss column might be the scaled loss or the original scale MAE
|
|
possible_val_cols = [val_loss_col, 'val_MeanAbsoluteError_Original_Scale', 'val_mae_orig_scale'] # Include potential names
|
|
|
|
found_val_col = None
|
|
for col in possible_val_cols:
|
|
if col in metrics_df.columns:
|
|
found_val_col = col
|
|
break
|
|
|
|
if not found_val_col:
|
|
missing_cols = [col for col in required_cols if col not in metrics_df.columns]
|
|
raise KeyError(f"Missing required columns in {metrics_path}: {missing_cols} or a suitable validation loss/metric column from {possible_val_cols}.")
|
|
|
|
|
|
# --- Plotting ---
|
|
setup_plot_style() # Apply standard style
|
|
fig, ax1 = plt.subplots(figsize=(12, 6))
|
|
|
|
color1 = 'tab:red'
|
|
ax1.set_xlabel(epoch_col.capitalize())
|
|
# Adjust ylabel based on actual column name used for train loss
|
|
ax1.set_ylabel(train_loss_col.replace('_epoch','').replace('_',' ').capitalize(), color=color1)
|
|
# Drop NaNs specific to this column for plotting integrity
|
|
train_plot_data = metrics_df[[epoch_col, train_loss_col]].dropna(subset=[train_loss_col])
|
|
# Filter for epoch column only if needed (usually not for loss plots)
|
|
# train_plot_data = train_plot_data[train_plot_data[epoch_col].notna()]
|
|
|
|
# Ensure epoch starts from 0 or 1 consistently
|
|
if train_plot_data[epoch_col].min() > 0 and 0 in metrics_df[epoch_col].unique():
|
|
# If epoch starts from 1 in plot data but 0 exists, adjust x-axis for alignment
|
|
ax1.plot(train_plot_data[epoch_col] + 1, train_plot_data[train_loss_col], color=color1, label='Train Loss', marker='.', linestyle='-')
|
|
logger.debug("Adjusting train loss x-axis by +1 for epoch alignment.")
|
|
else:
|
|
ax1.plot(train_plot_data[epoch_col], train_plot_data[train_loss_col], color=color1, label='Train Loss', marker='.', linestyle='-')
|
|
|
|
|
|
ax1.tick_params(axis='y', labelcolor=color1)
|
|
ax1.grid(True, axis='y', linestyle='--', alpha=0.6, which='major')
|
|
|
|
|
|
# Validation loss/metric plotting on twin axis
|
|
ax2 = ax1.twinx()
|
|
color2 = 'tab:blue'
|
|
# Adjust ylabel based on actual column name used for val metric
|
|
ax2.set_ylabel(found_val_col.replace('_epoch','').replace('_',' ').capitalize(), color=color2)
|
|
# Drop NaNs specific to the found validation column
|
|
val_plot_data = metrics_df[[epoch_col, found_val_col]].dropna(subset=[found_val_col])
|
|
# val_plot_data = val_plot_data[val_plot_data[epoch_col].notna()] # Ensure epoch is not NaN
|
|
|
|
# Ensure epoch starts from 0 or 1 consistently
|
|
if val_plot_data[epoch_col].min() > 0 and 0 in metrics_df[epoch_col].unique():
|
|
# If epoch starts from 1 in plot data but 0 exists, adjust x-axis for alignment
|
|
ax2.plot(val_plot_data[epoch_col] + 1, val_plot_data[found_val_col], color=color2, label='Validation Metric', marker='x', linestyle='--')
|
|
logger.debug("Adjusting val metric x-axis by +1 for epoch alignment.")
|
|
else:
|
|
ax2.plot(val_plot_data[epoch_col], val_plot_data[found_val_col], color=color2, label='Validation Metric', marker='x', linestyle='--')
|
|
|
|
|
|
ax2.tick_params(axis='y', labelcolor=color2)
|
|
|
|
|
|
# Add legend manually combining lines from both axes
|
|
lines, labels = ax1.get_legend_handles_labels()
|
|
lines2, labels2 = ax2.get_legend_handles_labels()
|
|
ax2.legend(lines + lines2, labels + labels2, loc='upper right')
|
|
|
|
plt.title(title, fontsize=14)
|
|
fig.tight_layout() # Otherwise the right y-label is slightly clipped
|
|
|
|
# Save the plot
|
|
save_plot(fig, output_path)
|
|
|
|
except pd.errors.EmptyDataError:
|
|
logger.error(f"Metrics CSV file is empty: {metrics_csv_path}")
|
|
except KeyError as e:
|
|
logger.error(f"Could not find expected column in {metrics_csv_path}: {e}")
|
|
raise # Re-raise specific error after logging
|
|
except Exception as e:
|
|
logger.error(f"Failed to create or save loss curve plot from {metrics_csv_path}: {e}", exc_info=True)
|
|
raise # Re-raise general errors |