""" Classic training routine: Train on initial data segment, validate and test on final segments. """ import logging import time from pathlib import Path import pandas as pd import torch import yaml import pytorch_lightning as pl from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor from pytorch_lightning.loggers import CSVLogger from typing import Dict, Optional from forecasting_model.utils.forecast_config_model import MainConfig from forecasting_model.data_processing import prepare_fold_data_and_loaders, split_data_classic from forecasting_model.train.model import LSTMForecastLightningModule from forecasting_model.evaluation import evaluate_fold_predictions from forecasting_model.utils.helper import save_results from forecasting_model.io.plotting import plot_loss_curve_from_csv logger = logging.getLogger(__name__) def run_classic_training( config: MainConfig, full_df: pd.DataFrame, output_base_dir: Path ) -> Optional[Dict[str, float]]: """ Runs a single training pipeline using a classic train/val/test split. Args: config: The main configuration object. full_df: The complete raw DataFrame. output_base_dir: The base directory where general outputs are saved. Classic results will be saved in a subdirectory. Returns: A dictionary containing test metrics (e.g., {'MAE': ..., 'RMSE': ...}) for the classic run, or None if it fails. """ run_start_time = time.perf_counter() logger.info("--- Starting Classic Training Run ---") # Define a specific output directory for this run classic_output_dir = output_base_dir / "classic_run" classic_output_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Classic run outputs will be saved to: {classic_output_dir}") test_metrics: Optional[Dict[str, float]] = None best_val_score: Optional[float] = None best_model_path: Optional[str] = None try: # --- Data Splitting --- logger.info("Splitting data into classic train/val/test sets...") n_samples = len(full_df) val_frac = config.cross_validation.val_size_fraction test_frac = config.cross_validation.test_size_fraction train_idx, val_idx, test_idx = split_data_classic(n_samples, val_frac, test_frac) # Store test datetime index for evaluation plotting test_datetime_index = full_df.iloc[test_idx].index # --- Data Preparation --- logger.info("Preparing data loaders for the classic split...") train_loader, val_loader, test_loader, target_scaler, input_size = prepare_fold_data_and_loaders( full_df=full_df, train_idx=train_idx, val_idx=val_idx, test_idx=test_idx, target_col=config.data.target_col, feature_config=config.features, train_config=config.training, eval_config=config.evaluation ) logger.info(f"Data loaders prepared. Input size determined: {input_size}") # Save artifacts specific to this run if needed (e.g., for later inference) torch.save(test_loader, classic_output_dir / "classic_test_loader.pt") torch.save(target_scaler, classic_output_dir / "classic_target_scaler.pt") torch.save(input_size, classic_output_dir / "classic_input_size.pt") # Save config for this run try: config_dump = config.model_dump() except AttributeError: config_dump = config.model_dump() with open(classic_output_dir / "config.yaml", 'w') as f: yaml.dump(config_dump, f, default_flow_style=False) # --- Model Initialization --- model = LSTMForecastLightningModule( model_config=config.model, train_config=config.training, input_size=input_size, target_scaler=target_scaler ) logger.info("Classic LSTMForecastLightningModule initialized.") # --- PyTorch Lightning Callbacks --- monitor_metric = "val_MeanAbsoluteError" # Monitor same metric as CV folds monitor_mode = "min" early_stop_callback = None if config.training.early_stopping_patience is not None and config.training.early_stopping_patience > 0: early_stop_callback = EarlyStopping( monitor=monitor_metric, min_delta=0.0001, patience=config.training.early_stopping_patience, verbose=True, mode=monitor_mode ) logger.info(f"Enabled EarlyStopping: monitor='{monitor_metric}', patience={config.training.early_stopping_patience}") checkpoint_callback = ModelCheckpoint( dirpath=classic_output_dir / "checkpoints", filename="best_classic_model", # Simple filename save_top_k=1, monitor=monitor_metric, mode=monitor_mode, verbose=True ) logger.info(f"Enabled ModelCheckpoint: monitor='{monitor_metric}', mode='{monitor_mode}'") lr_monitor = LearningRateMonitor(logging_interval='epoch') callbacks = [checkpoint_callback, lr_monitor] if early_stop_callback: callbacks.append(early_stop_callback) # --- PyTorch Lightning Logger --- pl_logger = CSVLogger(save_dir=str(classic_output_dir), name="training_logs") logger.info(f"Using CSVLogger, logs will be saved in: {pl_logger.log_dir}") # --- PyTorch Lightning Trainer --- accelerator = 'gpu' if torch.cuda.is_available() else 'cpu' devices = 1 if accelerator == 'gpu' else None precision = getattr(config.training, 'precision', 32) trainer = pl.Trainer( accelerator=accelerator, devices=devices, max_epochs=config.training.epochs, callbacks=callbacks, logger=pl_logger, log_every_n_steps=max(1, len(train_loader)//10), enable_progress_bar=True, gradient_clip_val=getattr(config.training, 'gradient_clip_val', None), precision=precision, ) logger.info(f"Initialized PyTorch Lightning Trainer: accelerator='{accelerator}', devices={devices}, precision={precision}") # --- Training --- logger.info("Starting classic model training...") trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) logger.info("Classic model training finished.") # Store best validation score and path best_val_score_tensor = trainer.checkpoint_callback.best_model_score best_model_path = trainer.checkpoint_callback.best_model_path best_val_score = best_val_score_tensor.item() if best_val_score_tensor is not None else None if best_val_score is not None: logger.info(f"Best validation score ({monitor_metric}): {best_val_score:.4f}") logger.info(f"Best model checkpoint path: {best_model_path}") else: logger.warning(f"Could not retrieve best validation score/path (metric: {monitor_metric}). Evaluation might use last model.") best_model_path = None # --- Prediction on Test Set --- logger.info("Starting prediction on classic test set using best checkpoint...") prediction_results_list = trainer.predict( ckpt_path=best_model_path if best_model_path else 'last', dataloaders=test_loader ) # --- Evaluation --- if not prediction_results_list: logger.error("Predict phase did not return any results for classic run.") test_metrics = None else: try: # Shapes: (n_samples, len(horizons)) all_preds_scaled = torch.cat([b['preds_scaled'] for b in prediction_results_list], dim=0).numpy() n_predictions = len(all_preds_scaled) # Number of samples actually predicted if 'targets_scaled' in prediction_results_list[0]: all_targets_scaled = torch.cat([b['targets_scaled'] for b in prediction_results_list], dim=0).numpy() if len(all_targets_scaled) != n_predictions: logger.error(f"Classic Run: Mismatch between number of predictions ({n_predictions}) and targets ({len(all_targets_scaled)}).") raise ValueError("Prediction and target count mismatch during classic evaluation.") else: raise ValueError("Targets missing from prediction results.") logger.info(f"Processing {n_predictions} prediction results for classic test set...") # --- Calculate Correct Time Index for Plotting (First Horizon) --- target_time_index_for_plotting = None if test_idx is not None and config.features.forecast_horizon: try: test_block_index = full_df.index[test_idx] # Use the test_idx from classic split seq_len = config.features.sequence_length first_horizon = config.features.forecast_horizon[0] start_offset = seq_len + first_horizon - 1 if start_offset < len(test_block_index): end_index = min(start_offset + n_predictions, len(test_block_index)) target_time_index_for_plotting = test_block_index[start_offset:end_index] if len(target_time_index_for_plotting) != n_predictions: logger.warning(f"Classic Run: Calculated target time index length ({len(target_time_index_for_plotting)}) " f"does not match prediction count ({n_predictions}). Plotting x-axis might be misaligned.") target_time_index_for_plotting = None else: logger.warning(f"Classic Run: Cannot calculate target time index, start offset ({start_offset}) " f"exceeds test block length ({len(test_block_index)}).") except Exception as e: logger.error(f"Classic Run: Error calculating target time index for plotting: {e}", exc_info=True) target_time_index_for_plotting = None # Ensure it's None if error occurs else: logger.warning(f"Classic Run: Skipping target time index calculation (missing test_idx or forecast_horizon).") # --- End Index Calculation --- # Use the classic run specific objects and config test_metrics = evaluate_fold_predictions( y_true_scaled=all_targets_scaled, y_pred_scaled=all_preds_scaled, target_scaler=target_scaler, eval_config=config.evaluation, fold_num=-1, # Indicate classic run output_dir=str(classic_output_dir), plot_subdir="plots", prediction_time_index=target_time_index_for_plotting, # Pass the correctly calculated index forecast_horizons=config.features.forecast_horizon, plot_title_prefix="Classic Run" ) # Save metrics save_results({"overall_metrics": test_metrics}, classic_output_dir / "test_metrics.json") logger.info(f"Classic run test metrics (overall): {test_metrics}") # --- Plot Loss Curve for Classic Run --- try: # Adjusted logic to find metrics.csv inside potential version_*/ directories classic_log_dir = classic_output_dir / "training_logs" metrics_file = None version_dirs = list(classic_log_dir.glob("version_*")) if version_dirs: # Assuming the latest version directory contains the relevant logs latest_version_dir = max(version_dirs, key=lambda p: p.stat().st_mtime) potential_metrics_file = latest_version_dir / "metrics.csv" if potential_metrics_file.is_file(): metrics_file = potential_metrics_file else: logger.warning(f"Classic Run: metrics.csv not found in latest version directory: {latest_version_dir}") else: # Fallback if no version_* directories exist (less common with CSVLogger) potential_metrics_file = classic_log_dir / "metrics.csv" if potential_metrics_file.is_file(): metrics_file = potential_metrics_file if metrics_file and metrics_file.is_file(): plot_loss_curve_from_csv( metrics_csv_path=metrics_file, output_path=classic_output_dir / "loss_curve.png", title="Classic Run Training Progression", train_loss_col='train_loss', # Changed from 'train_loss_epoch' val_loss_col='val_loss' # Keep as 'val_loss' ) logger.info(f"Generating loss curve for classic run from: {metrics_file}") else: logger.warning(f"Classic Run: Could not find metrics.csv in {classic_log_dir} or its version subdirectories for loss curve plot.") except Exception as plot_e: logger.error(f"Classic Run: Failed to generate loss curve plot: {plot_e}", exc_info=True) # --- End Classic Loss Plotting --- except (KeyError, ValueError, Exception) as e: logger.error(f"Error processing classic prediction results: {e}", exc_info=True) test_metrics = None except Exception as e: logger.error(f"An error occurred during the classic training pipeline: {e}", exc_info=True) test_metrics = None # Indicate failure finally: if torch.cuda.is_available(): torch.cuda.empty_cache() run_end_time = time.perf_counter() logger.info(f"--- Finished Classic Training Run in {run_end_time - run_start_time:.2f} seconds ---") return test_metrics