import argparse import logging import sys import warnings # Import the warnings module import copy # For deep copying config from pathlib import Path import time import pandas as pd import torch import yaml # Added for saving best config import optuna import pytorch_lightning as pl from pytorch_lightning.callbacks import EarlyStopping # Import necessary components from the forecasting_model package from forecasting_model.utils.forecast_config_model import MainConfig from forecasting_model.data_processing import ( prepare_fold_data_and_loaders, split_data_classic ) from forecasting_model.train.model import LSTMForecastLightningModule from forecasting_model.train.classic import run_classic_training # Import helper functions from forecasting_model_run.py from forecasting_model.utils.helper import load_config, set_seeds # Import the data processing functions from forecasting_model.data_processing import load_raw_data # --- Suppress specific PL warnings about logger=True with no logger --- # This is expected behavior in optuna_run.py where logger=False is intentional warnings.filterwarnings( "ignore", message=".*You called `self.log.*logger=True.*no logger configured.*", category=UserWarning, module="pytorch_lightning.core.module" ) # Silence overly verbose libraries if needed mpl_logger = logging.getLogger('matplotlib') mpl_logger.setLevel(logging.WARNING) pil_logger = logging.getLogger('PIL') pil_logger.setLevel(logging.WARNING) pl_logger = logging.getLogger('pytorch_lightning') pl_logger.setLevel(logging.WARNING) # --- Basic Logging Setup --- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)-7s - %(message)s', datefmt='%H:%M:%S') # Get the root logger logger = logging.getLogger() # --- Argument Parsing (Simplified) --- def parse_arguments(): """Parses command-line arguments for Optuna HPO.""" parser = argparse.ArgumentParser( description="Run Hyperparameter Optimization using Optuna.", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( '-c', '--config', type=str, default='forecasting_config.yaml', help="Path to the YAML configuration file containing HPO settings." ) parser.add_argument( '--output-dir', type=str, default=None, help="Override output directory specified in the configuration file." ) args = parser.parse_args() return args # --- Optuna Objective Function --- def objective( trial: optuna.Trial, base_config: MainConfig, df: pd.DataFrame, ) -> float: # Ensure it returns a float """ Optuna single-objective function using a classic train/val/test split. Returns: - Validation score from the classic split (minimize). """ logger.info(f"\n--- Starting Optuna Trial {trial.number} ---") trial_start_time = time.perf_counter() # Get HPO settings hpo_config = base_config.optuna # Use the specific metric name from config for validation monitoring validation_metric_monitor = hpo_config.metric_to_optimize # The objective will be based on the validation metric # Hardcode minimization for the single objective monitor_mode = "min" worst_value = float('inf') # Represents a poor validation score # --- 1. Suggest Hyperparameters --- try: # Deep copy the base config dictionary to modify for the trial trial_config_dict = copy.deepcopy(base_config.model_dump(mode='python')) # Use mode='python' for easier modification except Exception as e: logger.error(f"Trial {trial.number}: Failed to deep copy base configuration: {e}", exc_info=True) # Return worst value return worst_value # ----- Suggest Hyperparameters ----- # (Suggest parameters as before, modifying trial_config_dict) trial_config_dict['training']['learning_rate'] = trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True) trial_config_dict['training']['batch_size'] = trial.suggest_categorical('batch_size', [32, 64, 128, 256]) trial_config_dict['training']['loss_function'] = trial.suggest_categorical('loss_function', ['MSE', 'MAE']) trial_config_dict['model']['hidden_size'] = trial.suggest_int('hidden_size', 18, 498, step=32) trial_config_dict['model']['num_layers'] = trial.suggest_int('num_layers', 1, 8) trial_config_dict['model']['dropout'] = trial.suggest_float('dropout', 0.0, 0.25, step=0.05) trial_config_dict['features']['sequence_length'] = trial.suggest_int('sequence_length', 24, 168, step=12) # Note: forecast_horizon is NOT tuned here, taken from base_config trial_config_dict['features']['scaling_method'] = trial.suggest_categorical('scaling_method', ['standard', 'minmax', None]) use_configured_lags = trial.suggest_categorical('use_lags', [True, False]) if not use_configured_lags: trial_config_dict['features']['lags'] = [] use_configured_rolling = trial.suggest_categorical('use_rolling_windows', [True, False]) if not use_configured_rolling: trial_config_dict['features']['rolling_window_sizes'] = [] trial_config_dict['features']['use_time_features'] = trial.suggest_categorical('use_time_features', [True, False]) trial_config_dict['features']['sinus_curve'] = trial.suggest_categorical('sinus_curve', [True, False]) trial_config_dict['features']['cosin_curve'] = trial.suggest_categorical('cosin_curve', [True, False]) trial_config_dict['features']['fill_nan'] = trial.suggest_categorical('fill_nan', ['ffill', 'bfill', 0]) # ----- End of Hyperparameter Suggestions ----- # --- 2. Re-validate Trial Config --- try: # Transfer forecast horizon from base config (not tuned) trial_config_dict['features']['forecast_horizon'] = base_config.features.forecast_horizon trial_config = MainConfig(**trial_config_dict) logger.info(f"Trial {trial.number} Parameters: {trial.params}") except Exception as e: logger.error(f"Trial {trial.number}: Invalid configuration generated: {e}") return worst_value # --- Early check for invalid sequence length / forecast horizon combination --- try: # Make sure forecast_horizon is a list before checking max() if not isinstance(trial_config.features.forecast_horizon, list) or not trial_config.features.forecast_horizon: raise ValueError("Trial config has invalid forecast_horizon list.") min_data_for_sequence = trial_config.features.sequence_length + max(trial_config.features.forecast_horizon) if min_data_for_sequence > len(df): logger.warning(f"Trial {trial.number}: Skipped. sequence_length ({trial_config.features.sequence_length}) + " f"max_horizon ({max(trial_config.features.forecast_horizon)}) " f"exceeds data length ({len(df)}).") # Optuna doesn't directly support skipping, so return worst values return worst_value except Exception as e: logger.error(f"Trial {trial.number}: Error during pre-check: {e}", exc_info=True) return worst_value # --- 3. Run Classic Train/Test --- logger.info(f"Trial {trial.number}: Starting Classic Run...") validation_metric_value = worst_value # Initialize to worst try: n_samples = len(df) val_frac = trial_config.cross_validation.val_size_fraction test_frac = trial_config.cross_validation.test_size_fraction train_idx_cl, val_idx_cl, test_idx_cl = split_data_classic(n_samples, val_frac, test_frac) # Prepare data for classic split train_loader_cl, val_loader_cl, test_loader_cl, target_scaler_cl, input_size_cl = prepare_fold_data_and_loaders( full_df=df, train_idx=train_idx_cl, val_idx=val_idx_cl, test_idx=test_idx_cl, target_col=trial_config.data.target_col, feature_config=trial_config.features, train_config=trial_config.training, eval_config=trial_config.evaluation ) # Initialize Model model_cl = LSTMForecastLightningModule( model_config=trial_config.model, train_config=trial_config.training, input_size=input_size_cl, target_scaler=target_scaler_cl ) # Callbacks (EarlyStopping and Pruning) callbacks_cl = [] if trial_config.training.early_stopping_patience and trial_config.training.early_stopping_patience > 0: callbacks_cl.append(EarlyStopping( monitor=validation_metric_monitor, # Monitor validation metric patience=trial_config.training.early_stopping_patience, mode=monitor_mode, verbose=False )) # Trainer for classic run trainer_cl = pl.Trainer( accelerator='gpu' if torch.cuda.is_available() else 'cpu', devices=1 if torch.cuda.is_available() else None, max_epochs=trial_config.training.epochs, callbacks=callbacks_cl, logger=False, # logger=False as per original enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False, gradient_clip_val=getattr(trial_config.training, 'gradient_clip_val', None), precision=getattr(trial_config.training, 'precision', 32), ) # --- Train Model --- logger.info(f"Trial {trial.number}: Fitting model on classic train/val split...") trainer_cl.fit(model_cl, train_dataloaders=train_loader_cl, val_dataloaders=val_loader_cl) # --- Get Best Validation Score --- # Check early stopping callback first if it exists best_score_tensor = None if callbacks_cl and isinstance(callbacks_cl[0], EarlyStopping): if hasattr(callbacks_cl[0], 'best_score') and callbacks_cl[0].best_score is not None: best_score_tensor = callbacks_cl[0].best_score elif callbacks_cl[0].stopped_epoch > 0 : # Early stopping triggered logger.debug(f"Trial {trial.number}: Early stopping triggered, attempting to use last callback metric.") # If early stopping didn't capture best score, use last metrics from trainer if best_score_tensor is None: metric_val = trainer_cl.callback_metrics.get(validation_metric_monitor) if metric_val is not None: best_score_tensor = metric_val # Use the last logged value if best_score_tensor is None: logger.warning(f"Trial {trial.number}: Metric '{validation_metric_monitor}' not found in callbacks or metrics. Using {worst_value}.") validation_metric_value = worst_value else: validation_metric_value = best_score_tensor.item() logger.info(f"Trial {trial.number}: Best val score ({validation_metric_monitor}) = {validation_metric_value:.4f}") # Report intermediate value for pruning (if enabled) trial.report(validation_metric_value, trainer_cl.current_epoch) if trial.should_prune(): logger.info(f"Trial {trial.number}: Pruned.") raise optuna.TrialPruned() # Note: We don't run prediction/evaluation on the test set here, # as the objective is based on validation performance. # The test set evaluation will be done later for the best trial. logger.info(f"Trial {trial.number}: Finished Classic Run in {time.perf_counter() - trial_start_time:.2f}s") except optuna.TrialPruned: # Propagate prune signal, objective will be set to worst later by Optuna raise except Exception as e: logger.error(f"Trial {trial.number}: Failed during classic run phase: {e}", exc_info=True) validation_metric_value = worst_value # Assign worst value if classic run fails finally: # Clean up GPU memory after the run del model_cl, trainer_cl, train_loader_cl, val_loader_cl, test_loader_cl if torch.cuda.is_available(): torch.cuda.empty_cache() # --- 4. Return Objective --- logger.info(f"--- Trial {trial.number}: Finished ---") logger.info(f" Objective (Validation {validation_metric_monitor}): {validation_metric_value:.5f}") logger.info(f" Total time: {time.perf_counter() - trial_start_time:.2f}s") # Return the single objective (validation metric) return float(validation_metric_value) # --- Main HPO Execution --- def run_hpo(): """Main execution function for HPO.""" args = parse_arguments() config_path = Path(args.config) try: base_config = load_config(config_path) # Load base config once logger.info(f"Successfully loaded configuration from {config_path}") except Exception as e: logger.critical(f"Failed to load configuration from {config_path}: {e}", exc_info=True) sys.exit(1) # Setup output dir... if args.output_dir: hpo_base_output_dir = Path(args.output_dir) # Use specific name for HPO dir logger.info(f"Using HPO output directory from command line: {hpo_base_output_dir}") elif base_config.optuna.storage and base_config.optuna.storage.startswith("sqlite:///"): hpo_base_output_dir = Path(base_config.optuna.storage.replace("sqlite:///", "")).parent logger.info(f"Using HPO output directory from Optuna storage path: {hpo_base_output_dir}") else: # Use output_dir from main config if available, otherwise default main_output_dir = Path(getattr(base_config, 'output_dir', 'output')) hpo_base_output_dir = main_output_dir / 'hpo_results' logger.info(f"Using HPO output directory: {hpo_base_output_dir}") hpo_base_output_dir.mkdir(parents=True, exist_ok=True) # Setup logging... (ensure file handler uses hpo_base_output_dir) try: level_name = base_config.log_level.upper() effective_log_level = logging.getLevelName(level_name) log_file = hpo_base_output_dir / f"{base_config.optuna.study_name}_hpo.log" file_handler = logging.FileHandler(log_file, mode='a') formatter = logging.Formatter('%(asctime)s - %(name)-25s - %(levelname)-7s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') file_handler.setFormatter(formatter) # Add handler only if it's not already added (e.g., if run_hpo is called multiple times) if not any(isinstance(h, logging.FileHandler) and h.baseFilename == str(log_file.resolve()) for h in logger.handlers): logger.addHandler(file_handler) logger.setLevel(effective_log_level) logger.info(f"Set log level to {level_name}. Logging HPO run to console and {log_file}") if effective_log_level <= logging.DEBUG: logger.debug("Debug logging enabled.") except (AttributeError, ValueError) as e: logger.warning(f"Could not set log level from config. Defaulting to INFO. Error: {e}") logger.setLevel(logging.INFO) # Setup seeding... try: seed = base_config.random_seed set_seeds(seed) except AttributeError: logger.warning("Config missing 'random_seed'. Using default seed 42.") set_seeds(42) # Load data... try: logger.info("Loading base dataset for HPO...") df = load_raw_data(base_config.data) logger.info(f"Base dataset loaded. Shape: {df.shape}") except Exception as e: logger.critical(f"Failed to load raw data for HPO: {e}", exc_info=True) sys.exit(1) # --- Optuna Study Setup --- try: hpo_config = base_config.optuna if not hpo_config.enabled: logger.info("Optuna optimization is disabled in the configuration.") sys.exit(0) except AttributeError: logger.critical("Optuna configuration section ('optuna') missing.") sys.exit(1) storage_path = hpo_config.storage if storage_path and storage_path.startswith("sqlite:///"): db_path = Path(storage_path.replace("sqlite:///", "")) if not db_path.is_absolute(): db_path = hpo_base_output_dir / db_path # Relative to HPO output dir db_path.parent.mkdir(parents=True, exist_ok=True) storage_path = f"sqlite:///{db_path.resolve()}" logger.info(f"Using Optuna storage: {storage_path}") elif storage_path: logger.info(f"Using Optuna storage: {storage_path} (non-SQLite)") else: logger.warning("No Optuna storage DB specified in config, using in-memory storage.") try: # Change to single objective 'minimize' study = optuna.create_study( study_name=hpo_config.study_name, storage=storage_path, direction="minimize", # Changed to single direction load_if_exists=True, pruner=optuna.pruners.MedianPruner() if hpo_config.pruning else optuna.pruners.NopPruner() ) # Remove multi-objective check/attribute setting # if not study._is_multi_objective: # logger.warning(f"Study '{hpo_config.study_name}' exists but is not multi-objective.") # elif 'objective_names' not in study.user_attrs: # study.set_user_attr('objective_names', objective_names) # --- Run Optimization --- logger.info(f"Starting Optuna single-objective optimization: study='{hpo_config.study_name}', n_trials={hpo_config.n_trials}") # Updated log message study.optimize( lambda trial: objective(trial, base_config, df), # Pass base_config n_trials=hpo_config.n_trials, timeout=None, gc_after_trial=True ) # --- Report and Process Best Trial --- logger.info("--- Optuna HPO Finished ---") logger.info(f"Number of finished trials: {len(study.trials)}") # Get the single best trial best_trial = study.best_trial if best_trial is None: logger.warning("Optuna study finished, but no successful trial was completed.") else: logger.info(f"Best trial found (Trial {best_trial.number}):") # Log details for the best trial validation_metric_monitor = base_config.optuna.metric_to_optimize logger.info(f" Objective ({validation_metric_monitor}): {best_trial.value:.5f}") # Use .value for single objective logger.info(f" Hyperparameters:") for key, value in best_trial.params.items(): logger.info(f" {key}: {value}") # --- Re-run and Save Artifacts for the Best Trial --- logger.info(f"-> Re-running Best Trial {best_trial.number} to save artifacts...") trial_output_dir = hpo_base_output_dir / f"best_trial_num{best_trial.number}" # Simplified directory name trial_output_dir.mkdir(parents=True, exist_ok=True) try: # 1. Create config for this trial best_config_dict = copy.deepcopy(base_config.model_dump(mode='python')) # Update with trial's hyperparameters for key, value in best_trial.params.items(): # Use best_trial.params # Need to place the param in the correct nested config dict if key in best_config_dict['training']: best_config_dict['training'][key] = value elif key in best_config_dict['model']: best_config_dict['model'][key] = value elif key in best_config_dict['features']: best_config_dict['features'][key] = value # Add more sections if HPO tunes them # Add non-tuned features forecast_horizon back best_config_dict['features']['forecast_horizon'] = base_config.features.forecast_horizon # Ensure evaluation plots and model saving are enabled for this final run best_config_dict['evaluation']['save_plots'] = True best_config_dict['training']['save_model'] = True # Assuming you want to save the best model # Validate the final config for this trial best_trial_config = MainConfig(**best_config_dict) # Save the specific config used for this run with open(trial_output_dir / "best_config.yaml", 'w') as f: yaml.dump(best_config_dict, f, default_flow_style=False, sort_keys=False) # 2. Run classic training (which saves model & plots) logger.info(f"-> Running classic training for Best Trial {best_trial.number}...") # Pass the specific config and output directory run_classic_training( config=best_trial_config, full_df=df, output_base_dir=trial_output_dir # outputs -> hpo_results/best_trial_num/classic_run/ ) logger.info(f"-> Finished re-running and saving artifacts for Best Trial {best_trial.number} to {trial_output_dir}") except Exception as e: logger.error(f"-> Failed to re-run or save artifacts for Best Trial {best_trial.number}: {e}", exc_info=True) # --- Save Best Hyperparameters --- best_params_file = hpo_base_output_dir / f"{hpo_config.study_name}_best_params.json" # Simplified filename try: with open(best_params_file, 'w') as f: import json json.dump(best_trial.params, f, indent=4) # Use best_trial.params logger.info(f"Hyperparameters of the best trial saved to {best_params_file}") except Exception as e: logger.error(f"Failed to save parameters for best trial: {e}") except Exception as e: logger.critical(f"A critical error occurred during the Optuna study: {e}", exc_info=True) sys.exit(1) if __name__ == "__main__": run_hpo()