intermediate backup
This commit is contained in:
@ -1,30 +1,34 @@
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
import time
|
||||
import json
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import yaml
|
||||
import pytorch_lightning as pl
|
||||
from matplotlib import pyplot as plt
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
|
||||
from pytorch_lightning.loggers import CSVLogger
|
||||
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
||||
|
||||
# Import necessary components from your project structure
|
||||
# Assuming forecasting_model is a package installable or in PYTHONPATH
|
||||
from forecasting_model.utils.config_model import MainConfig
|
||||
from forecasting_model.utils.forecast_config_model import MainConfig
|
||||
from forecasting_model.data_processing import (
|
||||
load_raw_data,
|
||||
TimeSeriesCrossValidationSplitter,
|
||||
prepare_fold_data_and_loaders
|
||||
)
|
||||
from forecasting_model.model import LSTMForecastLightningModule
|
||||
from forecasting_model.train.model import LSTMForecastLightningModule
|
||||
from forecasting_model.evaluation import evaluate_fold_predictions
|
||||
from typing import Dict, List, Any, Optional
|
||||
from forecasting_model.train.ensemble_evaluation import run_ensemble_evaluation
|
||||
|
||||
# Import the new classic training function
|
||||
from forecasting_model.train.classic import run_classic_training
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from forecasting_model.utils.helper import parse_arguments, load_config, set_seeds, aggregate_cv_metrics, save_results
|
||||
from forecasting_model.io.plotting import plot_loss_curve_from_csv, create_multi_horizon_time_series_plot, save_plot
|
||||
|
||||
# Silence overly verbose libraries if needed
|
||||
mpl_logger = logging.getLogger('matplotlib')
|
||||
@ -33,396 +37,552 @@ pil_logger = logging.getLogger('PIL')
|
||||
pil_logger.setLevel(logging.WARNING)
|
||||
|
||||
# --- Basic Logging Setup ---
|
||||
# Configure logging early. Level might be adjusted by config.
|
||||
# Configure logging early. Level might be adjusted by config later.
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)-7s - %(message)s',
|
||||
datefmt='%H:%M:%S')
|
||||
# Get the root logger
|
||||
logger = logging.getLogger()
|
||||
|
||||
# --- Argument Parsing ---
|
||||
def parse_arguments():
|
||||
"""Parses command-line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run the Time Series Forecasting training pipeline.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument(
|
||||
'-c', '--config',
|
||||
type=str,
|
||||
default='config.yaml',
|
||||
help="Path to the YAML configuration file."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--seed',
|
||||
type=int,
|
||||
default=None, # Default to None, use config value if not provided
|
||||
help="Override random seed defined in config."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--debug',
|
||||
action='store_true',
|
||||
help="Override log level to DEBUG."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output-dir',
|
||||
type=str,
|
||||
default='output/cv_results', # Default output base directory
|
||||
help="Base directory for saving cross-validation results (checkpoints, logs, plots)."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
# --- Helper Functions ---
|
||||
|
||||
def load_config(config_path: Path) -> MainConfig:
|
||||
# --- Single Fold Processing Function ---
|
||||
def run_single_fold(
|
||||
fold_num: int,
|
||||
train_idx: np.ndarray,
|
||||
val_idx: np.ndarray,
|
||||
test_idx: np.ndarray,
|
||||
config: MainConfig,
|
||||
full_df: pd.DataFrame,
|
||||
output_base_dir: Path # Receives Path object from run_training_pipeline
|
||||
) -> Tuple[Dict[str, float], Optional[float], Optional[Path], Optional[Path], Optional[Path], Optional[Path]]:
|
||||
"""
|
||||
Load and validate configuration from YAML file using Pydantic.
|
||||
Runs the pipeline for a single cross-validation fold.
|
||||
|
||||
Args:
|
||||
config_path: Path to the YAML configuration file.
|
||||
fold_num: The zero-based index of the current fold.
|
||||
train_idx: Indices for the training set.
|
||||
val_idx: Indices for the validation set.
|
||||
test_idx: Indices for the test set.
|
||||
config: The main configuration object.
|
||||
full_df: The complete raw DataFrame.
|
||||
output_base_dir: The base directory Path for saving results.
|
||||
|
||||
Returns:
|
||||
Validated MainConfig object.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the config file doesn't exist.
|
||||
yaml.YAMLError: If the file is not valid YAML.
|
||||
pydantic.ValidationError: If the config doesn't match the schema.
|
||||
A tuple containing:
|
||||
- fold_metrics: Dictionary of test metrics for the fold (e.g., {'MAE': ..., 'RMSE': ...}).
|
||||
- best_val_score: The best validation score achieved during training (or None).
|
||||
- saved_model_path: Path to the best saved model checkpoint (or None).
|
||||
- saved_target_scaler_path: Path to the saved target scaler (or None).
|
||||
- saved_input_size_path: Path to the saved input size file (or None).
|
||||
- saved_config_path: Path to the saved config file for this fold (or None).
|
||||
"""
|
||||
if not config_path.is_file():
|
||||
logger.error(f"Configuration file not found at: {config_path}")
|
||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||
fold_start_time = time.perf_counter()
|
||||
fold_id = fold_num + 1 # User-facing fold number (1-based)
|
||||
logger.info(f"--- Starting Fold {fold_id}/{config.cross_validation.n_splits} ---")
|
||||
|
||||
fold_output_dir = output_base_dir / f"fold_{fold_id:02d}"
|
||||
fold_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.debug(f"Fold output directory: {fold_output_dir}")
|
||||
|
||||
fold_metrics: Dict[str, float] = {'MAE': np.nan, 'RMSE': np.nan} # Default in case of failure
|
||||
best_val_score: Optional[float] = None
|
||||
best_model_path_str: Optional[str] = None # Use a different name for the string from callback
|
||||
|
||||
# Variables to hold prediction results for plotting later
|
||||
all_preds_scaled: Optional[np.ndarray] = None
|
||||
all_targets_scaled: Optional[np.ndarray] = None
|
||||
target_scaler: Optional[Union[StandardScaler, MinMaxScaler]] = None # Need to keep scaler reference
|
||||
prediction_target_time_index_h1: Optional[pd.DatetimeIndex] = None
|
||||
|
||||
# Variables to store paths of saved artifacts
|
||||
saved_model_path: Optional[Path] = None
|
||||
saved_target_scaler_path: Optional[Path] = None
|
||||
saved_input_size_path: Optional[Path] = None
|
||||
saved_config_path: Optional[Path] = None
|
||||
|
||||
logger.info(f"Loading configuration from: {config_path}")
|
||||
try:
|
||||
with open(config_path, 'r') as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
# --- Per-Fold Data Preparation ---
|
||||
logger.info("Preparing data loaders for the fold...")
|
||||
# Keep scaler and input_size references returned by prepare_fold_data_and_loaders
|
||||
train_loader, val_loader, test_loader, target_scaler_fold, input_size = prepare_fold_data_and_loaders( # Renamed target_scaler
|
||||
full_df=full_df,
|
||||
train_idx=train_idx,
|
||||
val_idx=val_idx,
|
||||
test_idx=test_idx,
|
||||
target_col=config.data.target_col, # Pass target col name explicitly
|
||||
feature_config=config.features,
|
||||
train_config=config.training,
|
||||
eval_config=config.evaluation
|
||||
)
|
||||
target_scaler = target_scaler_fold # Store the scaler in the outer scope
|
||||
logger.info(f"Data loaders prepared. Input size determined: {input_size}")
|
||||
|
||||
# Validate configuration using Pydantic model
|
||||
config = MainConfig(**config_dict)
|
||||
logger.info("Configuration loaded and validated successfully.")
|
||||
return config
|
||||
except yaml.YAMLError as e:
|
||||
logger.error(f"Error parsing YAML file {config_path}: {e}", exc_info=True)
|
||||
raise
|
||||
except Exception as e: # Catches Pydantic validation errors too
|
||||
logger.error(f"Error validating configuration {config_path}: {e}", exc_info=True)
|
||||
raise
|
||||
# Save necessary items for potential later use (e.g., ensemble)
|
||||
# Capture the paths when saving
|
||||
saved_target_scaler_path = fold_output_dir / "target_scaler.pt"
|
||||
torch.save(target_scaler, saved_target_scaler_path)
|
||||
torch.save(test_loader, fold_output_dir / "test_loader.pt") # Test loader might be large, consider if needed
|
||||
|
||||
def set_seeds(seed: Optional[int] = 42) -> None:
|
||||
"""
|
||||
Set random seeds for reproducibility across libraries.
|
||||
# Save input size and capture path
|
||||
saved_input_size_path = fold_output_dir / "input_size.pt"
|
||||
torch.save(input_size, saved_input_size_path)
|
||||
|
||||
Args:
|
||||
seed: The seed value to use. If None, uses default 42.
|
||||
"""
|
||||
if seed is None:
|
||||
seed = 42
|
||||
logger.warning(f"No seed provided, using default seed: {seed}")
|
||||
else:
|
||||
logger.info(f"Setting random seed: {seed}")
|
||||
# Save config for this fold (needed for reloading model) and capture path
|
||||
config_dump = config.model_dump()
|
||||
saved_config_path = fold_output_dir / "config.yaml" # Capture the path before saving
|
||||
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
# Ensure reproducibility for CUDA operations where possible
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed) # For multi-GPU
|
||||
# These settings can slow down training but improve reproducibility
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.backends.cudnn.benchmark = False
|
||||
# PyTorch Lightning seeding (optional, as we seed torch directly)
|
||||
# pl.seed_everything(seed, workers=True) # workers=True ensures dataloader reproducibility
|
||||
with open(saved_config_path, 'w') as f:
|
||||
yaml.dump(config_dump, f, default_flow_style=False)
|
||||
|
||||
def aggregate_cv_metrics(all_fold_metrics: List[Dict[str, float]]) -> Dict[str, Dict[str, float]]:
|
||||
"""
|
||||
Calculate mean and standard deviation of metrics across folds.
|
||||
Handles potential NaN values by ignoring them.
|
||||
|
||||
Args:
|
||||
all_fold_metrics: A list where each element is a dictionary of
|
||||
metrics for one fold (e.g., {'MAE': v1, 'RMSE': v2}).
|
||||
# --- Model Initialization ---
|
||||
model = LSTMForecastLightningModule(
|
||||
model_config=config.model,
|
||||
train_config=config.training,
|
||||
input_size=input_size,
|
||||
target_scaler=target_scaler_fold # Pass scaler during init
|
||||
)
|
||||
logger.info("LSTMForecastLightningModule initialized.")
|
||||
|
||||
Returns:
|
||||
A dictionary where keys are metric names and values are dicts
|
||||
containing 'mean' and 'std' for that metric across folds.
|
||||
Example: {'MAE': {'mean': m, 'std': s}, 'RMSE': {'mean': m2, 'std': s2}}
|
||||
"""
|
||||
if not all_fold_metrics:
|
||||
logger.warning("Received empty list for metric aggregation.")
|
||||
return {}
|
||||
# --- PyTorch Lightning Callbacks ---
|
||||
# Ensure monitor_metric matches the exact name logged in model.py
|
||||
monitor_metric = "val_MeanAbsoluteError_Original_Scale" # Corrected metric name
|
||||
monitor_mode = "min"
|
||||
|
||||
aggregated: Dict[str, Dict[str, float]] = {}
|
||||
# Get metric names from the first valid fold's results
|
||||
first_valid_metrics = next((m for m in all_fold_metrics if m), None)
|
||||
if not first_valid_metrics:
|
||||
logger.warning("No valid fold metrics found for aggregation.")
|
||||
return {}
|
||||
metric_names = list(first_valid_metrics.keys())
|
||||
early_stop_callback = None
|
||||
if config.training.early_stopping_patience is not None and config.training.early_stopping_patience > 0:
|
||||
early_stop_callback = EarlyStopping(
|
||||
monitor=monitor_metric,
|
||||
min_delta=0.0001,
|
||||
patience=config.training.early_stopping_patience,
|
||||
verbose=True,
|
||||
mode=monitor_mode
|
||||
)
|
||||
logger.info(f"Enabled EarlyStopping: monitor='{monitor_metric}', patience={config.training.early_stopping_patience}")
|
||||
|
||||
for metric in metric_names:
|
||||
# Collect values for this metric across all folds, ignoring NaNs
|
||||
values = [fold_metrics.get(metric) for fold_metrics in all_fold_metrics if fold_metrics and metric in fold_metrics]
|
||||
valid_values = [v for v in values if v is not None and not np.isnan(v)]
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
dirpath=fold_output_dir / "checkpoints",
|
||||
filename=f"best_model_fold_{fold_id}",
|
||||
save_top_k=1,
|
||||
monitor=monitor_metric,
|
||||
mode=monitor_mode,
|
||||
verbose=True
|
||||
)
|
||||
logger.info(f"Enabled ModelCheckpoint: monitor='{monitor_metric}', mode='{monitor_mode}'")
|
||||
|
||||
lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
||||
|
||||
callbacks = [checkpoint_callback, lr_monitor]
|
||||
if early_stop_callback:
|
||||
callbacks.append(early_stop_callback)
|
||||
|
||||
# --- PyTorch Lightning Logger ---
|
||||
# Log to a subdir specific to the fold, relative to output_base_dir
|
||||
log_dir = output_base_dir / f"fold_{fold_id:02d}" / "training_logs"
|
||||
pl_logger = CSVLogger(save_dir=str(log_dir.parent), name=log_dir.name, version='') # Use name for subdir, version='' to avoid 'version_0'
|
||||
logger.info(f"Using CSVLogger, logs will be saved in: {pl_logger.log_dir}")
|
||||
|
||||
|
||||
# --- PyTorch Lightning Trainer ---
|
||||
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
|
||||
devices = 1 if accelerator == 'gpu' else None
|
||||
precision = getattr(config.training, 'precision', 32)
|
||||
|
||||
trainer = pl.Trainer(
|
||||
accelerator=accelerator,
|
||||
devices=devices,
|
||||
max_epochs=config.training.epochs,
|
||||
callbacks=callbacks,
|
||||
logger=pl_logger,
|
||||
log_every_n_steps=max(1, len(train_loader)//10),
|
||||
enable_progress_bar=True,
|
||||
gradient_clip_val=getattr(config.training, 'gradient_clip_val', None),
|
||||
precision=precision,
|
||||
)
|
||||
logger.info(f"Initialized PyTorch Lightning Trainer: accelerator='{accelerator}', devices={devices}, precision={precision}")
|
||||
|
||||
# --- Training ---
|
||||
logger.info(f"Starting training for Fold {fold_id}...")
|
||||
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
|
||||
logger.info(f"Training finished for Fold {fold_id}.")
|
||||
|
||||
# Store best validation score and path
|
||||
best_val_score_tensor = trainer.checkpoint_callback.best_model_score
|
||||
# Capture the best model path reported by the checkpoint callback
|
||||
best_model_path_str = trainer.checkpoint_callback.best_model_path # Capture the string path
|
||||
best_val_score = best_val_score_tensor.item() if best_val_score_tensor is not None else None
|
||||
|
||||
if best_val_score is not None:
|
||||
logger.info(f"Best validation score ({monitor_metric}) for Fold {fold_id}: {best_val_score:.4f}")
|
||||
# Check if best_model_path was actually set by the callback
|
||||
if best_model_path_str:
|
||||
saved_model_path = Path(best_model_path_str) # Convert string to Path object and store
|
||||
logger.info(f"Best model checkpoint path: {best_model_path_str}")
|
||||
else:
|
||||
logger.warning(f"ModelCheckpoint callback did not report a best_model_path for Fold {fold_id}.")
|
||||
|
||||
if not valid_values:
|
||||
logger.warning(f"No valid values found for metric '{metric}' across folds.")
|
||||
mean_val = np.nan
|
||||
std_val = np.nan
|
||||
else:
|
||||
mean_val = float(np.mean(valid_values))
|
||||
std_val = float(np.std(valid_values))
|
||||
logger.debug(f"Aggregated '{metric}': Mean={mean_val:.4f}, Std={std_val:.4f} from {len(valid_values)} folds.")
|
||||
logger.warning(f"Could not retrieve best validation score/path for Fold {fold_id} (metric: {monitor_metric}). Evaluation might use last model.")
|
||||
best_model_path_str = None # Ensure string path is None if no best score
|
||||
|
||||
aggregated[metric] = {'mean': mean_val, 'std': std_val}
|
||||
|
||||
return aggregated
|
||||
# --- Prediction on Test Set ---
|
||||
logger.info(f"Starting prediction for Fold {fold_id} using {'best checkpoint' if saved_model_path else 'last model'}...")
|
||||
# Use the best checkpoint path if available, otherwise use the in-memory model instance
|
||||
ckpt_path_for_predict = str(saved_model_path) if saved_model_path else None # Use the saved Path object, convert to string for ckpt_path
|
||||
|
||||
|
||||
prediction_results_list = trainer.predict(
|
||||
model=model, # Use the in-memory model instance
|
||||
dataloaders=test_loader,
|
||||
ckpt_path=ckpt_path_for_predict # Specify checkpoint path if needed, though using model=model is typical
|
||||
)
|
||||
|
||||
|
||||
# --- Process Prediction Results & Get Time Index ---
|
||||
if not prediction_results_list:
|
||||
logger.error(f"Predict phase did not return any results for Fold {fold_id}. Check predict_step and logs.")
|
||||
all_preds_scaled = None # Ensure these are None on failure
|
||||
all_targets_scaled = None
|
||||
else:
|
||||
try:
|
||||
all_preds_scaled = torch.cat([batch_res['preds_scaled'] for batch_res in prediction_results_list], dim=0).numpy()
|
||||
n_predictions = len(all_preds_scaled)
|
||||
|
||||
if 'targets_scaled' in prediction_results_list[0]:
|
||||
all_targets_scaled = torch.cat([batch_res['targets_scaled'] for batch_res in prediction_results_list], dim=0).numpy()
|
||||
if len(all_targets_scaled) != n_predictions:
|
||||
logger.error(f"Fold {fold_id}: Mismatch between number of predictions ({n_predictions}) and targets ({len(all_targets_scaled)}).")
|
||||
raise ValueError("Prediction and target count mismatch during evaluation.")
|
||||
else:
|
||||
logger.error(f"Targets not found in prediction results for Fold {fold_id}. Cannot evaluate or plot original scale targets.")
|
||||
all_targets_scaled = None
|
||||
|
||||
|
||||
logger.info(f"Processing {n_predictions} prediction results for Fold {fold_id}...")
|
||||
|
||||
# --- Calculate Correct Time Index for Plotting (First Horizon) ---
|
||||
prediction_target_time_index_h1_path = fold_output_dir / "prediction_target_time_index_h1.pt"
|
||||
|
||||
prediction_target_time_index_h1 = None
|
||||
|
||||
if test_idx is not None and config.features.forecast_horizon and len(config.features.forecast_horizon) > 0:
|
||||
try:
|
||||
test_block_index = full_df.index[test_idx]
|
||||
seq_len = config.features.sequence_length
|
||||
first_horizon = config.features.forecast_horizon[0]
|
||||
|
||||
target_indices_h1 = test_idx + seq_len + first_horizon - 1
|
||||
|
||||
valid_target_indices_h1_mask = target_indices_h1 < len(full_df)
|
||||
valid_target_indices_h1 = target_indices_h1[valid_target_indices_h1_mask]
|
||||
|
||||
if len(valid_target_indices_h1) >= n_predictions: # Should be exactly n_predictions if no indices were out of bounds
|
||||
prediction_target_time_index_h1 = full_df.index[valid_target_indices_h1[:n_predictions]]
|
||||
if len(prediction_target_time_index_h1) != n_predictions:
|
||||
logger.warning(f"Fold {fold_id}: Calculated target time index length ({len(prediction_target_time_index_h1)}) "
|
||||
f"does not match prediction count ({n_predictions}). Plotting x-axis might be misaligned.")
|
||||
prediction_target_time_index_h1 = None
|
||||
|
||||
else:
|
||||
logger.warning(f"Fold {fold_id}: Cannot calculate target time index for h1; insufficient valid indices ({len(valid_target_indices_h1)} < {n_predictions}).")
|
||||
prediction_target_time_index_h1 = None
|
||||
|
||||
|
||||
# Save the calculated index if it's valid and evaluation plots are enabled
|
||||
if prediction_target_time_index_h1 is not None and not prediction_target_time_index_h1.empty and config.evaluation.save_plots:
|
||||
try:
|
||||
torch.save(prediction_target_time_index_h1, prediction_target_time_index_h1_path)
|
||||
logger.debug(f"Saved prediction target time index for h1 to {prediction_target_time_index_h1_path}")
|
||||
except Exception as save_e:
|
||||
logger.warning(f"Failed to save prediction target time index file {prediction_target_time_index_h1_path}: {save_e}")
|
||||
|
||||
elif prediction_target_time_index_h1_path.exists():
|
||||
try:
|
||||
prediction_target_time_index_h1_path.unlink()
|
||||
logger.debug("Removed outdated prediction target time index h1 file.")
|
||||
except OSError as e:
|
||||
logger.warning(f"Could not remove outdated prediction target index h1 file {prediction_target_time_index_h1_path}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Fold {fold_id}: Error calculating or saving target time index for plotting: {e}", exc_info=True)
|
||||
prediction_target_time_index_h1 = None
|
||||
else:
|
||||
logger.warning(f"Fold {fold_id}: Skipping target time index calculation (missing test_idx, forecast_horizon, or empty list).")
|
||||
if prediction_target_time_index_h1_path.exists():
|
||||
try:
|
||||
prediction_target_time_index_h1_path.unlink()
|
||||
logger.debug("Removed outdated prediction target time index h1 file as calculation was skipped.")
|
||||
except OSError as e:
|
||||
logger.warning(f"Could not remove outdated prediction target index h1 file {prediction_target_time_index_h1_path}: {e}")
|
||||
# --- End Index Calculation and Saving ---
|
||||
|
||||
|
||||
# --- Evaluation ---
|
||||
if all_targets_scaled is not None: # Only evaluate if targets are available
|
||||
fold_metrics = evaluate_fold_predictions(
|
||||
y_true_scaled=all_targets_scaled, # Pass the (N, H) array
|
||||
y_pred_scaled=all_preds_scaled, # Pass the (N, H) array
|
||||
target_scaler=target_scaler,
|
||||
eval_config=config.evaluation,
|
||||
fold_num=fold_num, # Pass zero-based index
|
||||
output_dir=str(fold_output_dir),
|
||||
plot_subdir="plots",
|
||||
# Pass the calculated index for the targets being plotted (h1 reference)
|
||||
prediction_time_index=prediction_target_time_index_h1, # Use the calculated index here (for h1)
|
||||
forecast_horizons=config.features.forecast_horizon, # Pass the list of horizons
|
||||
plot_title_prefix=f"CV Fold {fold_id}"
|
||||
)
|
||||
save_results(fold_metrics, fold_output_dir / "test_metrics.json")
|
||||
else:
|
||||
logger.error(f"Skipping evaluation for Fold {fold_id} due to missing targets.")
|
||||
|
||||
|
||||
# --- Multi-Horizon Plotting ---
|
||||
if config.evaluation.save_plots and all_preds_scaled is not None and all_targets_scaled is not None and prediction_target_time_index_h1 is not None and target_scaler is not None:
|
||||
logger.info(f"Generating multi-horizon plot for Fold {fold_id}...")
|
||||
try:
|
||||
multi_horizon_plot_path = fold_output_dir / "plots" / "multi_horizon_forecast.png"
|
||||
# Need to import save_plot function if it's not already imported
|
||||
# from forecasting_model.io.plotting import save_plot # Ensure this import is present if needed
|
||||
fig = create_multi_horizon_time_series_plot(
|
||||
y_true_scaled_all_horizons=all_targets_scaled,
|
||||
y_pred_scaled_all_horizons=all_preds_scaled,
|
||||
target_scaler=target_scaler,
|
||||
prediction_time_index_h1=prediction_target_time_index_h1,
|
||||
forecast_horizons=config.features.forecast_horizon,
|
||||
title=f"Fold {fold_id} Multi-Horizon Forecast",
|
||||
max_points=1000 # Limit points for clarity
|
||||
)
|
||||
# Check if save_plot is available or use fig.savefig()
|
||||
try:
|
||||
save_plot(fig, multi_horizon_plot_path)
|
||||
except NameError:
|
||||
# Fallback if save_plot is not defined/imported
|
||||
fig.savefig(multi_horizon_plot_path)
|
||||
plt.close(fig) # Close the figure after saving
|
||||
logger.warning("Using fig.savefig as save_plot function was not found.")
|
||||
|
||||
except Exception as plot_e:
|
||||
logger.error(f"Fold {fold_id}: Failed to generate multi-horizon plot: {plot_e}", exc_info=True)
|
||||
elif config.evaluation.save_plots:
|
||||
logger.warning(f"Fold {fold_id}: Skipping multi-horizon plot due to missing data (preds, targets, time index, or scaler).")
|
||||
|
||||
|
||||
except KeyError as e:
|
||||
logger.error(f"KeyError processing prediction results for Fold {fold_id}: Missing key {e}. Check predict_step return format.", exc_info=True)
|
||||
except ValueError as e: # Catch specific error from above
|
||||
logger.error(f"ValueError processing prediction results for Fold {fold_id}: {e}", exc_info=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing prediction results for Fold {fold_id}: {e}", exc_info=True)
|
||||
|
||||
# --- Plot Loss Curve for Fold ---
|
||||
try:
|
||||
actual_log_dir = Path(pl_logger.log_dir) / pl_logger.name # Should be .../fold_XX/training_logs
|
||||
metrics_file_path = actual_log_dir / "metrics.csv"
|
||||
|
||||
if metrics_file_path.is_file():
|
||||
plot_loss_curve_from_csv(
|
||||
metrics_csv_path=metrics_file_path,
|
||||
output_path=fold_output_dir / "plots" / "loss_curve.png", # Save in plots subdir
|
||||
title=f"Fold {fold_id} Training Progression",
|
||||
train_loss_col='train_loss',
|
||||
val_loss_col='val_loss' # This function handles fallback
|
||||
)
|
||||
logger.info(f"Loss curve plot saved for Fold {fold_id} to {fold_output_dir / 'plots' / 'loss_curve.png'}.")
|
||||
else:
|
||||
logger.warning(f"Fold {fold_id}: Could not find metrics.csv at {metrics_file_path} for loss curve plot.")
|
||||
except Exception as e:
|
||||
logger.error(f"Fold {fold_id}: Failed to generate loss curve plot: {e}", exc_info=True)
|
||||
# --- End Loss Curve Plotting ---
|
||||
|
||||
def save_results(results: Dict, filename: Path):
|
||||
"""Save dictionary results to a JSON file."""
|
||||
try:
|
||||
filename.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(filename, 'w') as f:
|
||||
json.dump(results, f, indent=4)
|
||||
logger.info(f"Saved results to {filename}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save results to {filename}: {e}", exc_info=True)
|
||||
logger.error(f"An error occurred during Fold {fold_id} pipeline: {e}", exc_info=True)
|
||||
# Ensure paths are None if an error occurs before they are set
|
||||
if saved_model_path is None: saved_model_path = None
|
||||
if saved_target_scaler_path is None: saved_target_scaler_path = None
|
||||
|
||||
|
||||
finally:
|
||||
# Clean up GPU memory explicitly
|
||||
del model, trainer # Ensure objects are deleted before clearing cache
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
logger.debug("Cleared CUDA cache.")
|
||||
|
||||
# Delete loaders explicitly if they might hold references
|
||||
del train_loader, val_loader, test_loader
|
||||
|
||||
fold_end_time = time.perf_counter()
|
||||
logger.info(f"--- Finished Fold {fold_id} in {fold_end_time - fold_start_time:.2f} seconds ---")
|
||||
|
||||
# Return the calculated fold metrics, best validation score, and saved artifact paths
|
||||
return fold_metrics, best_val_score, saved_model_path, saved_target_scaler_path, saved_input_size_path, saved_config_path
|
||||
|
||||
|
||||
# --- Main Training & Evaluation Function ---
|
||||
def run_training_pipeline(config: MainConfig, output_base_dir: Path):
|
||||
"""Runs the full cross-validation training and evaluation pipeline."""
|
||||
"""Runs the full training and evaluation pipeline based on config flags."""
|
||||
start_time = time.perf_counter()
|
||||
logger.info(f"Starting training pipeline. Results will be saved to: {output_base_dir}")
|
||||
output_base_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# --- Data Loading ---
|
||||
try:
|
||||
df = load_raw_data(config.data)
|
||||
except Exception as e:
|
||||
logger.critical(f"Failed to load raw data: {e}", exc_info=True)
|
||||
sys.exit(1) # Cannot proceed without data
|
||||
|
||||
# --- Cross-Validation Setup ---
|
||||
try:
|
||||
cv_splitter = TimeSeriesCrossValidationSplitter(config.cross_validation, len(df))
|
||||
except ValueError as e:
|
||||
logger.critical(f"Failed to initialize CV splitter: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
sys.exit(1)
|
||||
|
||||
# --- Initialize results ---
|
||||
all_fold_test_metrics: List[Dict[str, float]] = []
|
||||
all_fold_best_val_scores: Dict[int, Optional[float]] = {} # Store best val score per fold
|
||||
all_fold_best_val_scores: Dict[int, Optional[float]] = {}
|
||||
aggregated_metrics: Dict = {}
|
||||
final_results: Dict = {} # Initialize empty results dict
|
||||
|
||||
# --- Cross-Validation Loop ---
|
||||
logger.info(f"Starting {config.cross_validation.n_splits}-Fold Cross-Validation...")
|
||||
for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split()):
|
||||
fold_start_time = time.perf_counter()
|
||||
fold_id = fold_num + 1
|
||||
logger.info(f"--- Starting Fold {fold_id}/{config.cross_validation.n_splits} ---")
|
||||
|
||||
fold_output_dir = output_base_dir / f"fold_{fold_id:02d}"
|
||||
fold_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.debug(f"Fold output directory: {fold_output_dir}")
|
||||
|
||||
if config.run_cross_validation:
|
||||
logger.info(f"Starting {config.cross_validation.n_splits}-Fold Cross-Validation...")
|
||||
try:
|
||||
# --- Per-Fold Data Preparation ---
|
||||
logger.info("Preparing data loaders for the fold...")
|
||||
train_loader, val_loader, test_loader, target_scaler, input_size = prepare_fold_data_and_loaders(
|
||||
full_df=df,
|
||||
cv_splitter = TimeSeriesCrossValidationSplitter(config.cross_validation, len(df))
|
||||
except ValueError as e:
|
||||
logger.critical(f"Failed to initialize CV splitter: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split()):
|
||||
# Unpack the two new return values from run_single_fold
|
||||
fold_metrics, best_val_score, saved_model_path, saved_target_scaler_path, _input_size_path, _config_path = run_single_fold(
|
||||
fold_num=fold_num,
|
||||
train_idx=train_idx,
|
||||
val_idx=val_idx,
|
||||
test_idx=test_idx,
|
||||
target_col=config.data.target_col, # Pass target col name explicitly
|
||||
feature_config=config.features,
|
||||
train_config=config.training,
|
||||
eval_config=config.evaluation
|
||||
config=config,
|
||||
full_df=df,
|
||||
output_base_dir=output_base_dir
|
||||
)
|
||||
logger.info(f"Data loaders prepared. Input size determined: {input_size}")
|
||||
|
||||
# --- Model Initialization ---
|
||||
# Pass input_size directly, ModelConfig no longer holds it.
|
||||
# Ensure forecast horizon is consistent (checked in MainConfig validation)
|
||||
current_model_config = config.model # Use the validated model config
|
||||
|
||||
model = LSTMForecastLightningModule(
|
||||
model_config=current_model_config, # Does not contain input_size
|
||||
train_config=config.training,
|
||||
input_size=input_size, # Pass the dynamically determined input_size
|
||||
target_scaler=target_scaler # Pass the fold-specific scaler
|
||||
)
|
||||
logger.info("LSTMForecastLightningModule initialized.")
|
||||
|
||||
# --- PyTorch Lightning Callbacks ---
|
||||
# Monitor the validation MAE on the original scale (logged by LightningModule)
|
||||
monitor_metric = "val_mae_orig_scale"
|
||||
monitor_mode = "min"
|
||||
|
||||
early_stop_callback = None
|
||||
if config.training.early_stopping_patience is not None and config.training.early_stopping_patience > 0:
|
||||
early_stop_callback = EarlyStopping(
|
||||
monitor=monitor_metric,
|
||||
min_delta=0.0001, # Minimum change to qualify as improvement
|
||||
patience=config.training.early_stopping_patience,
|
||||
verbose=True,
|
||||
mode=monitor_mode
|
||||
)
|
||||
logger.info(f"Enabled EarlyStopping: monitor='{monitor_metric}', patience={config.training.early_stopping_patience}")
|
||||
|
||||
# Checkpoint callback to save the best model based on validation metric
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
dirpath=fold_output_dir / "checkpoints",
|
||||
filename=f"best_model_fold_{fold_id}", # {{epoch}}-{{val_loss:.2f}} etc. possible
|
||||
save_top_k=1,
|
||||
monitor=monitor_metric,
|
||||
mode=monitor_mode,
|
||||
verbose=True
|
||||
)
|
||||
logger.info(f"Enabled ModelCheckpoint: monitor='{monitor_metric}', mode='{monitor_mode}'")
|
||||
|
||||
# Learning rate monitor callback
|
||||
lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
||||
|
||||
callbacks = [checkpoint_callback, lr_monitor]
|
||||
if early_stop_callback:
|
||||
callbacks.append(early_stop_callback)
|
||||
|
||||
# --- PyTorch Lightning Logger ---
|
||||
# Log metrics to a CSV file within the fold directory
|
||||
pl_logger = CSVLogger(save_dir=str(output_base_dir), name=f"fold_{fold_id:02d}", version='logs')
|
||||
logger.info(f"Using CSVLogger, logs will be saved in: {pl_logger.log_dir}")
|
||||
|
||||
# --- PyTorch Lightning Trainer ---
|
||||
# Determine accelerator and devices based on PyTorch check
|
||||
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
|
||||
devices = 1 if accelerator == 'gpu' else None # Or specify specific GPU IDs [0], [1] etc.
|
||||
precision = getattr(config.training, 'precision', 32) # Default to 32-bit
|
||||
|
||||
trainer = pl.Trainer(
|
||||
accelerator=accelerator,
|
||||
devices=devices,
|
||||
max_epochs=config.training.epochs,
|
||||
callbacks=callbacks,
|
||||
logger=pl_logger,
|
||||
log_every_n_steps=max(1, len(train_loader)//10), # Log ~10 times per epoch
|
||||
enable_progress_bar=True, # Set to False for less verbose runs (e.g., HPO)
|
||||
gradient_clip_val=getattr(config.training, 'gradient_clip_val', None),
|
||||
precision=precision,
|
||||
# deterministic=True, # For stricter reproducibility (can slow down)
|
||||
)
|
||||
logger.info(f"Initialized PyTorch Lightning Trainer: accelerator='{accelerator}', devices={devices}, precision={precision}")
|
||||
|
||||
# --- Training ---
|
||||
logger.info(f"Starting training for Fold {fold_id}...")
|
||||
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
|
||||
logger.info(f"Training finished for Fold {fold_id}.")
|
||||
|
||||
# Store best validation score for this fold
|
||||
best_val_score = trainer.checkpoint_callback.best_model_score
|
||||
best_model_path = trainer.checkpoint_callback.best_model_path
|
||||
all_fold_best_val_scores[fold_id] = best_val_score.item() if best_val_score else None
|
||||
if best_val_score is not None:
|
||||
logger.info(f"Best validation score ({monitor_metric}) for Fold {fold_id}: {all_fold_best_val_scores[fold_id]:.4f}")
|
||||
logger.info(f"Best model checkpoint path: {best_model_path}")
|
||||
else:
|
||||
logger.warning(f"Could not retrieve best validation score/path for Fold {fold_id} (metric: {monitor_metric}). Evaluation might use last model.")
|
||||
best_model_path = None # Ensure evaluation doesn't try to load 'best' if checkpointing failed
|
||||
|
||||
# --- Prediction on Test Set ---
|
||||
# Use trainer.predict() to get model outputs
|
||||
logger.info(f"Starting prediction for Fold {fold_id} using best checkpoint...")
|
||||
# predict_step returns dict {'preds_scaled': ..., 'targets_scaled': ...}
|
||||
# We pass the test_loader here, which yields (x, y) pairs, so predict_step will include targets
|
||||
prediction_results_list = trainer.predict(
|
||||
# model=model, # Not needed if using ckpt_path
|
||||
ckpt_path=best_model_path if best_model_path else 'last', # Load best model or last if best failed
|
||||
dataloaders=test_loader
|
||||
# return_predictions=True # Default is True
|
||||
)
|
||||
|
||||
# Check if prediction returned results
|
||||
if not prediction_results_list:
|
||||
logger.error(f"Predict phase did not return any results for Fold {fold_id}. Check predict_step and logs.")
|
||||
fold_metrics = {'MAE': np.nan, 'RMSE': np.nan}
|
||||
else:
|
||||
try:
|
||||
# Concatenate predictions and targets from predict_step results
|
||||
all_preds_scaled = torch.cat([batch_res['preds_scaled'] for batch_res in prediction_results_list], dim=0).numpy()
|
||||
# Check if targets were included (they should be if using test_loader)
|
||||
if 'targets_scaled' in prediction_results_list[0]:
|
||||
all_targets_scaled = torch.cat([batch_res['targets_scaled'] for batch_res in prediction_results_list], dim=0).numpy()
|
||||
else:
|
||||
# This case shouldn't happen if using test_loader, but good safeguard
|
||||
logger.error(f"Targets not found in prediction results for Fold {fold_id}. Cannot evaluate.")
|
||||
raise ValueError("Targets missing from prediction results.")
|
||||
|
||||
|
||||
# --- Final Evaluation & Plotting ---
|
||||
logger.info(f"Processing prediction results for Fold {fold_id}...")
|
||||
fold_metrics = evaluate_fold_predictions(
|
||||
y_true_scaled=all_targets_scaled,
|
||||
y_pred_scaled=all_preds_scaled,
|
||||
target_scaler=target_scaler, # Use the scaler from this fold
|
||||
eval_config=config.evaluation,
|
||||
fold_num=fold_num, # Pass zero-based index
|
||||
output_dir=output_base_dir, # Base dir for saving plots etc.
|
||||
# time_index=df.iloc[test_idx].index # Pass time index if needed
|
||||
)
|
||||
# Save fold metrics
|
||||
save_results(fold_metrics, fold_output_dir / "test_metrics.json")
|
||||
|
||||
except KeyError as e:
|
||||
logger.error(f"KeyError processing prediction results for Fold {fold_id}: Missing key {e}. Check predict_step return format.", exc_info=True)
|
||||
fold_metrics = {'MAE': np.nan, 'RMSE': np.nan}
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing prediction results for Fold {fold_id}: {e}", exc_info=True)
|
||||
fold_metrics = {'MAE': np.nan, 'RMSE': np.nan}
|
||||
|
||||
all_fold_test_metrics.append(fold_metrics)
|
||||
all_fold_best_val_scores[fold_num + 1] = best_val_score
|
||||
|
||||
# --- (Optional) Log final test metrics using trainer.test() ---
|
||||
# If you want the metrics logged by test_step aggregated, call test now.
|
||||
# logger.info(f"Logging final test metrics via trainer.test() for Fold {fold_id}...")
|
||||
# try:
|
||||
# trainer.test(ckpt_path=best_model_path if best_model_path else 'last', dataloaders=test_loader, verbose=False)
|
||||
# except Exception as e:
|
||||
# logger.warning(f"trainer.test() call failed for Fold {fold_id}: {e}")
|
||||
# --- Aggregation and Reporting for CV ---
|
||||
logger.info("Cross-validation finished. Aggregating results...")
|
||||
aggregated_metrics = aggregate_cv_metrics(all_fold_test_metrics)
|
||||
final_results['aggregated_test_metrics'] = aggregated_metrics
|
||||
final_results['per_fold_test_metrics'] = all_fold_test_metrics
|
||||
final_results['per_fold_best_val_scores'] = all_fold_best_val_scores
|
||||
# Save intermediate results after CV
|
||||
save_results(final_results, output_base_dir / "aggregated_cv_results.json")
|
||||
|
||||
except Exception as e:
|
||||
# Catch errors during the fold processing (data prep, training, prediction, eval)
|
||||
logger.error(f"An error occurred during Fold {fold_id} pipeline: {e}", exc_info=True)
|
||||
all_fold_test_metrics.append({'MAE': np.nan, 'RMSE': np.nan})
|
||||
|
||||
|
||||
# --- Cleanup per fold ---
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
logger.debug("Cleared CUDA cache.")
|
||||
|
||||
fold_end_time = time.perf_counter()
|
||||
logger.info(f"--- Finished Fold {fold_id} in {fold_end_time - fold_start_time:.2f} seconds ---")
|
||||
|
||||
|
||||
# --- Aggregation and Final Reporting ---
|
||||
logger.info("Cross-validation finished. Aggregating results...")
|
||||
aggregated_metrics = aggregate_cv_metrics(all_fold_test_metrics)
|
||||
|
||||
# Save aggregated results
|
||||
final_results = {
|
||||
'aggregated_test_metrics': aggregated_metrics,
|
||||
'per_fold_test_metrics': all_fold_test_metrics,
|
||||
'per_fold_best_val_scores': all_fold_best_val_scores,
|
||||
}
|
||||
save_results(final_results, output_base_dir / "aggregated_cv_results.json")
|
||||
|
||||
|
||||
# Log final results
|
||||
logger.info("--- Aggregated Cross-Validation Test Results ---")
|
||||
if aggregated_metrics:
|
||||
for metric, stats in aggregated_metrics.items():
|
||||
logger.info(f"{metric}: {stats['mean']:.4f} ± {stats['std']:.4f}")
|
||||
else:
|
||||
logger.warning("No metrics available for aggregation.")
|
||||
logger.info("Skipping Cross-Validation loop as per config.")
|
||||
|
||||
|
||||
# --- Ensemble Evaluation ---
|
||||
if config.run_ensemble_evaluation:
|
||||
# The validator in MainConfig already ensures run_cross_validation is also true here
|
||||
logger.info("Starting ensemble evaluation...")
|
||||
try:
|
||||
ensemble_results = run_ensemble_evaluation(
|
||||
config=config, # Pass config for context if needed by sub-functions
|
||||
output_base_dir=output_base_dir
|
||||
)
|
||||
if ensemble_results:
|
||||
logger.info("Ensemble evaluation completed successfully")
|
||||
final_results['ensemble_results'] = ensemble_results
|
||||
save_results(final_results, output_base_dir / "aggregated_cv_results.json")
|
||||
else:
|
||||
logger.warning("No ensemble results were generated (potentially < 2 folds available).")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during ensemble evaluation: {e}", exc_info=True)
|
||||
else:
|
||||
logger.info("Skipping Ensemble evaluation as per config.")
|
||||
|
||||
|
||||
# --- Classic Training Run ---
|
||||
if config.run_classic_training:
|
||||
logger.info("Starting classic training run...")
|
||||
classic_output_dir = output_base_dir / "classic_run" # Define dir for logging path
|
||||
try:
|
||||
# Call the original classic training function directly
|
||||
classic_metrics = run_classic_training(
|
||||
config=config,
|
||||
full_df=df,
|
||||
output_base_dir=output_base_dir # It creates classic_run subdir internally
|
||||
)
|
||||
if classic_metrics:
|
||||
logger.info(f"Classic training run completed. Test Metrics: {classic_metrics}")
|
||||
final_results['classic_training_results'] = classic_metrics
|
||||
save_results(final_results, output_base_dir / "aggregated_cv_results.json")
|
||||
|
||||
# --- Plot Loss Curve for Classic Run ---
|
||||
try:
|
||||
classic_log_dir = classic_output_dir / "training_logs"
|
||||
metrics_file = classic_log_dir / "metrics.csv"
|
||||
version_dirs = list(classic_log_dir.glob("version_*"))
|
||||
if version_dirs:
|
||||
metrics_file = version_dirs[0] / "metrics.csv"
|
||||
|
||||
if metrics_file.is_file():
|
||||
plot_loss_curve_from_csv(
|
||||
metrics_csv_path=metrics_file,
|
||||
output_path=classic_output_dir / "loss_curve.png",
|
||||
title="Classic Run Training Progression",
|
||||
train_loss_col='train_loss', # Changed from 'train_loss_epoch'
|
||||
val_loss_col='val_loss' # Check your logged metric names
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Classic Run: Could not find metrics.csv at {metrics_file} for loss curve plot.")
|
||||
except Exception as plot_e:
|
||||
logger.error(f"Classic Run: Failed to generate loss curve plot: {plot_e}", exc_info=True)
|
||||
# --- End Classic Loss Plotting ---
|
||||
|
||||
else:
|
||||
logger.warning("Classic training run did not produce metrics.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during classic training run: {e}", exc_info=True)
|
||||
else:
|
||||
logger.info("Skipping Classic training run as per config.")
|
||||
|
||||
|
||||
# --- Final Logging Summary ---
|
||||
logger.info("--- Final Summary ---")
|
||||
# Log aggregated CV results if they exist
|
||||
if 'aggregated_test_metrics' in final_results and final_results['aggregated_test_metrics']:
|
||||
logger.info("--- Aggregated Cross-Validation Test Results ---")
|
||||
for metric, stats in final_results['aggregated_test_metrics'].items():
|
||||
logger.info(f"{metric}: {stats.get('mean', np.nan):.4f} ± {stats.get('std', np.nan):.4f}")
|
||||
elif config.run_cross_validation:
|
||||
logger.warning("Cross-validation was run, but no metrics were aggregated.")
|
||||
|
||||
# Log aggregated ensemble results if they exist
|
||||
if 'ensemble_results' in final_results and final_results['ensemble_results']:
|
||||
logger.info("--- Aggregated Ensemble Test Results (Mean over Test Folds) ---")
|
||||
agg_ensemble = {}
|
||||
for fold_res in final_results['ensemble_results'].values():
|
||||
if isinstance(fold_res, dict):
|
||||
for method, metrics in fold_res.items():
|
||||
if method not in agg_ensemble: agg_ensemble[method] = {}
|
||||
if isinstance(metrics, dict):
|
||||
for m_name, m_val in metrics.items():
|
||||
if m_name not in agg_ensemble[method]: agg_ensemble[method][m_name] = []
|
||||
agg_ensemble[method][m_name].append(m_val)
|
||||
else: logger.warning(f"Skipping non-dict metrics for ensemble method '{method}'.")
|
||||
else: logger.warning("Skipping non-dict fold result in ensemble aggregation.")
|
||||
|
||||
for method, metrics_data in agg_ensemble.items():
|
||||
logger.info(f" Ensemble Method: {method}")
|
||||
for m_name, values in metrics_data.items():
|
||||
valid_vals = [v for v in values if v is not None and not np.isnan(v)]
|
||||
if valid_vals: logger.info(f" {m_name}: {np.mean(valid_vals):.4f} ± {np.std(valid_vals):.4f}")
|
||||
else: logger.info(f" {m_name}: N/A")
|
||||
|
||||
|
||||
# Log classic results if they exist
|
||||
if 'classic_training_results' in final_results and final_results['classic_training_results']:
|
||||
logger.info("--- Classic Training Test Results ---")
|
||||
classic_res = final_results['classic_training_results']
|
||||
for metric, value in classic_res.items():
|
||||
logger.info(f"{metric}: {value:.4f}")
|
||||
|
||||
logger.info("-------------------------------------------------")
|
||||
|
||||
end_time = time.perf_counter()
|
||||
@ -434,12 +594,6 @@ def run():
|
||||
"""Main execution function."""
|
||||
args = parse_arguments()
|
||||
config_path = Path(args.config)
|
||||
output_dir = Path(args.output_dir)
|
||||
|
||||
# Adjust log level if debug flag is set
|
||||
if args.debug:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.debug("# --- Debug mode enabled. --- #")
|
||||
|
||||
# --- Configuration Loading ---
|
||||
try:
|
||||
@ -448,10 +602,20 @@ def run():
|
||||
# Error already logged in load_config
|
||||
sys.exit(1)
|
||||
|
||||
# --- Seed Setting ---
|
||||
# Use command-line seed if provided, otherwise use config seed
|
||||
seed = args.seed if args.seed is not None else getattr(config, 'random_seed', 42)
|
||||
set_seeds(seed)
|
||||
# --- Setup based on Config ---
|
||||
# 1. Set Log Level
|
||||
log_level_name = config.log_level.upper()
|
||||
log_level = getattr(logging, log_level_name, logging.INFO)
|
||||
logger.setLevel(log_level)
|
||||
logger.info(f"Log level set to: {log_level_name}")
|
||||
if log_level == logging.DEBUG:
|
||||
logger.debug("# --- Debug mode enabled via config. --- #")
|
||||
|
||||
# 2. Set Seed
|
||||
set_seeds(config.random_seed)
|
||||
|
||||
# 3. Determine Output Directory
|
||||
output_dir = Path(config.output_dir)
|
||||
|
||||
# --- Pipeline Execution ---
|
||||
try:
|
||||
@ -459,7 +623,7 @@ def run():
|
||||
|
||||
except SystemExit as e:
|
||||
logger.warning(f"Pipeline exited with code {e.code}.")
|
||||
sys.exit(e.code) # Propagate exit code
|
||||
sys.exit(e.code)
|
||||
except Exception as e:
|
||||
logger.critical(f"An critical error occurred during pipeline execution: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
Reference in New Issue
Block a user