import argparse import logging import sys import copy # For deep copying config from pathlib import Path import time import numpy as np import pandas as pd import torch import optuna import pytorch_lightning as pl from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor # Import the Optuna callback for pruning from optuna.integration.pytorch_lightning import PyTorchLightningPruningCallback # Import necessary components from the forecasting_model package from forecasting_model.utils.config_model import MainConfig from forecasting_model.data_processing import ( load_raw_data, TimeSeriesCrossValidationSplitter, prepare_fold_data_and_loaders ) from forecasting_model.model import LSTMForecastLightningModule # We don't need evaluation functions here, Optuna optimizes based on validation metrics # from forecasting_model.evaluation import ... from typing import Dict, List, Any, Optional # Import helper functions from forecasting_model.py (or move them to a shared utils file) # For now, let's redefine simplified versions or assume they exist in utils from forecasting_model_run import load_config, set_seeds # Assuming these are accessible # Silence overly verbose libraries if needed mpl_logger = logging.getLogger('matplotlib') mpl_logger.setLevel(logging.WARNING) pil_logger = logging.getLogger('PIL') pil_logger.setLevel(logging.WARNING) pl_logger = logging.getLogger('pytorch_lightning') pl_logger.setLevel(logging.INFO) # Keep PL logs, but maybe set higher later # --- Basic Logging Setup --- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)-25s - %(levelname)-7s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') root_logger = logging.getLogger() logger = logging.getLogger(__name__) # Logger for this script optuna_lg = logging.getLogger('optuna') # Optuna's logger # --- Argument Parsing --- def parse_arguments(): """Parses command-line arguments for Optuna HPO.""" parser = argparse.ArgumentParser( description="Run Hyperparameter Optimization using Optuna for Time Series Forecasting.", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( '-c', '--config', type=str, default='config.yaml', help="Path to the BASE YAML configuration file." ) parser.add_argument( '--output-dir', type=str, default='output/hpo_results', help="Directory for saving Optuna study database and potentially best trial info." ) parser.add_argument( '--study-name', type=str, default='lstm_forecasting_hpo', help="Name for the Optuna study." ) parser.add_argument( '--n-trials', type=int, default=20, help="Number of Optuna trials to run." ) parser.add_argument( '--storage-db', type=str, default=None, # Default to in-memory if not specified help="Optuna storage database URL (e.g., 'sqlite:///output/hpo_results/study.db'). If None, uses in-memory storage." ) parser.add_argument( '--metric-to-optimize', type=str, default='val_mae_orig_scale', help="Metric logged during validation to optimize (must match metric name in LightningModule)." ) parser.add_argument( '--direction', type=str, default='minimize', choices=['minimize', 'maximize'], help="Direction for Optuna optimization." ) parser.add_argument( '--pruning', action='store_true', help="Enable Optuna's trial pruning based on intermediate validation results." ) parser.add_argument( '--seed', type=int, default=42, # Fixed seed for the HPO process itself help="Random seed for the main HPO script." ) parser.add_argument( '--debug', action='store_true', help="Override log level to DEBUG." ) args = parser.parse_args() return args # --- Optuna Objective Function --- def objective( trial: optuna.Trial, base_config: MainConfig, # Pass the loaded base config df: pd.DataFrame, # Pass the loaded data output_base_dir: Path, # Base dir for any potential trial artifacts (usually avoid saving checkpoints here) metric_to_optimize: str, enable_pruning: bool ) -> float: """ Optuna objective function. Trains and evaluates one set of hyperparameters using cross-validation and returns the average validation metric. """ logger.info(f"\n--- Starting Optuna Trial {trial.number} ---") trial_start_time = time.perf_counter() # --- 1. Suggest Hyperparameters --- # Make a deep copy of the base config to modify for this trial # Using dict conversion and back might be easier than Pydantic's copy for deep nested updates try: trial_config_dict = copy.deepcopy(base_config.dict()) # Convert to dict for easier modification except Exception as e: logger.error(f"Failed to deep copy base configuration: {e}") raise # Cannot proceed without config # Suggest values for hyperparameters we want to tune # Example suggestions (adjust ranges and types as needed): trial_config_dict['training']['learning_rate'] = trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True) trial_config_dict['training']['batch_size'] = trial.suggest_categorical('batch_size', [32, 64, 128]) trial_config_dict['model']['hidden_size'] = trial.suggest_int('hidden_size', 32, 256, step=32) trial_config_dict['model']['num_layers'] = trial.suggest_int('num_layers', 1, 4) trial_config_dict['model']['dropout'] = trial.suggest_float('dropout', 0.0, 0.5, step=0.1) # Example: Suggest sequence length? (Requires careful handling as it affects data prep) # trial_config_dict['features']['sequence_length'] = trial.suggest_int('sequence_length', 24, 168, step=24) # --- 2. Re-validate Trial Config (Optional but Recommended) --- try: trial_config = MainConfig(**trial_config_dict) logger.debug(f"Trial {trial.number} Config: {trial_config.training} {trial_config.model} {trial_config.features}") except Exception as e: logger.error(f"Trial {trial.number}: Invalid configuration generated from suggested parameters: {e}") # Return a high value (for minimization) to penalize invalid configs return float('inf') # --- 3. Run Cross-Validation for this Trial --- cv_splitter = TimeSeriesCrossValidationSplitter(trial_config.cross_validation, len(df)) fold_best_val_metrics: List[Optional[float]] = [] for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split()): fold_id = fold_num + 1 logger.info(f"Trial {trial.number}, Fold {fold_id}: Starting fold evaluation.") fold_start_time = time.perf_counter() # Create a temporary directory for this specific trial+fold if needed (usually avoid for HPO) # fold_trial_dir = output_base_dir / f"trial_{trial.number}" / f"fold_{fold_id:02d}" # fold_trial_dir.mkdir(parents=True, exist_ok=True) try: # --- Per-Fold Data Prep --- # Use trial_config for batch sizes etc. train_loader, val_loader, _, target_scaler, input_size = prepare_fold_data_and_loaders( full_df=df, train_idx=train_idx, val_idx=val_idx, test_idx=test_idx, # Test loader not needed here target_col=trial_config.data.target_col, feature_config=trial_config.features, train_config=trial_config.training, eval_config=trial_config.evaluation # Pass eval for batch size if needed by prep? ) # --- Model Instantiation --- current_model_config = trial_config.model.copy(update={'input_size': input_size, 'forecast_horizon': trial_config.features.forecast_horizon}) model = LSTMForecastLightningModule( model_config=current_model_config, train_config=trial_config.training, target_scaler=target_scaler ) # --- Callbacks for this Trial/Fold --- # Monitor the metric Optuna cares about monitor_mode = "min" if args.direction == "minimize" else "max" callbacks = [] if trial_config.training.early_stopping_patience is not None and trial_config.training.early_stopping_patience > 0: early_stopping = EarlyStopping( monitor=metric_to_optimize, patience=trial_config.training.early_stopping_patience, mode=monitor_mode, verbose=False # Less verbose during HPO ) callbacks.append(early_stopping) # Add Optuna Pruning Callback if enable_pruning: pruning_callback = PyTorchLightningPruningCallback(trial, monitor=metric_to_optimize) callbacks.append(pruning_callback) # Optional: LR Monitor # callbacks.append(LearningRateMonitor(logging_interval='epoch')) # --- Trainer for this Trial/Fold --- trainer = pl.Trainer( accelerator='gpu' if torch.cuda.is_available() else 'cpu', devices=1 if torch.cuda.is_available() else None, max_epochs=trial_config.training.epochs, callbacks=callbacks, logger=False, # Disable default PL logging during HPO enable_checkpointing=False, # Disable checkpoint saving during HPO enable_progress_bar=False, # Disable progress bar for cleaner logs enable_model_summary=False, # Disable model summary gradient_clip_val=getattr(trial_config.training, 'gradient_clip_val', None), precision=getattr(trial_config.training, 'precision', 32), # Log GPU usage if available? # log_gpu_memory='min_max', ) # --- Fit the Model --- logger.info(f"Trial {trial.number}, Fold {fold_id}: Fitting model...") trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) # --- Get Best Validation Score for Pruning/Reporting --- # Access the monitored metric value from the trainer's logged metrics or callback state # Ensure the key matches exactly what's logged in validation_step best_val_score = trainer.callback_metrics.get(metric_to_optimize) if best_val_score is None: logger.warning(f"Trial {trial.number}, Fold {fold_id}: Metric '{metric_to_optimize}' not found in trainer metrics. Using inf/nan.") # Handle cases where training might have failed or metric wasn't logged best_val_score = float('inf') if monitor_mode == 'min' else float('-inf') # Return worst possible value else: best_val_score = best_val_score.item() # Convert tensor to float logger.info(f"Trial {trial.number}, Fold {fold_id}: Best validation score ({metric_to_optimize}) = {best_val_score:.4f}") fold_best_val_metrics.append(best_val_score) # --- Intermediate Pruning Report (Optional but Recommended) --- # Report the intermediate value (best score for this fold) to Optuna # trial.report(best_val_score, fold_id) # Report score at step `fold_id` # Check if the trial should be pruned based on reported values # if trial.should_prune(): # logger.info(f"Trial {trial.number}: Pruned after fold {fold_id}.") # raise optuna.TrialPruned() logger.info(f"Trial {trial.number}, Fold {fold_id}: Finished in {time.perf_counter() - fold_start_time:.2f}s") except optuna.TrialPruned: # Re-raise prune exception to let Optuna handle it raise except Exception as e: logger.error(f"Trial {trial.number}, Fold {fold_id}: Failed with error: {e}", exc_info=True) # Record a failure for this fold (e.g., append NaN or worst value) fold_best_val_metrics.append(float('inf') if monitor_mode == 'min' else float('-inf')) # Optionally: Break the CV loop for this trial if one fold fails catastrophically? # break # --- 4. Calculate Average Metric Across Folds --- if not fold_best_val_metrics: logger.error(f"Trial {trial.number}: No validation results obtained across folds.") return float('inf') # Return worst value # Handle potential infinities or NaNs from failed folds valid_scores = [s for s in fold_best_val_metrics if np.isfinite(s)] if not valid_scores: logger.error(f"Trial {trial.number}: All folds failed or produced non-finite scores.") return float('inf') average_val_metric = np.mean(valid_scores) logger.info(f"--- Trial {trial.number}: Finished ---") logger.info(f" Average validation {metric_to_optimize}: {average_val_metric:.5f}") logger.info(f" Total trial time: {time.perf_counter() - trial_start_time:.2f}s") # --- 5. Return Metric for Optuna --- return average_val_metric # --- Main HPO Execution --- def run_hpo(): """Main execution function for HPO.""" global args # Make args accessible in objective (simplifies passing) - or use functools.partial args = parse_arguments() config_path = Path(args.config) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) # Ensure output dir exists # Adjust log level if debug flag is set if args.debug: root_logger.setLevel(logging.DEBUG) optuna_lg.setLevel(logging.DEBUG) pl_logger.setLevel(logging.DEBUG) logger.debug("Debug mode enabled.") else: # Reduce verbosity during HPO runs optuna_lg.setLevel(logging.WARNING) pl_logger.setLevel(logging.INFO) # Keep INFO for PL start/end messages # --- Configuration Loading --- try: base_config = load_config(config_path) except Exception: sys.exit(1) # --- Seed Setting (for HPO script itself) --- set_seeds(args.seed) # --- Load Data Once --- # Assume data doesn't change based on HPs (unless sequence_length is tuned heavily) try: logger.info("Loading base dataset...") df = load_raw_data(base_config.data) logger.info("Base dataset loaded.") except Exception as e: logger.critical(f"Failed to load raw data for HPO: {e}", exc_info=True) sys.exit(1) # --- Optuna Study Setup --- storage_path = args.storage_db if storage_path: # Ensure directory exists if using SQLite file storage db_path = Path(storage_path.replace("sqlite:///", "")) db_path.parent.mkdir(parents=True, exist_ok=True) storage_path = f"sqlite:///{db_path.resolve()}" # Use absolute path logger.info(f"Using Optuna storage: {storage_path}") else: logger.warning("No Optuna storage DB specified, using in-memory storage (results lost on exit).") try: # Create or load the study study = optuna.create_study( study_name=args.study_name, storage=storage_path, direction=args.direction, load_if_exists=True, # Load previous results if study exists pruner=optuna.pruners.MedianPruner() if args.pruning else optuna.pruners.NopPruner() # Example pruner ) # --- Run Optimization --- logger.info(f"Starting Optuna optimization: study='{args.study_name}', n_trials={args.n_trials}, metric='{args.metric_to_optimize}', direction='{args.direction}'") study.optimize( lambda trial: objective(trial, base_config, df, output_dir, args.metric_to_optimize, args.pruning), n_trials=args.n_trials, timeout=None # Optional: Set timeout in seconds # Optional: Add callbacks (e.g., logging callback) ) # --- Report Best Trial --- logger.info("--- Optuna HPO Finished ---") logger.info(f"Number of finished trials: {len(study.trials)}") best_trial = study.best_trial logger.info(f"Best trial number: {best_trial.number}") logger.info(f" Best validation {args.metric_to_optimize}: {best_trial.value:.5f}") logger.info(" Best hyperparameters:") for key, value in best_trial.params.items(): logger.info(f" {key}: {value}") # --- Save Best Hyperparameters (Optional) --- best_params_file = output_dir / f"{args.study_name}_best_params.json" try: with open(best_params_file, 'w') as f: import json json.dump(best_trial.params, f, indent=4) logger.info(f"Best hyperparameters saved to {best_params_file}") except Exception as e: logger.error(f"Failed to save best parameters: {e}") except Exception as e: logger.critical(f"An critical error occurred during the Optuna study: {e}", exc_info=True) sys.exit(1) if __name__ == "__main__": run_hpo()