Files
entrix_case_challange/optuna_run.py
2025-05-02 14:36:19 +02:00

395 lines
17 KiB
Python

import argparse
import logging
import sys
import copy # For deep copying config
from pathlib import Path
import time
import numpy as np
import pandas as pd
import torch
import optuna
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
# Import the Optuna callback for pruning
from optuna.integration.pytorch_lightning import PyTorchLightningPruningCallback
# Import necessary components from the forecasting_model package
from forecasting_model.utils.config_model import MainConfig
from forecasting_model.data_processing import (
load_raw_data,
TimeSeriesCrossValidationSplitter,
prepare_fold_data_and_loaders
)
from forecasting_model.model import LSTMForecastLightningModule
# We don't need evaluation functions here, Optuna optimizes based on validation metrics
# from forecasting_model.evaluation import ...
from typing import Dict, List, Any, Optional
# Import helper functions from forecasting_model.py (or move them to a shared utils file)
# For now, let's redefine simplified versions or assume they exist in utils
from forecasting_model_run import load_config, set_seeds # Assuming these are accessible
# Silence overly verbose libraries if needed
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)
pil_logger = logging.getLogger('PIL')
pil_logger.setLevel(logging.WARNING)
pl_logger = logging.getLogger('pytorch_lightning')
pl_logger.setLevel(logging.INFO) # Keep PL logs, but maybe set higher later
# --- Basic Logging Setup ---
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)-25s - %(levelname)-7s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
root_logger = logging.getLogger()
logger = logging.getLogger(__name__) # Logger for this script
optuna_lg = logging.getLogger('optuna') # Optuna's logger
# --- Argument Parsing ---
def parse_arguments():
"""Parses command-line arguments for Optuna HPO."""
parser = argparse.ArgumentParser(
description="Run Hyperparameter Optimization using Optuna for Time Series Forecasting.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
'-c', '--config',
type=str,
default='config.yaml',
help="Path to the BASE YAML configuration file."
)
parser.add_argument(
'--output-dir',
type=str,
default='output/hpo_results',
help="Directory for saving Optuna study database and potentially best trial info."
)
parser.add_argument(
'--study-name',
type=str,
default='lstm_forecasting_hpo',
help="Name for the Optuna study."
)
parser.add_argument(
'--n-trials',
type=int,
default=20,
help="Number of Optuna trials to run."
)
parser.add_argument(
'--storage-db',
type=str,
default=None, # Default to in-memory if not specified
help="Optuna storage database URL (e.g., 'sqlite:///output/hpo_results/study.db'). If None, uses in-memory storage."
)
parser.add_argument(
'--metric-to-optimize',
type=str,
default='val_mae_orig_scale',
help="Metric logged during validation to optimize (must match metric name in LightningModule)."
)
parser.add_argument(
'--direction',
type=str,
default='minimize',
choices=['minimize', 'maximize'],
help="Direction for Optuna optimization."
)
parser.add_argument(
'--pruning',
action='store_true',
help="Enable Optuna's trial pruning based on intermediate validation results."
)
parser.add_argument(
'--seed',
type=int,
default=42, # Fixed seed for the HPO process itself
help="Random seed for the main HPO script."
)
parser.add_argument(
'--debug',
action='store_true',
help="Override log level to DEBUG."
)
args = parser.parse_args()
return args
# --- Optuna Objective Function ---
def objective(
trial: optuna.Trial,
base_config: MainConfig, # Pass the loaded base config
df: pd.DataFrame, # Pass the loaded data
output_base_dir: Path, # Base dir for any potential trial artifacts (usually avoid saving checkpoints here)
metric_to_optimize: str,
enable_pruning: bool
) -> float:
"""
Optuna objective function. Trains and evaluates one set of hyperparameters
using cross-validation and returns the average validation metric.
"""
logger.info(f"\n--- Starting Optuna Trial {trial.number} ---")
trial_start_time = time.perf_counter()
# --- 1. Suggest Hyperparameters ---
# Make a deep copy of the base config to modify for this trial
# Using dict conversion and back might be easier than Pydantic's copy for deep nested updates
try:
trial_config_dict = copy.deepcopy(base_config.dict()) # Convert to dict for easier modification
except Exception as e:
logger.error(f"Failed to deep copy base configuration: {e}")
raise # Cannot proceed without config
# Suggest values for hyperparameters we want to tune
# Example suggestions (adjust ranges and types as needed):
trial_config_dict['training']['learning_rate'] = trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True)
trial_config_dict['training']['batch_size'] = trial.suggest_categorical('batch_size', [32, 64, 128])
trial_config_dict['model']['hidden_size'] = trial.suggest_int('hidden_size', 32, 256, step=32)
trial_config_dict['model']['num_layers'] = trial.suggest_int('num_layers', 1, 4)
trial_config_dict['model']['dropout'] = trial.suggest_float('dropout', 0.0, 0.5, step=0.1)
# Example: Suggest sequence length? (Requires careful handling as it affects data prep)
# trial_config_dict['features']['sequence_length'] = trial.suggest_int('sequence_length', 24, 168, step=24)
# --- 2. Re-validate Trial Config (Optional but Recommended) ---
try:
trial_config = MainConfig(**trial_config_dict)
logger.debug(f"Trial {trial.number} Config: {trial_config.training} {trial_config.model} {trial_config.features}")
except Exception as e:
logger.error(f"Trial {trial.number}: Invalid configuration generated from suggested parameters: {e}")
# Return a high value (for minimization) to penalize invalid configs
return float('inf')
# --- 3. Run Cross-Validation for this Trial ---
cv_splitter = TimeSeriesCrossValidationSplitter(trial_config.cross_validation, len(df))
fold_best_val_metrics: List[Optional[float]] = []
for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split()):
fold_id = fold_num + 1
logger.info(f"Trial {trial.number}, Fold {fold_id}: Starting fold evaluation.")
fold_start_time = time.perf_counter()
# Create a temporary directory for this specific trial+fold if needed (usually avoid for HPO)
# fold_trial_dir = output_base_dir / f"trial_{trial.number}" / f"fold_{fold_id:02d}"
# fold_trial_dir.mkdir(parents=True, exist_ok=True)
try:
# --- Per-Fold Data Prep ---
# Use trial_config for batch sizes etc.
train_loader, val_loader, _, target_scaler, input_size = prepare_fold_data_and_loaders(
full_df=df, train_idx=train_idx, val_idx=val_idx, test_idx=test_idx, # Test loader not needed here
target_col=trial_config.data.target_col,
feature_config=trial_config.features,
train_config=trial_config.training,
eval_config=trial_config.evaluation # Pass eval for batch size if needed by prep?
)
# --- Model Instantiation ---
current_model_config = trial_config.model.copy(update={'input_size': input_size,
'forecast_horizon': trial_config.features.forecast_horizon})
model = LSTMForecastLightningModule(
model_config=current_model_config,
train_config=trial_config.training,
target_scaler=target_scaler
)
# --- Callbacks for this Trial/Fold ---
# Monitor the metric Optuna cares about
monitor_mode = "min" if args.direction == "minimize" else "max"
callbacks = []
if trial_config.training.early_stopping_patience is not None and trial_config.training.early_stopping_patience > 0:
early_stopping = EarlyStopping(
monitor=metric_to_optimize,
patience=trial_config.training.early_stopping_patience,
mode=monitor_mode,
verbose=False # Less verbose during HPO
)
callbacks.append(early_stopping)
# Add Optuna Pruning Callback
if enable_pruning:
pruning_callback = PyTorchLightningPruningCallback(trial, monitor=metric_to_optimize)
callbacks.append(pruning_callback)
# Optional: LR Monitor
# callbacks.append(LearningRateMonitor(logging_interval='epoch'))
# --- Trainer for this Trial/Fold ---
trainer = pl.Trainer(
accelerator='gpu' if torch.cuda.is_available() else 'cpu',
devices=1 if torch.cuda.is_available() else None,
max_epochs=trial_config.training.epochs,
callbacks=callbacks,
logger=False, # Disable default PL logging during HPO
enable_checkpointing=False, # Disable checkpoint saving during HPO
enable_progress_bar=False, # Disable progress bar for cleaner logs
enable_model_summary=False, # Disable model summary
gradient_clip_val=getattr(trial_config.training, 'gradient_clip_val', None),
precision=getattr(trial_config.training, 'precision', 32),
# Log GPU usage if available?
# log_gpu_memory='min_max',
)
# --- Fit the Model ---
logger.info(f"Trial {trial.number}, Fold {fold_id}: Fitting model...")
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
# --- Get Best Validation Score for Pruning/Reporting ---
# Access the monitored metric value from the trainer's logged metrics or callback state
# Ensure the key matches exactly what's logged in validation_step
best_val_score = trainer.callback_metrics.get(metric_to_optimize)
if best_val_score is None:
logger.warning(f"Trial {trial.number}, Fold {fold_id}: Metric '{metric_to_optimize}' not found in trainer metrics. Using inf/nan.")
# Handle cases where training might have failed or metric wasn't logged
best_val_score = float('inf') if monitor_mode == 'min' else float('-inf') # Return worst possible value
else:
best_val_score = best_val_score.item() # Convert tensor to float
logger.info(f"Trial {trial.number}, Fold {fold_id}: Best validation score ({metric_to_optimize}) = {best_val_score:.4f}")
fold_best_val_metrics.append(best_val_score)
# --- Intermediate Pruning Report (Optional but Recommended) ---
# Report the intermediate value (best score for this fold) to Optuna
# trial.report(best_val_score, fold_id) # Report score at step `fold_id`
# Check if the trial should be pruned based on reported values
# if trial.should_prune():
# logger.info(f"Trial {trial.number}: Pruned after fold {fold_id}.")
# raise optuna.TrialPruned()
logger.info(f"Trial {trial.number}, Fold {fold_id}: Finished in {time.perf_counter() - fold_start_time:.2f}s")
except optuna.TrialPruned:
# Re-raise prune exception to let Optuna handle it
raise
except Exception as e:
logger.error(f"Trial {trial.number}, Fold {fold_id}: Failed with error: {e}", exc_info=True)
# Record a failure for this fold (e.g., append NaN or worst value)
fold_best_val_metrics.append(float('inf') if monitor_mode == 'min' else float('-inf'))
# Optionally: Break the CV loop for this trial if one fold fails catastrophically?
# break
# --- 4. Calculate Average Metric Across Folds ---
if not fold_best_val_metrics:
logger.error(f"Trial {trial.number}: No validation results obtained across folds.")
return float('inf') # Return worst value
# Handle potential infinities or NaNs from failed folds
valid_scores = [s for s in fold_best_val_metrics if np.isfinite(s)]
if not valid_scores:
logger.error(f"Trial {trial.number}: All folds failed or produced non-finite scores.")
return float('inf')
average_val_metric = np.mean(valid_scores)
logger.info(f"--- Trial {trial.number}: Finished ---")
logger.info(f" Average validation {metric_to_optimize}: {average_val_metric:.5f}")
logger.info(f" Total trial time: {time.perf_counter() - trial_start_time:.2f}s")
# --- 5. Return Metric for Optuna ---
return average_val_metric
# --- Main HPO Execution ---
def run_hpo():
"""Main execution function for HPO."""
global args # Make args accessible in objective (simplifies passing) - or use functools.partial
args = parse_arguments()
config_path = Path(args.config)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) # Ensure output dir exists
# Adjust log level if debug flag is set
if args.debug:
root_logger.setLevel(logging.DEBUG)
optuna_lg.setLevel(logging.DEBUG)
pl_logger.setLevel(logging.DEBUG)
logger.debug("Debug mode enabled.")
else:
# Reduce verbosity during HPO runs
optuna_lg.setLevel(logging.WARNING)
pl_logger.setLevel(logging.INFO) # Keep INFO for PL start/end messages
# --- Configuration Loading ---
try:
base_config = load_config(config_path)
except Exception:
sys.exit(1)
# --- Seed Setting (for HPO script itself) ---
set_seeds(args.seed)
# --- Load Data Once ---
# Assume data doesn't change based on HPs (unless sequence_length is tuned heavily)
try:
logger.info("Loading base dataset...")
df = load_raw_data(base_config.data)
logger.info("Base dataset loaded.")
except Exception as e:
logger.critical(f"Failed to load raw data for HPO: {e}", exc_info=True)
sys.exit(1)
# --- Optuna Study Setup ---
storage_path = args.storage_db
if storage_path:
# Ensure directory exists if using SQLite file storage
db_path = Path(storage_path.replace("sqlite:///", ""))
db_path.parent.mkdir(parents=True, exist_ok=True)
storage_path = f"sqlite:///{db_path.resolve()}" # Use absolute path
logger.info(f"Using Optuna storage: {storage_path}")
else:
logger.warning("No Optuna storage DB specified, using in-memory storage (results lost on exit).")
try:
# Create or load the study
study = optuna.create_study(
study_name=args.study_name,
storage=storage_path,
direction=args.direction,
load_if_exists=True, # Load previous results if study exists
pruner=optuna.pruners.MedianPruner() if args.pruning else optuna.pruners.NopPruner() # Example pruner
)
# --- Run Optimization ---
logger.info(f"Starting Optuna optimization: study='{args.study_name}', n_trials={args.n_trials}, metric='{args.metric_to_optimize}', direction='{args.direction}'")
study.optimize(
lambda trial: objective(trial, base_config, df, output_dir, args.metric_to_optimize, args.pruning),
n_trials=args.n_trials,
timeout=None # Optional: Set timeout in seconds
# Optional: Add callbacks (e.g., logging callback)
)
# --- Report Best Trial ---
logger.info("--- Optuna HPO Finished ---")
logger.info(f"Number of finished trials: {len(study.trials)}")
best_trial = study.best_trial
logger.info(f"Best trial number: {best_trial.number}")
logger.info(f" Best validation {args.metric_to_optimize}: {best_trial.value:.5f}")
logger.info(" Best hyperparameters:")
for key, value in best_trial.params.items():
logger.info(f" {key}: {value}")
# --- Save Best Hyperparameters (Optional) ---
best_params_file = output_dir / f"{args.study_name}_best_params.json"
try:
with open(best_params_file, 'w') as f:
import json
json.dump(best_trial.params, f, indent=4)
logger.info(f"Best hyperparameters saved to {best_params_file}")
except Exception as e:
logger.error(f"Failed to save best parameters: {e}")
except Exception as e:
logger.critical(f"An critical error occurred during the Optuna study: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
run_hpo()