intermediate backup
This commit is contained in:
@ -1,11 +1,15 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import numpy as np
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, List
|
||||
import logging
|
||||
import pandas as pd
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
# Assuming sklearn scalers are available
|
||||
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def setup_plot_style(use_seaborn: bool = True) -> None:
|
||||
@ -17,14 +21,16 @@ def setup_plot_style(use_seaborn: bool = True) -> None:
|
||||
"""
|
||||
if use_seaborn:
|
||||
try:
|
||||
sns.set_theme(style="whitegrid", palette="muted")
|
||||
plt.rcParams['figure.figsize'] = (12, 6) # Default figure size
|
||||
# Use a different style that might be better for multiple lines
|
||||
sns.set_theme(style="whitegrid", palette="viridis") # Changed palette
|
||||
plt.rcParams['figure.figsize'] = (15, 7) # Slightly larger default figure size
|
||||
logger.debug("Seaborn plot style set.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to set seaborn theme: {e}. Using default matplotlib style.")
|
||||
else:
|
||||
# Optional: Define a default matplotlib style if seaborn is not used
|
||||
plt.style.use('default')
|
||||
plt.rcParams['figure.figsize'] = (15, 7)
|
||||
logger.debug("Using default matplotlib plot style.")
|
||||
|
||||
def save_plot(fig: plt.Figure, filename: Union[str, Path]) -> None:
|
||||
@ -49,16 +55,21 @@ def save_plot(fig: plt.Figure, filename: Union[str, Path]) -> None:
|
||||
logger.info(f"Plot saved successfully to: {filepath}")
|
||||
except OSError as e:
|
||||
logger.error(f"Failed to create directory for plot {filepath}: {e}", exc_info=True)
|
||||
raise # Re-raise OSError for directory creation issues
|
||||
# Don't re-raise immediately, try closing figure first
|
||||
# raise # Re-raise OSError for directory creation issues - Removed to ensure finally runs
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save plot to {filepath}: {e}", exc_info=True)
|
||||
raise # Re-raise other saving errors
|
||||
# Don't re-raise immediately, try closing figure first
|
||||
finally:
|
||||
# Close the figure to free up memory, regardless of saving success
|
||||
plt.close(fig)
|
||||
# Close the figure to free up memory, regardless of saving success or failure
|
||||
try:
|
||||
plt.close(fig)
|
||||
logger.debug(f"Closed figure for plot {filepath}.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to close figure for plot {filepath}: {e}")
|
||||
|
||||
def create_time_series_plot(
|
||||
x: np.ndarray,
|
||||
x: Union[np.ndarray, pd.Index], # Allow pd.Index for time axis
|
||||
y_true: np.ndarray,
|
||||
y_pred: np.ndarray,
|
||||
title: str,
|
||||
@ -68,9 +79,9 @@ def create_time_series_plot(
|
||||
) -> plt.Figure:
|
||||
"""
|
||||
Create a time series plot comparing actual vs predicted values.
|
||||
|
||||
NOTE: When using multi-horizon forecasts, this typically plots only ONE selected horizon.
|
||||
Args:
|
||||
x: The array for the x-axis (e.g., time steps, indices).
|
||||
x: The array or index for the x-axis (e.g., time steps, datetime index). Should align with y_true/y_pred.
|
||||
y_true: Ground truth values (1D array).
|
||||
y_pred: Predicted values (1D array).
|
||||
title: Title for the plot.
|
||||
@ -84,8 +95,9 @@ def create_time_series_plot(
|
||||
Raises:
|
||||
ValueError: If input array shapes are incompatible.
|
||||
"""
|
||||
if not (x.shape == y_true.shape == y_pred.shape and x.ndim == 1):
|
||||
raise ValueError("Input arrays (x, y_true, y_pred) must be 1D and have the same shape.")
|
||||
# Add check for pd.Index for x
|
||||
if not isinstance(x, (np.ndarray, pd.Index)) or x.shape[0] != y_true.shape[0] or x.shape[0] != y_pred.shape[0] or y_true.ndim != 1 or y_pred.ndim != 1:
|
||||
raise ValueError(f"Input shapes mismatch or invalid types: x({type(x)}, {x.shape if hasattr(x, 'shape') else 'N/A'}), y_true({y_true.shape}), y_pred({y_pred.shape}). Expecting 1D y arrays and matching length x.")
|
||||
if len(x) == 0:
|
||||
logger.warning("Attempting to create time series plot with empty data.")
|
||||
# Return an empty figure or raise error? Let's return empty.
|
||||
@ -304,4 +316,243 @@ def create_residuals_distribution_plot(
|
||||
ax.grid(True, axis='y', linestyle='--', alpha=0.6)
|
||||
fig.tight_layout()
|
||||
|
||||
return fig
|
||||
return fig
|
||||
|
||||
def create_multi_horizon_time_series_plot(
|
||||
y_true_scaled_all_horizons: np.ndarray, # (N, H)
|
||||
y_pred_scaled_all_horizons: np.ndarray, # (N, H)
|
||||
target_scaler: Optional[Union[StandardScaler, MinMaxScaler]],
|
||||
prediction_time_index_h1: pd.DatetimeIndex, # Time index for the first horizon predictions
|
||||
forecast_horizons: List[int],
|
||||
title: str,
|
||||
xlabel: str = "Time",
|
||||
ylabel: str = "Value (Original Scale)",
|
||||
max_points: Optional[int] = 1000 # Limit points for clarity
|
||||
) -> plt.Figure:
|
||||
"""
|
||||
Create a time series plot comparing actual values to predictions for multiple horizons.
|
||||
Predictions for each horizon are plotted on their corresponding target time step.
|
||||
|
||||
Args:
|
||||
y_true_scaled_all_horizons: Ground truth values (N, H array) on scaled scale.
|
||||
y_pred_scaled_all_horizons: Predicted values (N, H array) on scaled scale.
|
||||
target_scaler: The scaler used for the target variable, needed for inverse transform.
|
||||
prediction_time_index_h1: DatetimeIndex for the first horizon (h=h1) predictions.
|
||||
Length should be N.
|
||||
forecast_horizons: List of forecast horizons (e.g., [1, 6, 12, 24]).
|
||||
title: Title for the plot.
|
||||
xlabel: Label for the x-axis.
|
||||
ylabel: Label for the y-axis.
|
||||
max_points: Maximum number of points to display (subsamples if needed).
|
||||
|
||||
Returns:
|
||||
The generated matplotlib Figure object.
|
||||
|
||||
Raises:
|
||||
ValueError: If input shapes are incompatible or horizons list is invalid.
|
||||
"""
|
||||
if y_true_scaled_all_horizons.shape != y_pred_scaled_all_horizons.shape:
|
||||
raise ValueError(f"Shapes of y_true_scaled_all_horizons {y_true_scaled_all_horizons.shape} and y_pred_scaled_all_horizons {y_pred_scaled_all_horizons.shape} must match.")
|
||||
if y_true_scaled_all_horizons.ndim != 2 or y_true_scaled_all_horizons.shape[1] != len(forecast_horizons):
|
||||
raise ValueError(f"y arrays must be 2D (N, H) where H is the number of horizons ({len(forecast_horizons)}). Shape is {y_true_scaled_all_horizons.shape}.")
|
||||
if len(prediction_time_index_h1) != y_true_scaled_all_horizons.shape[0]:
|
||||
raise ValueError(f"Length of prediction_time_index_h1 ({len(prediction_time_index_h1)}) must match the number of predictions ({y_true_scaled_all_horizons.shape[0]}).")
|
||||
if not isinstance(prediction_time_index_h1, pd.DatetimeIndex):
|
||||
logger.warning("prediction_time_index_h1 is not a DatetimeIndex. Time shifts may not work as expected.")
|
||||
if not forecast_horizons or len(forecast_horizons) == 0:
|
||||
raise ValueError("forecast_horizons list cannot be empty.")
|
||||
|
||||
logger.debug(f"Creating multi-horizon time series plot: {title}")
|
||||
setup_plot_style() # Apply standard style
|
||||
|
||||
fig, ax = plt.subplots(figsize=(18, 8)) # Larger figure for multi-horizon
|
||||
|
||||
n_points = y_true_scaled_all_horizons.shape[0]
|
||||
plot_indices = np.arange(n_points)
|
||||
|
||||
if max_points and n_points > max_points:
|
||||
step = max(1, n_points // max_points)
|
||||
plot_indices = plot_indices[::step]
|
||||
# Subsample the data and index
|
||||
y_true_scaled_plot = y_true_scaled_all_horizons[plot_indices]
|
||||
y_pred_scaled_plot = y_pred_scaled_all_horizons[plot_indices]
|
||||
time_index_h1_plot = prediction_time_index_h1[plot_indices]
|
||||
effective_title = f'{title} (Sampled {len(plot_indices)} points)'
|
||||
else:
|
||||
y_true_scaled_plot = y_true_scaled_all_horizons
|
||||
y_pred_scaled_plot = y_pred_scaled_all_horizons
|
||||
time_index_h1_plot = prediction_time_index_h1
|
||||
effective_title = title
|
||||
|
||||
# Inverse transform the subsampled data
|
||||
y_true_inv_plot = None
|
||||
y_pred_inv_plot = None
|
||||
if target_scaler is not None:
|
||||
try:
|
||||
# Scaler expects (N * H, 1), reshape (N, H) to (N*H, 1)
|
||||
y_true_inv_plot_flat = target_scaler.inverse_transform(y_true_scaled_plot.reshape(-1, 1))
|
||||
y_pred_inv_plot_flat = target_scaler.inverse_transform(y_pred_scaled_plot.reshape(-1, 1))
|
||||
# Reshape back to (N, H)
|
||||
y_true_inv_plot = y_true_inv_plot_flat.reshape(y_true_scaled_plot.shape)
|
||||
y_pred_inv_plot = y_pred_inv_plot_flat.reshape(y_pred_scaled_plot.shape)
|
||||
logger.debug("Successfully inverse-transformed data for multi-horizon plot.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to inverse transform data for multi-horizon plot: {e}", exc_info=True)
|
||||
# Fallback to plotting scaled data if inverse transform fails
|
||||
y_true_inv_plot = y_true_scaled_plot
|
||||
y_pred_inv_plot = y_pred_scaled_plot
|
||||
ylabel = f"{ylabel} (Scaled Data - Inverse Transform Failed)"
|
||||
|
||||
|
||||
if y_true_inv_plot is None or y_pred_inv_plot is None:
|
||||
# This should not happen with the fallback, but as a safeguard
|
||||
logger.error("Inverse transformed data is None, cannot plot.")
|
||||
return fig # Return empty figure
|
||||
|
||||
# Plot Actuals (using h1's time index, as it's the reference point)
|
||||
ax.plot(time_index_h1_plot, y_true_inv_plot[:, 0], label='Actuals', marker='.', linestyle='-', markersize=4, linewidth=1.5, color='black') # Actuals for H1
|
||||
|
||||
# Plot predictions for each horizon
|
||||
colors = sns.color_palette("viridis", len(forecast_horizons)) # Use palette for distinct colors
|
||||
linestyles = ['-', '--', '-.', ':'] * (len(forecast_horizons) // 4 + 1) # Cycle through linestyles
|
||||
|
||||
for i, horizon in enumerate(forecast_horizons):
|
||||
preds_h = y_pred_inv_plot[:, i]
|
||||
# Calculate time index for this specific horizon by shifting the h1 index
|
||||
# Assumes the time index frequency is appropriate for the horizon steps
|
||||
try:
|
||||
time_index_h = time_index_h1_plot + pd.to_timedelta(horizon - forecast_horizons[0], unit='h') # Assuming 'h' for hours
|
||||
ax.plot(time_index_h, preds_h, label=f'Predicted (h={horizon})', marker='x', linestyle=linestyles[i], markersize=4, alpha=0.8, linewidth=1, color=colors[i])
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not calculate time index for horizon {horizon}: {e}. Skipping plot for this horizon.", exc_info=True)
|
||||
|
||||
|
||||
# Configure plot appearance
|
||||
ax.set_title(effective_title, fontsize=16) # Slightly larger title
|
||||
ax.set_xlabel(xlabel, fontsize=12)
|
||||
ax.set_ylabel(ylabel, fontsize=12)
|
||||
ax.legend(fontsize=10) # Smaller legend font
|
||||
ax.grid(True, linestyle='--', alpha=0.6)
|
||||
|
||||
# Improve x-axis readability for datetimes
|
||||
fig.autofmt_xdate() # Auto-rotate date labels
|
||||
|
||||
fig.tight_layout()
|
||||
|
||||
return fig
|
||||
|
||||
def plot_loss_curve_from_csv(
|
||||
metrics_csv_path: Union[str, Path],
|
||||
output_path: Union[str, Path],
|
||||
title: str = "Training Loss Curve",
|
||||
train_loss_col: str = "train_loss", # Changed to match logging in model.py
|
||||
val_loss_col: str = "val_loss", # Common validation loss metric logged by PL
|
||||
epoch_col: str = "epoch"
|
||||
) -> None:
|
||||
"""
|
||||
Reads training metrics from a PyTorch Lightning CSVLogger file and plots
|
||||
training and validation loss curves over epochs.
|
||||
|
||||
Args:
|
||||
metrics_csv_path: Path to the metrics.csv file generated by CSVLogger.
|
||||
output_path: Path where the plot image will be saved.
|
||||
title: Title for the plot.
|
||||
train_loss_col: Name of the column containing epoch-level training loss.
|
||||
val_loss_col: Name of the column containing epoch-level validation loss.
|
||||
epoch_col: Name of the column containing the epoch number.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the metrics_csv_path does not exist.
|
||||
KeyError: If required columns are not found in the CSV.
|
||||
Exception: For other plotting or file reading errors.
|
||||
"""
|
||||
logger.info(f"Generating loss curve plot from: {metrics_csv_path}")
|
||||
metrics_path = Path(metrics_csv_path)
|
||||
if not metrics_path.is_file():
|
||||
raise FileNotFoundError(f"Metrics CSV file not found at: {metrics_path}")
|
||||
|
||||
try:
|
||||
metrics_df = pd.read_csv(metrics_path)
|
||||
|
||||
# Check if required columns exist
|
||||
required_cols = [epoch_col, train_loss_col]
|
||||
# Val loss column might be the scaled loss or the original scale MAE
|
||||
possible_val_cols = [val_loss_col, 'val_MeanAbsoluteError_Original_Scale', 'val_mae_orig_scale'] # Include potential names
|
||||
|
||||
found_val_col = None
|
||||
for col in possible_val_cols:
|
||||
if col in metrics_df.columns:
|
||||
found_val_col = col
|
||||
break
|
||||
|
||||
if not found_val_col:
|
||||
missing_cols = [col for col in required_cols if col not in metrics_df.columns]
|
||||
raise KeyError(f"Missing required columns in {metrics_path}: {missing_cols} or a suitable validation loss/metric column from {possible_val_cols}.")
|
||||
|
||||
|
||||
# --- Plotting ---
|
||||
setup_plot_style() # Apply standard style
|
||||
fig, ax1 = plt.subplots(figsize=(12, 6))
|
||||
|
||||
color1 = 'tab:red'
|
||||
ax1.set_xlabel(epoch_col.capitalize())
|
||||
# Adjust ylabel based on actual column name used for train loss
|
||||
ax1.set_ylabel(train_loss_col.replace('_epoch','').replace('_',' ').capitalize(), color=color1)
|
||||
# Drop NaNs specific to this column for plotting integrity
|
||||
train_plot_data = metrics_df[[epoch_col, train_loss_col]].dropna(subset=[train_loss_col])
|
||||
# Filter for epoch column only if needed (usually not for loss plots)
|
||||
# train_plot_data = train_plot_data[train_plot_data[epoch_col].notna()]
|
||||
|
||||
# Ensure epoch starts from 0 or 1 consistently
|
||||
if train_plot_data[epoch_col].min() > 0 and 0 in metrics_df[epoch_col].unique():
|
||||
# If epoch starts from 1 in plot data but 0 exists, adjust x-axis for alignment
|
||||
ax1.plot(train_plot_data[epoch_col] + 1, train_plot_data[train_loss_col], color=color1, label='Train Loss', marker='.', linestyle='-')
|
||||
logger.debug("Adjusting train loss x-axis by +1 for epoch alignment.")
|
||||
else:
|
||||
ax1.plot(train_plot_data[epoch_col], train_plot_data[train_loss_col], color=color1, label='Train Loss', marker='.', linestyle='-')
|
||||
|
||||
|
||||
ax1.tick_params(axis='y', labelcolor=color1)
|
||||
ax1.grid(True, axis='y', linestyle='--', alpha=0.6, which='major')
|
||||
|
||||
|
||||
# Validation loss/metric plotting on twin axis
|
||||
ax2 = ax1.twinx()
|
||||
color2 = 'tab:blue'
|
||||
# Adjust ylabel based on actual column name used for val metric
|
||||
ax2.set_ylabel(found_val_col.replace('_epoch','').replace('_',' ').capitalize(), color=color2)
|
||||
# Drop NaNs specific to the found validation column
|
||||
val_plot_data = metrics_df[[epoch_col, found_val_col]].dropna(subset=[found_val_col])
|
||||
# val_plot_data = val_plot_data[val_plot_data[epoch_col].notna()] # Ensure epoch is not NaN
|
||||
|
||||
# Ensure epoch starts from 0 or 1 consistently
|
||||
if val_plot_data[epoch_col].min() > 0 and 0 in metrics_df[epoch_col].unique():
|
||||
# If epoch starts from 1 in plot data but 0 exists, adjust x-axis for alignment
|
||||
ax2.plot(val_plot_data[epoch_col] + 1, val_plot_data[found_val_col], color=color2, label='Validation Metric', marker='x', linestyle='--')
|
||||
logger.debug("Adjusting val metric x-axis by +1 for epoch alignment.")
|
||||
else:
|
||||
ax2.plot(val_plot_data[epoch_col], val_plot_data[found_val_col], color=color2, label='Validation Metric', marker='x', linestyle='--')
|
||||
|
||||
|
||||
ax2.tick_params(axis='y', labelcolor=color2)
|
||||
|
||||
|
||||
# Add legend manually combining lines from both axes
|
||||
lines, labels = ax1.get_legend_handles_labels()
|
||||
lines2, labels2 = ax2.get_legend_handles_labels()
|
||||
ax2.legend(lines + lines2, labels + labels2, loc='upper right')
|
||||
|
||||
plt.title(title, fontsize=14)
|
||||
fig.tight_layout() # Otherwise the right y-label is slightly clipped
|
||||
|
||||
# Save the plot
|
||||
save_plot(fig, output_path)
|
||||
|
||||
except pd.errors.EmptyDataError:
|
||||
logger.error(f"Metrics CSV file is empty: {metrics_csv_path}")
|
||||
except KeyError as e:
|
||||
logger.error(f"Could not find expected column in {metrics_csv_path}: {e}")
|
||||
raise # Re-raise specific error after logging
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create or save loss curve plot from {metrics_csv_path}: {e}", exc_info=True)
|
||||
raise # Re-raise general errors
|
Reference in New Issue
Block a user