intermediate backup

This commit is contained in:
2025-05-03 20:46:14 +02:00
parent 2b0a5728d4
commit 6542caf48f
38 changed files with 4513 additions and 1067 deletions

View File

@ -1,11 +1,15 @@
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from typing import Optional, Union
from typing import Optional, Union, List
import logging
import pandas as pd
from pathlib import Path
# Assuming sklearn scalers are available
from sklearn.preprocessing import StandardScaler, MinMaxScaler
logger = logging.getLogger(__name__)
def setup_plot_style(use_seaborn: bool = True) -> None:
@ -17,14 +21,16 @@ def setup_plot_style(use_seaborn: bool = True) -> None:
"""
if use_seaborn:
try:
sns.set_theme(style="whitegrid", palette="muted")
plt.rcParams['figure.figsize'] = (12, 6) # Default figure size
# Use a different style that might be better for multiple lines
sns.set_theme(style="whitegrid", palette="viridis") # Changed palette
plt.rcParams['figure.figsize'] = (15, 7) # Slightly larger default figure size
logger.debug("Seaborn plot style set.")
except Exception as e:
logger.warning(f"Failed to set seaborn theme: {e}. Using default matplotlib style.")
else:
# Optional: Define a default matplotlib style if seaborn is not used
plt.style.use('default')
plt.rcParams['figure.figsize'] = (15, 7)
logger.debug("Using default matplotlib plot style.")
def save_plot(fig: plt.Figure, filename: Union[str, Path]) -> None:
@ -49,16 +55,21 @@ def save_plot(fig: plt.Figure, filename: Union[str, Path]) -> None:
logger.info(f"Plot saved successfully to: {filepath}")
except OSError as e:
logger.error(f"Failed to create directory for plot {filepath}: {e}", exc_info=True)
raise # Re-raise OSError for directory creation issues
# Don't re-raise immediately, try closing figure first
# raise # Re-raise OSError for directory creation issues - Removed to ensure finally runs
except Exception as e:
logger.error(f"Failed to save plot to {filepath}: {e}", exc_info=True)
raise # Re-raise other saving errors
# Don't re-raise immediately, try closing figure first
finally:
# Close the figure to free up memory, regardless of saving success
plt.close(fig)
# Close the figure to free up memory, regardless of saving success or failure
try:
plt.close(fig)
logger.debug(f"Closed figure for plot {filepath}.")
except Exception as e:
logger.warning(f"Failed to close figure for plot {filepath}: {e}")
def create_time_series_plot(
x: np.ndarray,
x: Union[np.ndarray, pd.Index], # Allow pd.Index for time axis
y_true: np.ndarray,
y_pred: np.ndarray,
title: str,
@ -68,9 +79,9 @@ def create_time_series_plot(
) -> plt.Figure:
"""
Create a time series plot comparing actual vs predicted values.
NOTE: When using multi-horizon forecasts, this typically plots only ONE selected horizon.
Args:
x: The array for the x-axis (e.g., time steps, indices).
x: The array or index for the x-axis (e.g., time steps, datetime index). Should align with y_true/y_pred.
y_true: Ground truth values (1D array).
y_pred: Predicted values (1D array).
title: Title for the plot.
@ -84,8 +95,9 @@ def create_time_series_plot(
Raises:
ValueError: If input array shapes are incompatible.
"""
if not (x.shape == y_true.shape == y_pred.shape and x.ndim == 1):
raise ValueError("Input arrays (x, y_true, y_pred) must be 1D and have the same shape.")
# Add check for pd.Index for x
if not isinstance(x, (np.ndarray, pd.Index)) or x.shape[0] != y_true.shape[0] or x.shape[0] != y_pred.shape[0] or y_true.ndim != 1 or y_pred.ndim != 1:
raise ValueError(f"Input shapes mismatch or invalid types: x({type(x)}, {x.shape if hasattr(x, 'shape') else 'N/A'}), y_true({y_true.shape}), y_pred({y_pred.shape}). Expecting 1D y arrays and matching length x.")
if len(x) == 0:
logger.warning("Attempting to create time series plot with empty data.")
# Return an empty figure or raise error? Let's return empty.
@ -304,4 +316,243 @@ def create_residuals_distribution_plot(
ax.grid(True, axis='y', linestyle='--', alpha=0.6)
fig.tight_layout()
return fig
return fig
def create_multi_horizon_time_series_plot(
y_true_scaled_all_horizons: np.ndarray, # (N, H)
y_pred_scaled_all_horizons: np.ndarray, # (N, H)
target_scaler: Optional[Union[StandardScaler, MinMaxScaler]],
prediction_time_index_h1: pd.DatetimeIndex, # Time index for the first horizon predictions
forecast_horizons: List[int],
title: str,
xlabel: str = "Time",
ylabel: str = "Value (Original Scale)",
max_points: Optional[int] = 1000 # Limit points for clarity
) -> plt.Figure:
"""
Create a time series plot comparing actual values to predictions for multiple horizons.
Predictions for each horizon are plotted on their corresponding target time step.
Args:
y_true_scaled_all_horizons: Ground truth values (N, H array) on scaled scale.
y_pred_scaled_all_horizons: Predicted values (N, H array) on scaled scale.
target_scaler: The scaler used for the target variable, needed for inverse transform.
prediction_time_index_h1: DatetimeIndex for the first horizon (h=h1) predictions.
Length should be N.
forecast_horizons: List of forecast horizons (e.g., [1, 6, 12, 24]).
title: Title for the plot.
xlabel: Label for the x-axis.
ylabel: Label for the y-axis.
max_points: Maximum number of points to display (subsamples if needed).
Returns:
The generated matplotlib Figure object.
Raises:
ValueError: If input shapes are incompatible or horizons list is invalid.
"""
if y_true_scaled_all_horizons.shape != y_pred_scaled_all_horizons.shape:
raise ValueError(f"Shapes of y_true_scaled_all_horizons {y_true_scaled_all_horizons.shape} and y_pred_scaled_all_horizons {y_pred_scaled_all_horizons.shape} must match.")
if y_true_scaled_all_horizons.ndim != 2 or y_true_scaled_all_horizons.shape[1] != len(forecast_horizons):
raise ValueError(f"y arrays must be 2D (N, H) where H is the number of horizons ({len(forecast_horizons)}). Shape is {y_true_scaled_all_horizons.shape}.")
if len(prediction_time_index_h1) != y_true_scaled_all_horizons.shape[0]:
raise ValueError(f"Length of prediction_time_index_h1 ({len(prediction_time_index_h1)}) must match the number of predictions ({y_true_scaled_all_horizons.shape[0]}).")
if not isinstance(prediction_time_index_h1, pd.DatetimeIndex):
logger.warning("prediction_time_index_h1 is not a DatetimeIndex. Time shifts may not work as expected.")
if not forecast_horizons or len(forecast_horizons) == 0:
raise ValueError("forecast_horizons list cannot be empty.")
logger.debug(f"Creating multi-horizon time series plot: {title}")
setup_plot_style() # Apply standard style
fig, ax = plt.subplots(figsize=(18, 8)) # Larger figure for multi-horizon
n_points = y_true_scaled_all_horizons.shape[0]
plot_indices = np.arange(n_points)
if max_points and n_points > max_points:
step = max(1, n_points // max_points)
plot_indices = plot_indices[::step]
# Subsample the data and index
y_true_scaled_plot = y_true_scaled_all_horizons[plot_indices]
y_pred_scaled_plot = y_pred_scaled_all_horizons[plot_indices]
time_index_h1_plot = prediction_time_index_h1[plot_indices]
effective_title = f'{title} (Sampled {len(plot_indices)} points)'
else:
y_true_scaled_plot = y_true_scaled_all_horizons
y_pred_scaled_plot = y_pred_scaled_all_horizons
time_index_h1_plot = prediction_time_index_h1
effective_title = title
# Inverse transform the subsampled data
y_true_inv_plot = None
y_pred_inv_plot = None
if target_scaler is not None:
try:
# Scaler expects (N * H, 1), reshape (N, H) to (N*H, 1)
y_true_inv_plot_flat = target_scaler.inverse_transform(y_true_scaled_plot.reshape(-1, 1))
y_pred_inv_plot_flat = target_scaler.inverse_transform(y_pred_scaled_plot.reshape(-1, 1))
# Reshape back to (N, H)
y_true_inv_plot = y_true_inv_plot_flat.reshape(y_true_scaled_plot.shape)
y_pred_inv_plot = y_pred_inv_plot_flat.reshape(y_pred_scaled_plot.shape)
logger.debug("Successfully inverse-transformed data for multi-horizon plot.")
except Exception as e:
logger.error(f"Failed to inverse transform data for multi-horizon plot: {e}", exc_info=True)
# Fallback to plotting scaled data if inverse transform fails
y_true_inv_plot = y_true_scaled_plot
y_pred_inv_plot = y_pred_scaled_plot
ylabel = f"{ylabel} (Scaled Data - Inverse Transform Failed)"
if y_true_inv_plot is None or y_pred_inv_plot is None:
# This should not happen with the fallback, but as a safeguard
logger.error("Inverse transformed data is None, cannot plot.")
return fig # Return empty figure
# Plot Actuals (using h1's time index, as it's the reference point)
ax.plot(time_index_h1_plot, y_true_inv_plot[:, 0], label='Actuals', marker='.', linestyle='-', markersize=4, linewidth=1.5, color='black') # Actuals for H1
# Plot predictions for each horizon
colors = sns.color_palette("viridis", len(forecast_horizons)) # Use palette for distinct colors
linestyles = ['-', '--', '-.', ':'] * (len(forecast_horizons) // 4 + 1) # Cycle through linestyles
for i, horizon in enumerate(forecast_horizons):
preds_h = y_pred_inv_plot[:, i]
# Calculate time index for this specific horizon by shifting the h1 index
# Assumes the time index frequency is appropriate for the horizon steps
try:
time_index_h = time_index_h1_plot + pd.to_timedelta(horizon - forecast_horizons[0], unit='h') # Assuming 'h' for hours
ax.plot(time_index_h, preds_h, label=f'Predicted (h={horizon})', marker='x', linestyle=linestyles[i], markersize=4, alpha=0.8, linewidth=1, color=colors[i])
except Exception as e:
logger.warning(f"Could not calculate time index for horizon {horizon}: {e}. Skipping plot for this horizon.", exc_info=True)
# Configure plot appearance
ax.set_title(effective_title, fontsize=16) # Slightly larger title
ax.set_xlabel(xlabel, fontsize=12)
ax.set_ylabel(ylabel, fontsize=12)
ax.legend(fontsize=10) # Smaller legend font
ax.grid(True, linestyle='--', alpha=0.6)
# Improve x-axis readability for datetimes
fig.autofmt_xdate() # Auto-rotate date labels
fig.tight_layout()
return fig
def plot_loss_curve_from_csv(
metrics_csv_path: Union[str, Path],
output_path: Union[str, Path],
title: str = "Training Loss Curve",
train_loss_col: str = "train_loss", # Changed to match logging in model.py
val_loss_col: str = "val_loss", # Common validation loss metric logged by PL
epoch_col: str = "epoch"
) -> None:
"""
Reads training metrics from a PyTorch Lightning CSVLogger file and plots
training and validation loss curves over epochs.
Args:
metrics_csv_path: Path to the metrics.csv file generated by CSVLogger.
output_path: Path where the plot image will be saved.
title: Title for the plot.
train_loss_col: Name of the column containing epoch-level training loss.
val_loss_col: Name of the column containing epoch-level validation loss.
epoch_col: Name of the column containing the epoch number.
Raises:
FileNotFoundError: If the metrics_csv_path does not exist.
KeyError: If required columns are not found in the CSV.
Exception: For other plotting or file reading errors.
"""
logger.info(f"Generating loss curve plot from: {metrics_csv_path}")
metrics_path = Path(metrics_csv_path)
if not metrics_path.is_file():
raise FileNotFoundError(f"Metrics CSV file not found at: {metrics_path}")
try:
metrics_df = pd.read_csv(metrics_path)
# Check if required columns exist
required_cols = [epoch_col, train_loss_col]
# Val loss column might be the scaled loss or the original scale MAE
possible_val_cols = [val_loss_col, 'val_MeanAbsoluteError_Original_Scale', 'val_mae_orig_scale'] # Include potential names
found_val_col = None
for col in possible_val_cols:
if col in metrics_df.columns:
found_val_col = col
break
if not found_val_col:
missing_cols = [col for col in required_cols if col not in metrics_df.columns]
raise KeyError(f"Missing required columns in {metrics_path}: {missing_cols} or a suitable validation loss/metric column from {possible_val_cols}.")
# --- Plotting ---
setup_plot_style() # Apply standard style
fig, ax1 = plt.subplots(figsize=(12, 6))
color1 = 'tab:red'
ax1.set_xlabel(epoch_col.capitalize())
# Adjust ylabel based on actual column name used for train loss
ax1.set_ylabel(train_loss_col.replace('_epoch','').replace('_',' ').capitalize(), color=color1)
# Drop NaNs specific to this column for plotting integrity
train_plot_data = metrics_df[[epoch_col, train_loss_col]].dropna(subset=[train_loss_col])
# Filter for epoch column only if needed (usually not for loss plots)
# train_plot_data = train_plot_data[train_plot_data[epoch_col].notna()]
# Ensure epoch starts from 0 or 1 consistently
if train_plot_data[epoch_col].min() > 0 and 0 in metrics_df[epoch_col].unique():
# If epoch starts from 1 in plot data but 0 exists, adjust x-axis for alignment
ax1.plot(train_plot_data[epoch_col] + 1, train_plot_data[train_loss_col], color=color1, label='Train Loss', marker='.', linestyle='-')
logger.debug("Adjusting train loss x-axis by +1 for epoch alignment.")
else:
ax1.plot(train_plot_data[epoch_col], train_plot_data[train_loss_col], color=color1, label='Train Loss', marker='.', linestyle='-')
ax1.tick_params(axis='y', labelcolor=color1)
ax1.grid(True, axis='y', linestyle='--', alpha=0.6, which='major')
# Validation loss/metric plotting on twin axis
ax2 = ax1.twinx()
color2 = 'tab:blue'
# Adjust ylabel based on actual column name used for val metric
ax2.set_ylabel(found_val_col.replace('_epoch','').replace('_',' ').capitalize(), color=color2)
# Drop NaNs specific to the found validation column
val_plot_data = metrics_df[[epoch_col, found_val_col]].dropna(subset=[found_val_col])
# val_plot_data = val_plot_data[val_plot_data[epoch_col].notna()] # Ensure epoch is not NaN
# Ensure epoch starts from 0 or 1 consistently
if val_plot_data[epoch_col].min() > 0 and 0 in metrics_df[epoch_col].unique():
# If epoch starts from 1 in plot data but 0 exists, adjust x-axis for alignment
ax2.plot(val_plot_data[epoch_col] + 1, val_plot_data[found_val_col], color=color2, label='Validation Metric', marker='x', linestyle='--')
logger.debug("Adjusting val metric x-axis by +1 for epoch alignment.")
else:
ax2.plot(val_plot_data[epoch_col], val_plot_data[found_val_col], color=color2, label='Validation Metric', marker='x', linestyle='--')
ax2.tick_params(axis='y', labelcolor=color2)
# Add legend manually combining lines from both axes
lines, labels = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax2.legend(lines + lines2, labels + labels2, loc='upper right')
plt.title(title, fontsize=14)
fig.tight_layout() # Otherwise the right y-label is slightly clipped
# Save the plot
save_plot(fig, output_path)
except pd.errors.EmptyDataError:
logger.error(f"Metrics CSV file is empty: {metrics_csv_path}")
except KeyError as e:
logger.error(f"Could not find expected column in {metrics_csv_path}: {e}")
raise # Re-raise specific error after logging
except Exception as e:
logger.error(f"Failed to create or save loss curve plot from {metrics_csv_path}: {e}", exc_info=True)
raise # Re-raise general errors