Files
entrix_case_challange/forecasting_model_run.py
2025-05-12 20:05:28 +02:00

639 lines
32 KiB
Python

import logging
import sys
from pathlib import Path
import time
import numpy as np
import pandas as pd
import torch
import yaml
import pytorch_lightning as pl
from matplotlib import pyplot as plt
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import CSVLogger
from sklearn.preprocessing import StandardScaler, MinMaxScaler
# Import necessary components from your project structure
from forecasting_model.utils.forecast_config_model import MainConfig
from forecasting_model.utils.data_processing import (
prepare_fold_data_and_loaders
)
from forecasting_model.utils.dataset_splitter import TimeSeriesCrossValidationSplitter
from forecasting_model.io.data import load_raw_data
from forecasting_model.train.model import LSTMForecastLightningModule
from forecasting_model.utils.evaluation import evaluate_fold_predictions
from forecasting_model.train.ensemble_evaluation import run_ensemble_evaluation
# Import the new classic training function
from forecasting_model.train.classic import run_model_training
from typing import Dict, List, Optional, Tuple, Union
from forecasting_model.utils.helper import (
parse_arguments, load_config,
set_seeds, aggregate_cv_metrics,
save_results, calculate_h1_target_index
)
from forecasting_model.io.plotting import plot_loss_curve_from_csv, create_multi_horizon_time_series_plot, save_plot
# Silence overly verbose libraries if needed
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)
pil_logger = logging.getLogger('PIL')
pil_logger.setLevel(logging.WARNING)
# --- Basic Logging Setup ---
# Configure logging early. Level might be adjusted by config later.
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)-7s - %(message)s',
datefmt='%H:%M:%S')
# Get the root logger
logger = logging.getLogger()
# --- Single Fold Processing Function ---
# noinspection PyInconsistentReturns
def run_single_fold(
fold_num: int,
train_idx: np.ndarray,
val_idx: np.ndarray,
test_idx: np.ndarray,
config: MainConfig,
full_df: pd.DataFrame,
output_base_dir: Path,
enable_progress_bar: bool = True
) -> Optional[Tuple[Dict[str, float], Optional[float], Optional[Path], Optional[Path], Optional[Path], Optional[Path], Optional[Path]]]:
"""
Runs the pipeline for a single cross-validation fold.
Args:
fold_num: The zero-based index of the current fold.
train_idx: Indices for the training set.
val_idx: Indices for the validation set.
test_idx: Indices for the test set.
config: The main configuration object.
full_df: The complete raw DataFrame.
output_base_dir: The base directory Path for saving results.
enable_progress_bar: Whether to enable progress bar.
Returns:
A tuple containing:
- fold_metrics: Dictionary of test metrics for the fold (e.g., {'MAE': ..., 'RMSE': ...}).
- best_val_score: The best validation score achieved during training (or None).
- saved_model_path: Path to the best saved model checkpoint (or None).
- saved_target_scaler_path: Path to the saved target scaler (or None).
- saved_data_scaler_path: Path to the saved data feature scaler (or None).
- saved_input_size_path: Path to the saved input size file (or None).
- saved_config_path: Path to the saved config file for this fold (or None).
"""
fold_start_time = time.perf_counter()
fold_id = fold_num + 1 # User-facing fold number (1-based)
logger.info(f"--- Starting Fold {fold_id}/{config.cross_validation.n_splits} ---")
fold_output_dir = output_base_dir / f"fold_{fold_id:02d}"
fold_output_dir.mkdir(parents=True, exist_ok=True)
logger.debug(f"Fold output directory: {fold_output_dir}")
fold_metrics: Dict[str, float] = {'MAE': np.nan, 'RMSE': np.nan} # Default in case of failure
best_val_score: Optional[float] = None
best_model_path_str: Optional[str] = None # Use a different name for the string from callback
# Variables to hold prediction results for plotting later
all_preds_scaled: Optional[np.ndarray] = None
all_targets_scaled: Optional[np.ndarray] = None
target_scaler: Optional[Union[StandardScaler, MinMaxScaler]] = None
data_scaler: Optional[Union[StandardScaler, MinMaxScaler]] = None
prediction_target_time_index_h1: Optional[pd.DatetimeIndex] = None
pl_logger = None
# Variables to store paths of saved artifacts
saved_model_path: Optional[Path] = None
saved_target_scaler_path: Optional[Path] = None
saved_data_scaler_path: Optional[Path] = None
saved_input_size_path: Optional[Path] = None
saved_config_path: Optional[Path] = None
try:
# --- Per-Fold Data Preparation ---
logger.info("Preparing data loaders for the fold...")
# Assume prepare_fold_data_and_loaders returns the data_scaler as the 5th element
# Modify this call based on the actual return signature of prepare_fold_data_and_loaders
train_loader, val_loader, test_loader, target_scaler_fold, data_scaler_fold, input_size = prepare_fold_data_and_loaders(
full_df=full_df,
train_idx=train_idx,
val_idx=val_idx,
test_idx=test_idx,
target_col=config.data.target_col, # Pass target col name explicitly
feature_config=config.features,
train_config=config.training,
eval_config=config.evaluation
)
target_scaler = target_scaler_fold # Store the target scaler in the outer scope
data_scaler = data_scaler_fold # Store the data scaler in the outer scope
logger.info(f"Data loaders prepared. Input size determined: {input_size}")
# Save necessary items for potential later use (e.g., ensemble, inference)
# Capture the paths when saving
saved_target_scaler_path = fold_output_dir / "target_scaler.pt"
torch.save(target_scaler, saved_target_scaler_path)
saved_data_scaler_path = fold_output_dir / "data_scaler.pt"
torch.save(data_scaler, saved_data_scaler_path)
torch.save(test_loader, fold_output_dir / "test_loader.pt") # Test loader might be large, consider if needed
# Save input size and capture path
saved_input_size_path = fold_output_dir / "input_size.pt"
torch.save(input_size, saved_input_size_path)
# Save config for this fold (needed for reloading model) and capture path
config_dump = config.model_dump()
saved_config_path = fold_output_dir / "config.yaml" # Capture the path before saving
with open(saved_config_path, 'w') as f:
yaml.dump(config_dump, f, default_flow_style=False)
# --- Model Initialization ---
model = LSTMForecastLightningModule(
model_config=config.model,
train_config=config.training,
input_size=input_size,
target_scaler=target_scaler_fold,
data_scaler=data_scaler
)
logger.info("LSTMForecastLightningModule initialized.")
# --- PyTorch Lightning Callbacks ---
# Ensure monitor_metric matches the exact name logged in model.py
monitor_metric = "val_MeanAbsoluteError" # Corrected metric name
monitor_mode = "min"
early_stop_callback = None
if config.training.early_stopping_patience is not None and config.training.early_stopping_patience > 0:
early_stop_callback = EarlyStopping(
monitor=monitor_metric,
min_delta=0.0001,
patience=config.training.early_stopping_patience,
verbose=True,
mode=monitor_mode
)
logger.info(f"Enabled EarlyStopping: monitor='{monitor_metric}', patience={config.training.early_stopping_patience}")
checkpoint_callback = ModelCheckpoint(
dirpath=fold_output_dir / "checkpoints",
filename=f"best_model_fold_{fold_id}",
save_top_k=1,
monitor=monitor_metric,
mode=monitor_mode,
verbose=True
)
logger.info(f"Enabled ModelCheckpoint: monitor='{monitor_metric}', mode='{monitor_mode}'")
lr_monitor = LearningRateMonitor(logging_interval='epoch')
callbacks = [checkpoint_callback, lr_monitor]
if early_stop_callback:
# noinspection PyTypeChecker
callbacks.append(early_stop_callback)
# --- PyTorch Lightning Logger ---
# Log to a subdir specific to the fold, relative to output_base_dir
log_dir = output_base_dir / f"fold_{fold_id:02d}" / "training_logs"
pl_logger = CSVLogger(save_dir=str(log_dir.parent), name=log_dir.name, version='') # Use name for subdir, version='' to avoid 'version_0'
logger.info(f"Using CSVLogger, logs will be saved in: {pl_logger.log_dir}")
# --- PyTorch Lightning Trainer ---
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
devices = 1 if accelerator == 'gpu' else None
precision = getattr(config.training, 'precision', 32)
trainer = pl.Trainer(
accelerator=accelerator,
check_val_every_n_epoch=config.training.check_val_n_epoch,
devices=devices,
enable_progress_bar=enable_progress_bar,
max_epochs=config.training.epochs,
callbacks=callbacks,
logger=pl_logger,
log_every_n_steps=max(1, len(train_loader)//10),
gradient_clip_val=getattr(config.training, 'gradient_clip_val', None),
precision=precision,
)
logger.info(f"Initialized PyTorch Lightning Trainer: accelerator='{accelerator}', devices={devices}, precision={precision}")
# --- Training ---
logger.info(f"Starting training for Fold {fold_id}...")
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
logger.info(f"Training finished for Fold {fold_id}.")
# Store best validation score and path
best_val_score_tensor = trainer.checkpoint_callback.best_model_score
# Capture the best model path reported by the checkpoint callback
best_model_path_str = trainer.checkpoint_callback.best_model_path # Capture the string path
best_val_score = best_val_score_tensor.item() if best_val_score_tensor is not None else None
if best_val_score is not None:
logger.info(f"Best validation score ({monitor_metric}) for Fold {fold_id}: {best_val_score:.4f}")
# Check if best_model_path was actually set by the callback
if best_model_path_str:
saved_model_path = Path(best_model_path_str) # Convert string to Path object and store
logger.info(f"Best model checkpoint path: {best_model_path_str}")
else:
logger.warning(f"ModelCheckpoint callback did not report a best_model_path for Fold {fold_id}.")
else:
logger.warning(f"Could not retrieve best validation score/path for Fold {fold_id} (metric: {monitor_metric}). Evaluation might use last model.")
best_model_path_str = None # Ensure string path is None if no best score
# --- Prediction on Test Set ---
logger.info(f"Starting prediction for Fold {fold_id} using {'best checkpoint' if saved_model_path else 'last model'}...")
# Use the best checkpoint path if available, otherwise use the in-memory model instance
ckpt_path_for_predict = str(saved_model_path) if saved_model_path else None # Use the saved Path object, convert to string for ckpt_path
prediction_results_list = trainer.predict(
model=model, # Use the in-memory model instance
dataloaders=test_loader,
ckpt_path=ckpt_path_for_predict # Specify checkpoint path if needed, though using model=model is typical
)
# --- Process Prediction Results & Get Time Index ---
if not prediction_results_list:
logger.error(f"Predict phase did not return any results for Fold {fold_id}. Check predict_step and logs.")
all_preds_scaled = None # Ensure these are None on failure
all_targets_scaled = None
else:
try:
all_preds_scaled = torch.cat([batch_res['preds_scaled'] for batch_res in prediction_results_list], dim=0).numpy()
n_predictions = len(all_preds_scaled)
if 'targets_scaled' in prediction_results_list[0]:
all_targets_scaled = torch.cat([batch_res['targets_scaled'] for batch_res in prediction_results_list], dim=0).numpy()
if len(all_targets_scaled) != n_predictions:
logger.error(f"Fold {fold_id}: Mismatch between number of predictions ({n_predictions}) and targets ({len(all_targets_scaled)}).")
raise ValueError("Prediction and target count mismatch during evaluation.")
else:
logger.error(f"Targets not found in prediction results for Fold {fold_id}. Cannot evaluate or plot original scale targets.")
all_targets_scaled = None
logger.info(f"Processing {n_predictions} prediction results for Fold {fold_id}...")
# --- Calculate Correct Time Index for Plotting (First Horizon) ---
prediction_target_time_index_h1 = calculate_h1_target_index(
full_df=full_df,
test_idx=test_idx,
sequence_length=config.features.sequence_length,
forecast_horizon=config.features.forecast_horizon,
n_predictions=n_predictions,
fold_id=fold_id
)
# --- Handle Saving/Cleanup of the Index File ---
prediction_target_time_index_h1_path = fold_output_dir / "prediction_target_time_index_h1.pt"
if prediction_target_time_index_h1 is not None and config.evaluation.save_plots:
# Save the calculated index if valid and plots are enabled
try:
torch.save(prediction_target_time_index_h1, prediction_target_time_index_h1_path)
logger.debug(f"Saved prediction target time index for h1 to {prediction_target_time_index_h1_path}")
except Exception as save_e:
logger.warning(f"Failed to save prediction target time index file {prediction_target_time_index_h1_path}: {save_e}")
elif prediction_target_time_index_h1_path.exists():
# Remove outdated file if index is invalid/not calculated OR plots disabled
logger.debug(f"Removing potentially outdated time index file: {prediction_target_time_index_h1_path}")
try:
prediction_target_time_index_h1_path.unlink()
except OSError as e:
logger.warning(f"Could not remove outdated prediction target index h1 file {prediction_target_time_index_h1_path}: {e}")
# --- End Index Calculation and Saving ---
# --- Evaluation ---
if all_targets_scaled is not None: # Only evaluate if targets are available
fold_metrics = evaluate_fold_predictions(
y_true_scaled=all_targets_scaled, # Pass the (N, H) array
y_pred_scaled=all_preds_scaled, # Pass the (N, H) array
target_scaler=target_scaler,
data_scaler=data_scaler,
eval_config=config.evaluation,
fold_num=fold_num, # Pass zero-based index
output_dir=str(fold_output_dir),
plot_subdir="plots",
# Pass the calculated index for the targets being plotted (h1 reference)
prediction_time_index=prediction_target_time_index_h1, # Use the calculated index here (for h1)
forecast_horizons=config.features.forecast_horizon, # Pass the list of horizons
plot_title_prefix=f"CV Fold {fold_id}",
)
save_results(fold_metrics, fold_output_dir / "test_metrics.json")
else:
logger.error(f"Skipping evaluation for Fold {fold_id} due to missing targets.")
# --- Multi-Horizon Plotting ---
if config.evaluation.save_plots and all_preds_scaled is not None and all_targets_scaled is not None and prediction_target_time_index_h1 is not None and target_scaler is not None:
logger.info(f"Generating multi-horizon plot for Fold {fold_id}...")
try:
multi_horizon_plot_path = fold_output_dir / "plots" / "multi_horizon_forecast.png"
# Need to import save_plot function if it's not already imported
# from forecasting_model.io.plotting import save_plot # Ensure this import is present if needed
fig = create_multi_horizon_time_series_plot(
y_true_scaled_all_horizons=all_targets_scaled,
y_pred_scaled_all_horizons=all_preds_scaled,
target_scaler=target_scaler,
prediction_time_index_h1=prediction_target_time_index_h1,
forecast_horizons=config.features.forecast_horizon,
title=f"Fold {fold_id} Multi-Horizon Forecast",
max_points=1000 # Limit points for clarity
)
# Check if save_plot is available or use fig.savefig()
try:
save_plot(fig, multi_horizon_plot_path)
except NameError:
# Fallback if save_plot is not defined/imported
fig.savefig(multi_horizon_plot_path)
plt.close(fig) # Close the figure after saving
logger.warning("Using fig.savefig as save_plot function was not found.")
except Exception as plot_e:
logger.error(f"Fold {fold_id}: Failed to generate multi-horizon plot: {plot_e}", exc_info=True)
elif config.evaluation.save_plots:
logger.warning(f"Fold {fold_id}: Skipping multi-horizon plot due to missing data (preds, targets, time index, or scaler).")
except KeyError as e:
logger.error(f"KeyError processing prediction results for Fold {fold_id}: Missing key {e}. Check predict_step return format.", exc_info=True)
except ValueError as e: # Catch specific error from above
logger.error(f"ValueError processing prediction results for Fold {fold_id}: {e}", exc_info=True)
except Exception as e:
logger.error(f"Error processing prediction results for Fold {fold_id}: {e}", exc_info=True)
except Exception as e:
logger.error(f"An error occurred during Fold {fold_id} pipeline: {e}", exc_info=True)
# Ensure paths are None if an error occurs before they are set
if saved_model_path is None: saved_model_path = None
if saved_target_scaler_path is None: saved_target_scaler_path = None
if saved_data_scaler_path is None: saved_data_scaler_path = None # Added check
if saved_input_size_path is None: saved_input_size_path = None
if saved_config_path is None: saved_config_path = None
finally:
# Clean up GPU memory explicitly
del model, trainer # Ensure objects are deleted before clearing cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.debug("Cleared CUDA cache.")
# Delete loaders explicitly if they might hold references
del train_loader, val_loader, test_loader
# --- Plot Loss Curve for Fold ---
if pl_logger and hasattr(pl_logger, 'log_dir') and pl_logger.log_dir: # Check if logger exists and has log_dir
try:
# Use the logger's log_dir directly, it already includes the 'name' segment
actual_log_dir = Path(pl_logger.log_dir) # FIX: Remove appending pl_logger.name
metrics_file_path = actual_log_dir / "metrics.csv"
if metrics_file_path.is_file():
plot_loss_curve_from_csv(
metrics_csv_path=metrics_file_path,
# Save plot inside the specific fold's plot directory
output_path=fold_output_dir / "plots" / "loss_curve.png",
title=f"Fold {fold_id} Training Progression",
train_loss_col='train_loss', # Ensure these column names match your CSVLogger output
val_loss_col='val_loss' # Ensure these column names match your CSVLogger output
)
logger.info(f"Loss curve plot saved for Fold {fold_id} to {fold_output_dir / 'plots' / 'loss_curve.png'}.")
else:
logger.warning(f"Fold {fold_id}: Could not find metrics.csv at {metrics_file_path} for loss curve plot.")
except AttributeError:
logger.warning(f"Fold {fold_id}: Could not plot loss curve, CSVLogger object or log_dir attribute missing.")
except Exception as e:
logger.error(f"Fold {fold_id}: Failed to generate loss curve plot: {e}", exc_info=True)
else:
logger.warning(f"Fold {fold_id}: Skipping loss curve plot generation as CSVLogger was not properly initialized or log_dir is missing.")
# --- End Loss Curve Plotting ---
fold_end_time = time.perf_counter()
logger.info(f"--- Finished Fold {fold_id} in {fold_end_time - fold_start_time:.2f} seconds ---")
pass
# Return the calculated fold metrics, best validation score, and saved artifact paths
return fold_metrics, best_val_score, saved_model_path, saved_target_scaler_path, saved_data_scaler_path, saved_input_size_path, saved_config_path
# --- Main Training & Evaluation Function ---
def run_training_pipeline(config: MainConfig, output_base_dir: Path):
"""Runs the full training and evaluation pipeline based on config flags."""
start_time = time.perf_counter()
logger.info(f"Starting training pipeline. Results will be saved to: {output_base_dir}")
output_base_dir.mkdir(parents=True, exist_ok=True)
# --- Data Loading ---
try:
df = load_raw_data(config.data)
except Exception as e:
logger.critical(f"Failed to load raw data: {e}", exc_info=True)
sys.exit(1)
# --- Initialize results ---
all_fold_test_metrics: List[Dict[str, float]] = []
all_fold_best_val_scores: Dict[int, Optional[float]] = {}
aggregated_metrics: Dict = {}
final_results: Dict = {} # Initialize empty results dict
# --- Cross-Validation Loop ---
if config.run_cross_validation:
logger.info(f"Starting {config.cross_validation.n_splits}-Fold Cross-Validation...")
try:
cv_splitter = TimeSeriesCrossValidationSplitter(config.cross_validation, len(df))
except ValueError as e:
logger.critical(f"Failed to initialize CV splitter: {e}", exc_info=True)
sys.exit(1)
for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split()):
# Unpack the return values from run_single_fold, including the new data_scaler path
fold_metrics, best_val_score, saved_model_path, saved_target_scaler_path, saved_data_scaler_path, _input_size_path, _config_path = run_single_fold(
fold_num=fold_num,
train_idx=train_idx,
val_idx=val_idx,
test_idx=test_idx,
config=config,
full_df=df,
output_base_dir=output_base_dir
)
all_fold_test_metrics.append(fold_metrics)
all_fold_best_val_scores[fold_num + 1] = best_val_score
# --- Aggregation and Reporting for CV ---
logger.info("Cross-validation finished. Aggregating results...")
aggregated_metrics = aggregate_cv_metrics(all_fold_test_metrics)
final_results['aggregated_test_metrics'] = aggregated_metrics
final_results['per_fold_test_metrics'] = all_fold_test_metrics
final_results['per_fold_best_val_scores'] = all_fold_best_val_scores
# Save intermediate results after CV
save_results(final_results, output_base_dir / "aggregated_cv_results.json")
else:
logger.info("Skipping Cross-Validation loop as per config.")
# --- Ensemble Evaluation ---
if config.run_ensemble_evaluation:
# The validator in MainConfig already ensures run_cross_validation is also true here
logger.info("Starting ensemble evaluation...")
try:
ensemble_results = run_ensemble_evaluation(
config=config, # Pass config for context if needed by sub-functions
output_base_dir=output_base_dir
)
if ensemble_results:
logger.info("Ensemble evaluation completed successfully")
final_results['ensemble_results'] = ensemble_results
save_results(final_results, output_base_dir / "aggregated_cv_results.json")
else:
logger.warning("No ensemble results were generated (potentially < 2 folds available).")
except Exception as e:
logger.error(f"Error during ensemble evaluation: {e}", exc_info=True)
else:
logger.info("Skipping Ensemble evaluation as per config.")
# --- Classic Training Run ---
if config.run_classic_training:
logger.info("Starting classic training run...")
classic_output_dir = output_base_dir / "classic_run" # Define dir for logging path
try:
# Call the original classic training function directly
classic_metrics = run_model_training(
config=config,
full_df=df,
output_base_dir=output_base_dir # It creates classic_run subdir internally
)
if classic_metrics:
logger.info(f"Classic training run completed. Test Metrics: {classic_metrics}")
final_results['classic_training_results'] = classic_metrics
save_results(final_results, output_base_dir / "aggregated_cv_results.json")
# --- Plot Loss Curve for Classic Run ---
try:
classic_log_dir = classic_output_dir / "training_logs"
metrics_file = classic_log_dir / "metrics.csv"
version_dirs = list(classic_log_dir.glob("version_*"))
if version_dirs:
metrics_file = version_dirs[0] / "metrics.csv"
if metrics_file.is_file():
plot_loss_curve_from_csv(
metrics_csv_path=metrics_file,
output_path=classic_output_dir / "loss_curve.png",
title="Classic Run Training Progression",
train_loss_col='train_loss', # Changed from 'train_loss_epoch'
val_loss_col='val_loss' # Check your logged metric names
)
else:
logger.warning(f"Classic Run: Could not find metrics.csv at {metrics_file} for loss curve plot.")
except Exception as plot_e:
logger.error(f"Classic Run: Failed to generate loss curve plot: {plot_e}", exc_info=True)
# --- End Classic Loss Plotting ---
else:
logger.warning("Classic training run did not produce metrics.")
except Exception as e:
logger.error(f"Error during classic training run: {e}", exc_info=True)
else:
logger.info("Skipping Classic training run as per config.")
# --- Final Logging Summary ---
logger.info("--- Final Summary ---")
# Log aggregated CV results if they exist
if 'aggregated_test_metrics' in final_results and final_results['aggregated_test_metrics']:
logger.info("--- Aggregated Cross-Validation Test Results ---")
for metric, stats in final_results['aggregated_test_metrics'].items():
logger.info(f"{metric}: {stats.get('mean', np.nan):.4f} ± {stats.get('std', np.nan):.4f}")
elif config.run_cross_validation:
logger.warning("Cross-validation was run, but no metrics were aggregated.")
# Log aggregated ensemble results if they exist
if 'ensemble_results' in final_results and final_results['ensemble_results']:
logger.info("--- Aggregated Ensemble Test Results (Mean over Test Folds) ---")
agg_ensemble = {}
for fold_res in final_results['ensemble_results'].values():
if isinstance(fold_res, dict):
for method, metrics in fold_res.items():
if method not in agg_ensemble: agg_ensemble[method] = {}
if isinstance(metrics, dict):
for m_name, m_val in metrics.items():
if m_name not in agg_ensemble[method]: agg_ensemble[method][m_name] = []
agg_ensemble[method][m_name].append(m_val)
else: logger.warning(f"Skipping non-dict metrics for ensemble method '{method}'.")
else: logger.warning("Skipping non-dict fold result in ensemble aggregation.")
for method, metrics_data in agg_ensemble.items():
logger.info(f" Ensemble Method: {method}")
for m_name, values in metrics_data.items():
valid_vals = [v for v in values if v is not None and not np.isnan(v)]
if valid_vals: logger.info(f" {m_name}: {np.mean(valid_vals):.4f} ± {np.std(valid_vals):.4f}")
else: logger.info(f" {m_name}: N/A")
# Log classic results if they exist
if 'classic_training_results' in final_results and final_results['classic_training_results']:
logger.info("--- Classic Training Test Results ---")
classic_res = final_results['classic_training_results']
for metric, value in classic_res.items():
logger.info(f"{metric}: {value:.4f}")
logger.info("-------------------------------------------------")
end_time = time.perf_counter()
logger.info(f"Training pipeline finished successfully in {end_time - start_time:.2f} seconds.")
# --- Main Execution ---
def run():
"""Main execution function."""
args = parse_arguments()
config_path = Path(args.config)
# --- Configuration Loading ---
try:
config = load_config(config_path, MainConfig)
except Exception:
# Error already logged in load_config
sys.exit(1)
# --- Setup based on Config ---
# 1. Set Log Level
log_level_name = config.log_level.upper()
log_level = getattr(logging, log_level_name, logging.INFO)
logger.setLevel(log_level)
logger.info(f"Log level set to: {log_level_name}")
if log_level == logging.DEBUG:
logger.debug("# --- Debug mode enabled via config. --- #")
# 2. Set Seed
set_seeds(config.random_seed)
# 3. Determine Output Directory
output_dir = Path(config.output_dir)
# --- Pipeline Execution ---
try:
run_training_pipeline(config, output_dir)
except SystemExit as e:
logger.warning(f"Pipeline exited with code {e.code}.")
sys.exit(e.code)
except Exception as e:
logger.critical(f"An critical error occurred during pipeline execution: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
raise DeprecationWarning(
"This was the intial class for training, is not maintained!\n Exiting...."
)
exit(-9999)