intermediate backup
This commit is contained in:
605
optuna_run.py
605
optuna_run.py
@ -1,35 +1,44 @@
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
import warnings # Import the warnings module
|
||||
|
||||
import copy # For deep copying config
|
||||
from pathlib import Path
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import yaml # Added for saving best config
|
||||
|
||||
import optuna
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
|
||||
# Import the Optuna callback for pruning
|
||||
from optuna.integration.pytorch_lightning import PyTorchLightningPruningCallback
|
||||
from pytorch_lightning.callbacks import EarlyStopping
|
||||
|
||||
# Import necessary components from the forecasting_model package
|
||||
from forecasting_model.utils.config_model import MainConfig
|
||||
from forecasting_model.utils.forecast_config_model import MainConfig
|
||||
from forecasting_model.data_processing import (
|
||||
load_raw_data,
|
||||
TimeSeriesCrossValidationSplitter,
|
||||
prepare_fold_data_and_loaders
|
||||
prepare_fold_data_and_loaders,
|
||||
split_data_classic
|
||||
)
|
||||
from forecasting_model.model import LSTMForecastLightningModule
|
||||
# We don't need evaluation functions here, Optuna optimizes based on validation metrics
|
||||
# from forecasting_model.evaluation import ...
|
||||
from typing import Dict, List, Any, Optional
|
||||
from forecasting_model.train.model import LSTMForecastLightningModule
|
||||
from forecasting_model.train.classic import run_classic_training
|
||||
|
||||
# Import helper functions from forecasting_model.py (or move them to a shared utils file)
|
||||
# For now, let's redefine simplified versions or assume they exist in utils
|
||||
from forecasting_model_run import load_config, set_seeds # Assuming these are accessible
|
||||
|
||||
# Import helper functions from forecasting_model_run.py
|
||||
from forecasting_model.utils.helper import load_config, set_seeds
|
||||
|
||||
# Import the data processing functions
|
||||
from forecasting_model.data_processing import load_raw_data
|
||||
|
||||
# --- Suppress specific PL warnings about logger=True with no logger ---
|
||||
# This is expected behavior in optuna_run.py where logger=False is intentional
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=".*You called `self.log.*logger=True.*no logger configured.*",
|
||||
category=UserWarning,
|
||||
module="pytorch_lightning.core.module"
|
||||
)
|
||||
|
||||
# Silence overly verbose libraries if needed
|
||||
mpl_logger = logging.getLogger('matplotlib')
|
||||
@ -37,358 +46,396 @@ mpl_logger.setLevel(logging.WARNING)
|
||||
pil_logger = logging.getLogger('PIL')
|
||||
pil_logger.setLevel(logging.WARNING)
|
||||
pl_logger = logging.getLogger('pytorch_lightning')
|
||||
pl_logger.setLevel(logging.INFO) # Keep PL logs, but maybe set higher later
|
||||
pl_logger.setLevel(logging.WARNING)
|
||||
|
||||
# --- Basic Logging Setup ---
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(name)-25s - %(levelname)-7s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S')
|
||||
root_logger = logging.getLogger()
|
||||
logger = logging.getLogger(__name__) # Logger for this script
|
||||
optuna_lg = logging.getLogger('optuna') # Optuna's logger
|
||||
format='%(asctime)s - %(levelname)-7s - %(message)s',
|
||||
datefmt='%H:%M:%S')
|
||||
# Get the root logger
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
# --- Argument Parsing ---
|
||||
# --- Argument Parsing (Simplified) ---
|
||||
def parse_arguments():
|
||||
"""Parses command-line arguments for Optuna HPO."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run Hyperparameter Optimization using Optuna for Time Series Forecasting.",
|
||||
description="Run Hyperparameter Optimization using Optuna.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument(
|
||||
'-c', '--config',
|
||||
type=str,
|
||||
default='config.yaml',
|
||||
help="Path to the BASE YAML configuration file."
|
||||
'-c', '--config', type=str, default='forecasting_config.yaml',
|
||||
help="Path to the YAML configuration file containing HPO settings."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output-dir',
|
||||
type=str,
|
||||
default='output/hpo_results',
|
||||
help="Directory for saving Optuna study database and potentially best trial info."
|
||||
'--output-dir', type=str, default=None,
|
||||
help="Override output directory specified in the configuration file."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--study-name',
|
||||
type=str,
|
||||
default='lstm_forecasting_hpo',
|
||||
help="Name for the Optuna study."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--n-trials',
|
||||
type=int,
|
||||
default=20,
|
||||
help="Number of Optuna trials to run."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--storage-db',
|
||||
type=str,
|
||||
default=None, # Default to in-memory if not specified
|
||||
help="Optuna storage database URL (e.g., 'sqlite:///output/hpo_results/study.db'). If None, uses in-memory storage."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--metric-to-optimize',
|
||||
type=str,
|
||||
default='val_mae_orig_scale',
|
||||
help="Metric logged during validation to optimize (must match metric name in LightningModule)."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--direction',
|
||||
type=str,
|
||||
default='minimize',
|
||||
choices=['minimize', 'maximize'],
|
||||
help="Direction for Optuna optimization."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--pruning',
|
||||
action='store_true',
|
||||
help="Enable Optuna's trial pruning based on intermediate validation results."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--seed',
|
||||
type=int,
|
||||
default=42, # Fixed seed for the HPO process itself
|
||||
help="Random seed for the main HPO script."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--debug',
|
||||
action='store_true',
|
||||
help="Override log level to DEBUG."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
# --- Optuna Objective Function ---
|
||||
def objective(
|
||||
trial: optuna.Trial,
|
||||
base_config: MainConfig, # Pass the loaded base config
|
||||
df: pd.DataFrame, # Pass the loaded data
|
||||
output_base_dir: Path, # Base dir for any potential trial artifacts (usually avoid saving checkpoints here)
|
||||
metric_to_optimize: str,
|
||||
enable_pruning: bool
|
||||
) -> float:
|
||||
base_config: MainConfig,
|
||||
df: pd.DataFrame,
|
||||
) -> float: # Ensure it returns a float
|
||||
"""
|
||||
Optuna objective function. Trains and evaluates one set of hyperparameters
|
||||
using cross-validation and returns the average validation metric.
|
||||
Optuna single-objective function using a classic train/val/test split.
|
||||
|
||||
Returns:
|
||||
- Validation score from the classic split (minimize).
|
||||
"""
|
||||
logger.info(f"\n--- Starting Optuna Trial {trial.number} ---")
|
||||
trial_start_time = time.perf_counter()
|
||||
|
||||
# Get HPO settings
|
||||
hpo_config = base_config.optuna
|
||||
# Use the specific metric name from config for validation monitoring
|
||||
validation_metric_monitor = hpo_config.metric_to_optimize
|
||||
# The objective will be based on the validation metric
|
||||
# Hardcode minimization for the single objective
|
||||
monitor_mode = "min"
|
||||
worst_value = float('inf') # Represents a poor validation score
|
||||
|
||||
# --- 1. Suggest Hyperparameters ---
|
||||
# Make a deep copy of the base config to modify for this trial
|
||||
# Using dict conversion and back might be easier than Pydantic's copy for deep nested updates
|
||||
try:
|
||||
trial_config_dict = copy.deepcopy(base_config.dict()) # Convert to dict for easier modification
|
||||
# Deep copy the base config dictionary to modify for the trial
|
||||
trial_config_dict = copy.deepcopy(base_config.model_dump(mode='python')) # Use mode='python' for easier modification
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deep copy base configuration: {e}")
|
||||
raise # Cannot proceed without config
|
||||
logger.error(f"Trial {trial.number}: Failed to deep copy base configuration: {e}", exc_info=True)
|
||||
# Return worst value
|
||||
return worst_value
|
||||
|
||||
# Suggest values for hyperparameters we want to tune
|
||||
# Example suggestions (adjust ranges and types as needed):
|
||||
# ----- Suggest Hyperparameters -----
|
||||
# (Suggest parameters as before, modifying trial_config_dict)
|
||||
trial_config_dict['training']['learning_rate'] = trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True)
|
||||
trial_config_dict['training']['batch_size'] = trial.suggest_categorical('batch_size', [32, 64, 128])
|
||||
trial_config_dict['model']['hidden_size'] = trial.suggest_int('hidden_size', 32, 256, step=32)
|
||||
trial_config_dict['model']['num_layers'] = trial.suggest_int('num_layers', 1, 4)
|
||||
trial_config_dict['model']['dropout'] = trial.suggest_float('dropout', 0.0, 0.5, step=0.1)
|
||||
# Example: Suggest sequence length? (Requires careful handling as it affects data prep)
|
||||
# trial_config_dict['features']['sequence_length'] = trial.suggest_int('sequence_length', 24, 168, step=24)
|
||||
trial_config_dict['training']['batch_size'] = trial.suggest_categorical('batch_size', [32, 64, 128, 256])
|
||||
trial_config_dict['training']['loss_function'] = trial.suggest_categorical('loss_function', ['MSE', 'MAE'])
|
||||
trial_config_dict['model']['hidden_size'] = trial.suggest_int('hidden_size', 18, 498, step=32)
|
||||
trial_config_dict['model']['num_layers'] = trial.suggest_int('num_layers', 1, 8)
|
||||
trial_config_dict['model']['dropout'] = trial.suggest_float('dropout', 0.0, 0.25, step=0.05)
|
||||
trial_config_dict['features']['sequence_length'] = trial.suggest_int('sequence_length', 24, 168, step=12)
|
||||
# Note: forecast_horizon is NOT tuned here, taken from base_config
|
||||
trial_config_dict['features']['scaling_method'] = trial.suggest_categorical('scaling_method', ['standard', 'minmax', None])
|
||||
use_configured_lags = trial.suggest_categorical('use_lags', [True, False])
|
||||
if not use_configured_lags: trial_config_dict['features']['lags'] = []
|
||||
use_configured_rolling = trial.suggest_categorical('use_rolling_windows', [True, False])
|
||||
if not use_configured_rolling: trial_config_dict['features']['rolling_window_sizes'] = []
|
||||
trial_config_dict['features']['use_time_features'] = trial.suggest_categorical('use_time_features', [True, False])
|
||||
trial_config_dict['features']['sinus_curve'] = trial.suggest_categorical('sinus_curve', [True, False])
|
||||
trial_config_dict['features']['cosin_curve'] = trial.suggest_categorical('cosin_curve', [True, False])
|
||||
trial_config_dict['features']['fill_nan'] = trial.suggest_categorical('fill_nan', ['ffill', 'bfill', 0])
|
||||
# ----- End of Hyperparameter Suggestions -----
|
||||
|
||||
# --- 2. Re-validate Trial Config (Optional but Recommended) ---
|
||||
# --- 2. Re-validate Trial Config ---
|
||||
try:
|
||||
# Transfer forecast horizon from base config (not tuned)
|
||||
trial_config_dict['features']['forecast_horizon'] = base_config.features.forecast_horizon
|
||||
trial_config = MainConfig(**trial_config_dict)
|
||||
logger.debug(f"Trial {trial.number} Config: {trial_config.training} {trial_config.model} {trial_config.features}")
|
||||
logger.info(f"Trial {trial.number} Parameters: {trial.params}")
|
||||
except Exception as e:
|
||||
logger.error(f"Trial {trial.number}: Invalid configuration generated from suggested parameters: {e}")
|
||||
# Return a high value (for minimization) to penalize invalid configs
|
||||
return float('inf')
|
||||
logger.error(f"Trial {trial.number}: Invalid configuration generated: {e}")
|
||||
return worst_value
|
||||
|
||||
# --- Early check for invalid sequence length / forecast horizon combination ---
|
||||
try:
|
||||
# Make sure forecast_horizon is a list before checking max()
|
||||
if not isinstance(trial_config.features.forecast_horizon, list) or not trial_config.features.forecast_horizon:
|
||||
raise ValueError("Trial config has invalid forecast_horizon list.")
|
||||
min_data_for_sequence = trial_config.features.sequence_length + max(trial_config.features.forecast_horizon)
|
||||
if min_data_for_sequence > len(df):
|
||||
logger.warning(f"Trial {trial.number}: Skipped. sequence_length ({trial_config.features.sequence_length}) + "
|
||||
f"max_horizon ({max(trial_config.features.forecast_horizon)}) "
|
||||
f"exceeds data length ({len(df)}).")
|
||||
# Optuna doesn't directly support skipping, so return worst values
|
||||
return worst_value
|
||||
except Exception as e:
|
||||
logger.error(f"Trial {trial.number}: Error during pre-check: {e}", exc_info=True)
|
||||
return worst_value
|
||||
|
||||
# --- 3. Run Classic Train/Test ---
|
||||
logger.info(f"Trial {trial.number}: Starting Classic Run...")
|
||||
validation_metric_value = worst_value # Initialize to worst
|
||||
try:
|
||||
n_samples = len(df)
|
||||
val_frac = trial_config.cross_validation.val_size_fraction
|
||||
test_frac = trial_config.cross_validation.test_size_fraction
|
||||
train_idx_cl, val_idx_cl, test_idx_cl = split_data_classic(n_samples, val_frac, test_frac)
|
||||
|
||||
|
||||
# --- 3. Run Cross-Validation for this Trial ---
|
||||
cv_splitter = TimeSeriesCrossValidationSplitter(trial_config.cross_validation, len(df))
|
||||
fold_best_val_metrics: List[Optional[float]] = []
|
||||
# Prepare data for classic split
|
||||
train_loader_cl, val_loader_cl, test_loader_cl, target_scaler_cl, input_size_cl = prepare_fold_data_and_loaders(
|
||||
full_df=df, train_idx=train_idx_cl, val_idx=val_idx_cl, test_idx=test_idx_cl,
|
||||
target_col=trial_config.data.target_col, feature_config=trial_config.features,
|
||||
train_config=trial_config.training, eval_config=trial_config.evaluation
|
||||
)
|
||||
|
||||
for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split()):
|
||||
fold_id = fold_num + 1
|
||||
logger.info(f"Trial {trial.number}, Fold {fold_id}: Starting fold evaluation.")
|
||||
fold_start_time = time.perf_counter()
|
||||
# Initialize Model
|
||||
model_cl = LSTMForecastLightningModule(
|
||||
model_config=trial_config.model, train_config=trial_config.training,
|
||||
input_size=input_size_cl, target_scaler=target_scaler_cl
|
||||
)
|
||||
|
||||
# Create a temporary directory for this specific trial+fold if needed (usually avoid for HPO)
|
||||
# fold_trial_dir = output_base_dir / f"trial_{trial.number}" / f"fold_{fold_id:02d}"
|
||||
# fold_trial_dir.mkdir(parents=True, exist_ok=True)
|
||||
# Callbacks (EarlyStopping and Pruning)
|
||||
callbacks_cl = []
|
||||
if trial_config.training.early_stopping_patience and trial_config.training.early_stopping_patience > 0:
|
||||
callbacks_cl.append(EarlyStopping(
|
||||
monitor=validation_metric_monitor, # Monitor validation metric
|
||||
patience=trial_config.training.early_stopping_patience,
|
||||
mode=monitor_mode, verbose=False
|
||||
))
|
||||
|
||||
try:
|
||||
# --- Per-Fold Data Prep ---
|
||||
# Use trial_config for batch sizes etc.
|
||||
train_loader, val_loader, _, target_scaler, input_size = prepare_fold_data_and_loaders(
|
||||
full_df=df, train_idx=train_idx, val_idx=val_idx, test_idx=test_idx, # Test loader not needed here
|
||||
target_col=trial_config.data.target_col,
|
||||
feature_config=trial_config.features,
|
||||
train_config=trial_config.training,
|
||||
eval_config=trial_config.evaluation # Pass eval for batch size if needed by prep?
|
||||
)
|
||||
# Trainer for classic run
|
||||
trainer_cl = pl.Trainer(
|
||||
accelerator='gpu' if torch.cuda.is_available() else 'cpu', devices=1 if torch.cuda.is_available() else None,
|
||||
max_epochs=trial_config.training.epochs, callbacks=callbacks_cl, logger=False, # logger=False as per original
|
||||
enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False,
|
||||
gradient_clip_val=getattr(trial_config.training, 'gradient_clip_val', None),
|
||||
precision=getattr(trial_config.training, 'precision', 32),
|
||||
)
|
||||
|
||||
# --- Model Instantiation ---
|
||||
current_model_config = trial_config.model.copy(update={'input_size': input_size,
|
||||
'forecast_horizon': trial_config.features.forecast_horizon})
|
||||
model = LSTMForecastLightningModule(
|
||||
model_config=current_model_config,
|
||||
train_config=trial_config.training,
|
||||
target_scaler=target_scaler
|
||||
)
|
||||
# --- Train Model ---
|
||||
logger.info(f"Trial {trial.number}: Fitting model on classic train/val split...")
|
||||
trainer_cl.fit(model_cl, train_dataloaders=train_loader_cl, val_dataloaders=val_loader_cl)
|
||||
|
||||
# --- Callbacks for this Trial/Fold ---
|
||||
# Monitor the metric Optuna cares about
|
||||
monitor_mode = "min" if args.direction == "minimize" else "max"
|
||||
# --- Get Best Validation Score ---
|
||||
# Check early stopping callback first if it exists
|
||||
best_score_tensor = None
|
||||
if callbacks_cl and isinstance(callbacks_cl[0], EarlyStopping):
|
||||
if hasattr(callbacks_cl[0], 'best_score') and callbacks_cl[0].best_score is not None:
|
||||
best_score_tensor = callbacks_cl[0].best_score
|
||||
elif callbacks_cl[0].stopped_epoch > 0 : # Early stopping triggered
|
||||
logger.debug(f"Trial {trial.number}: Early stopping triggered, attempting to use last callback metric.")
|
||||
|
||||
callbacks = []
|
||||
if trial_config.training.early_stopping_patience is not None and trial_config.training.early_stopping_patience > 0:
|
||||
early_stopping = EarlyStopping(
|
||||
monitor=metric_to_optimize,
|
||||
patience=trial_config.training.early_stopping_patience,
|
||||
mode=monitor_mode,
|
||||
verbose=False # Less verbose during HPO
|
||||
)
|
||||
callbacks.append(early_stopping)
|
||||
# If early stopping didn't capture best score, use last metrics from trainer
|
||||
if best_score_tensor is None:
|
||||
metric_val = trainer_cl.callback_metrics.get(validation_metric_monitor)
|
||||
if metric_val is not None:
|
||||
best_score_tensor = metric_val # Use the last logged value
|
||||
|
||||
# Add Optuna Pruning Callback
|
||||
if enable_pruning:
|
||||
pruning_callback = PyTorchLightningPruningCallback(trial, monitor=metric_to_optimize)
|
||||
callbacks.append(pruning_callback)
|
||||
if best_score_tensor is None:
|
||||
logger.warning(f"Trial {trial.number}: Metric '{validation_metric_monitor}' not found in callbacks or metrics. Using {worst_value}.")
|
||||
validation_metric_value = worst_value
|
||||
else:
|
||||
validation_metric_value = best_score_tensor.item()
|
||||
logger.info(f"Trial {trial.number}: Best val score ({validation_metric_monitor}) = {validation_metric_value:.4f}")
|
||||
|
||||
# Optional: LR Monitor
|
||||
# callbacks.append(LearningRateMonitor(logging_interval='epoch'))
|
||||
# Report intermediate value for pruning (if enabled)
|
||||
trial.report(validation_metric_value, trainer_cl.current_epoch)
|
||||
if trial.should_prune():
|
||||
logger.info(f"Trial {trial.number}: Pruned.")
|
||||
raise optuna.TrialPruned()
|
||||
|
||||
# --- Trainer for this Trial/Fold ---
|
||||
trainer = pl.Trainer(
|
||||
accelerator='gpu' if torch.cuda.is_available() else 'cpu',
|
||||
devices=1 if torch.cuda.is_available() else None,
|
||||
max_epochs=trial_config.training.epochs,
|
||||
callbacks=callbacks,
|
||||
logger=False, # Disable default PL logging during HPO
|
||||
enable_checkpointing=False, # Disable checkpoint saving during HPO
|
||||
enable_progress_bar=False, # Disable progress bar for cleaner logs
|
||||
enable_model_summary=False, # Disable model summary
|
||||
gradient_clip_val=getattr(trial_config.training, 'gradient_clip_val', None),
|
||||
precision=getattr(trial_config.training, 'precision', 32),
|
||||
# Log GPU usage if available?
|
||||
# log_gpu_memory='min_max',
|
||||
)
|
||||
# Note: We don't run prediction/evaluation on the test set here,
|
||||
# as the objective is based on validation performance.
|
||||
# The test set evaluation will be done later for the best trial.
|
||||
|
||||
# --- Fit the Model ---
|
||||
logger.info(f"Trial {trial.number}, Fold {fold_id}: Fitting model...")
|
||||
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
|
||||
logger.info(f"Trial {trial.number}: Finished Classic Run in {time.perf_counter() - trial_start_time:.2f}s")
|
||||
|
||||
# --- Get Best Validation Score for Pruning/Reporting ---
|
||||
# Access the monitored metric value from the trainer's logged metrics or callback state
|
||||
# Ensure the key matches exactly what's logged in validation_step
|
||||
best_val_score = trainer.callback_metrics.get(metric_to_optimize)
|
||||
|
||||
if best_val_score is None:
|
||||
logger.warning(f"Trial {trial.number}, Fold {fold_id}: Metric '{metric_to_optimize}' not found in trainer metrics. Using inf/nan.")
|
||||
# Handle cases where training might have failed or metric wasn't logged
|
||||
best_val_score = float('inf') if monitor_mode == 'min' else float('-inf') # Return worst possible value
|
||||
else:
|
||||
best_val_score = best_val_score.item() # Convert tensor to float
|
||||
logger.info(f"Trial {trial.number}, Fold {fold_id}: Best validation score ({metric_to_optimize}) = {best_val_score:.4f}")
|
||||
|
||||
fold_best_val_metrics.append(best_val_score)
|
||||
|
||||
# --- Intermediate Pruning Report (Optional but Recommended) ---
|
||||
# Report the intermediate value (best score for this fold) to Optuna
|
||||
# trial.report(best_val_score, fold_id) # Report score at step `fold_id`
|
||||
# Check if the trial should be pruned based on reported values
|
||||
# if trial.should_prune():
|
||||
# logger.info(f"Trial {trial.number}: Pruned after fold {fold_id}.")
|
||||
# raise optuna.TrialPruned()
|
||||
|
||||
logger.info(f"Trial {trial.number}, Fold {fold_id}: Finished in {time.perf_counter() - fold_start_time:.2f}s")
|
||||
|
||||
except optuna.TrialPruned:
|
||||
# Re-raise prune exception to let Optuna handle it
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Trial {trial.number}, Fold {fold_id}: Failed with error: {e}", exc_info=True)
|
||||
# Record a failure for this fold (e.g., append NaN or worst value)
|
||||
fold_best_val_metrics.append(float('inf') if monitor_mode == 'min' else float('-inf'))
|
||||
# Optionally: Break the CV loop for this trial if one fold fails catastrophically?
|
||||
# break
|
||||
except optuna.TrialPruned:
|
||||
# Propagate prune signal, objective will be set to worst later by Optuna
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Trial {trial.number}: Failed during classic run phase: {e}", exc_info=True)
|
||||
validation_metric_value = worst_value # Assign worst value if classic run fails
|
||||
finally:
|
||||
# Clean up GPU memory after the run
|
||||
del model_cl, trainer_cl, train_loader_cl, val_loader_cl, test_loader_cl
|
||||
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
||||
|
||||
|
||||
# --- 4. Calculate Average Metric Across Folds ---
|
||||
if not fold_best_val_metrics:
|
||||
logger.error(f"Trial {trial.number}: No validation results obtained across folds.")
|
||||
return float('inf') # Return worst value
|
||||
|
||||
# Handle potential infinities or NaNs from failed folds
|
||||
valid_scores = [s for s in fold_best_val_metrics if np.isfinite(s)]
|
||||
if not valid_scores:
|
||||
logger.error(f"Trial {trial.number}: All folds failed or produced non-finite scores.")
|
||||
return float('inf')
|
||||
|
||||
average_val_metric = np.mean(valid_scores)
|
||||
# --- 4. Return Objective ---
|
||||
logger.info(f"--- Trial {trial.number}: Finished ---")
|
||||
logger.info(f" Average validation {metric_to_optimize}: {average_val_metric:.5f}")
|
||||
logger.info(f" Total trial time: {time.perf_counter() - trial_start_time:.2f}s")
|
||||
logger.info(f" Objective (Validation {validation_metric_monitor}): {validation_metric_value:.5f}")
|
||||
logger.info(f" Total time: {time.perf_counter() - trial_start_time:.2f}s")
|
||||
|
||||
# --- 5. Return Metric for Optuna ---
|
||||
return average_val_metric
|
||||
# Return the single objective (validation metric)
|
||||
return float(validation_metric_value)
|
||||
|
||||
|
||||
# --- Main HPO Execution ---
|
||||
def run_hpo():
|
||||
"""Main execution function for HPO."""
|
||||
global args # Make args accessible in objective (simplifies passing) - or use functools.partial
|
||||
args = parse_arguments()
|
||||
config_path = Path(args.config)
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True) # Ensure output dir exists
|
||||
|
||||
# Adjust log level if debug flag is set
|
||||
if args.debug:
|
||||
root_logger.setLevel(logging.DEBUG)
|
||||
optuna_lg.setLevel(logging.DEBUG)
|
||||
pl_logger.setLevel(logging.DEBUG)
|
||||
logger.debug("Debug mode enabled.")
|
||||
else:
|
||||
# Reduce verbosity during HPO runs
|
||||
optuna_lg.setLevel(logging.WARNING)
|
||||
pl_logger.setLevel(logging.INFO) # Keep INFO for PL start/end messages
|
||||
|
||||
# --- Configuration Loading ---
|
||||
try:
|
||||
base_config = load_config(config_path)
|
||||
except Exception:
|
||||
base_config = load_config(config_path) # Load base config once
|
||||
logger.info(f"Successfully loaded configuration from {config_path}")
|
||||
except Exception as e:
|
||||
logger.critical(f"Failed to load configuration from {config_path}: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
# --- Seed Setting (for HPO script itself) ---
|
||||
set_seeds(args.seed)
|
||||
# Setup output dir...
|
||||
if args.output_dir:
|
||||
hpo_base_output_dir = Path(args.output_dir) # Use specific name for HPO dir
|
||||
logger.info(f"Using HPO output directory from command line: {hpo_base_output_dir}")
|
||||
elif base_config.optuna.storage and base_config.optuna.storage.startswith("sqlite:///"):
|
||||
hpo_base_output_dir = Path(base_config.optuna.storage.replace("sqlite:///", "")).parent
|
||||
logger.info(f"Using HPO output directory from Optuna storage path: {hpo_base_output_dir}")
|
||||
else:
|
||||
# Use output_dir from main config if available, otherwise default
|
||||
main_output_dir = Path(getattr(base_config, 'output_dir', 'output'))
|
||||
hpo_base_output_dir = main_output_dir / 'hpo_results'
|
||||
logger.info(f"Using HPO output directory: {hpo_base_output_dir}")
|
||||
hpo_base_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# --- Load Data Once ---
|
||||
# Assume data doesn't change based on HPs (unless sequence_length is tuned heavily)
|
||||
# Setup logging... (ensure file handler uses hpo_base_output_dir)
|
||||
try:
|
||||
logger.info("Loading base dataset...")
|
||||
level_name = base_config.log_level.upper()
|
||||
effective_log_level = logging.getLevelName(level_name)
|
||||
log_file = hpo_base_output_dir / f"{base_config.optuna.study_name}_hpo.log"
|
||||
file_handler = logging.FileHandler(log_file, mode='a')
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)-25s - %(levelname)-7s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
||||
file_handler.setFormatter(formatter)
|
||||
# Add handler only if it's not already added (e.g., if run_hpo is called multiple times)
|
||||
if not any(isinstance(h, logging.FileHandler) and h.baseFilename == str(log_file.resolve()) for h in logger.handlers):
|
||||
logger.addHandler(file_handler)
|
||||
logger.setLevel(effective_log_level)
|
||||
logger.info(f"Set log level to {level_name}. Logging HPO run to console and {log_file}")
|
||||
if effective_log_level <= logging.DEBUG: logger.debug("Debug logging enabled.")
|
||||
except (AttributeError, ValueError) as e:
|
||||
logger.warning(f"Could not set log level from config. Defaulting to INFO. Error: {e}")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
# Setup seeding...
|
||||
try:
|
||||
seed = base_config.random_seed
|
||||
set_seeds(seed)
|
||||
except AttributeError:
|
||||
logger.warning("Config missing 'random_seed'. Using default seed 42.")
|
||||
set_seeds(42)
|
||||
|
||||
# Load data...
|
||||
try:
|
||||
logger.info("Loading base dataset for HPO...")
|
||||
df = load_raw_data(base_config.data)
|
||||
logger.info("Base dataset loaded.")
|
||||
logger.info(f"Base dataset loaded. Shape: {df.shape}")
|
||||
except Exception as e:
|
||||
logger.critical(f"Failed to load raw data for HPO: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# --- Optuna Study Setup ---
|
||||
storage_path = args.storage_db
|
||||
if storage_path:
|
||||
# Ensure directory exists if using SQLite file storage
|
||||
db_path = Path(storage_path.replace("sqlite:///", ""))
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
storage_path = f"sqlite:///{db_path.resolve()}" # Use absolute path
|
||||
logger.info(f"Using Optuna storage: {storage_path}")
|
||||
else:
|
||||
logger.warning("No Optuna storage DB specified, using in-memory storage (results lost on exit).")
|
||||
try:
|
||||
hpo_config = base_config.optuna
|
||||
if not hpo_config.enabled:
|
||||
logger.info("Optuna optimization is disabled in the configuration.")
|
||||
sys.exit(0)
|
||||
except AttributeError:
|
||||
logger.critical("Optuna configuration section ('optuna') missing.")
|
||||
sys.exit(1)
|
||||
|
||||
storage_path = hpo_config.storage
|
||||
if storage_path and storage_path.startswith("sqlite:///"):
|
||||
db_path = Path(storage_path.replace("sqlite:///", ""))
|
||||
if not db_path.is_absolute():
|
||||
db_path = hpo_base_output_dir / db_path # Relative to HPO output dir
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
storage_path = f"sqlite:///{db_path.resolve()}"
|
||||
logger.info(f"Using Optuna storage: {storage_path}")
|
||||
elif storage_path:
|
||||
logger.info(f"Using Optuna storage: {storage_path} (non-SQLite)")
|
||||
else:
|
||||
logger.warning("No Optuna storage DB specified in config, using in-memory storage.")
|
||||
|
||||
try:
|
||||
# Create or load the study
|
||||
# Change to single objective 'minimize'
|
||||
study = optuna.create_study(
|
||||
study_name=args.study_name,
|
||||
study_name=hpo_config.study_name,
|
||||
storage=storage_path,
|
||||
direction=args.direction,
|
||||
load_if_exists=True, # Load previous results if study exists
|
||||
pruner=optuna.pruners.MedianPruner() if args.pruning else optuna.pruners.NopPruner() # Example pruner
|
||||
direction="minimize", # Changed to single direction
|
||||
load_if_exists=True,
|
||||
pruner=optuna.pruners.MedianPruner() if hpo_config.pruning else optuna.pruners.NopPruner()
|
||||
)
|
||||
# Remove multi-objective check/attribute setting
|
||||
# if not study._is_multi_objective:
|
||||
# logger.warning(f"Study '{hpo_config.study_name}' exists but is not multi-objective.")
|
||||
# elif 'objective_names' not in study.user_attrs:
|
||||
# study.set_user_attr('objective_names', objective_names)
|
||||
|
||||
|
||||
# --- Run Optimization ---
|
||||
logger.info(f"Starting Optuna optimization: study='{args.study_name}', n_trials={args.n_trials}, metric='{args.metric_to_optimize}', direction='{args.direction}'")
|
||||
logger.info(f"Starting Optuna single-objective optimization: study='{hpo_config.study_name}', n_trials={hpo_config.n_trials}") # Updated log message
|
||||
study.optimize(
|
||||
lambda trial: objective(trial, base_config, df, output_dir, args.metric_to_optimize, args.pruning),
|
||||
n_trials=args.n_trials,
|
||||
timeout=None # Optional: Set timeout in seconds
|
||||
# Optional: Add callbacks (e.g., logging callback)
|
||||
lambda trial: objective(trial, base_config, df), # Pass base_config
|
||||
n_trials=hpo_config.n_trials,
|
||||
timeout=None,
|
||||
gc_after_trial=True
|
||||
)
|
||||
|
||||
# --- Report Best Trial ---
|
||||
# --- Report and Process Best Trial ---
|
||||
logger.info("--- Optuna HPO Finished ---")
|
||||
logger.info(f"Number of finished trials: {len(study.trials)}")
|
||||
|
||||
# Get the single best trial
|
||||
best_trial = study.best_trial
|
||||
logger.info(f"Best trial number: {best_trial.number}")
|
||||
logger.info(f" Best validation {args.metric_to_optimize}: {best_trial.value:.5f}")
|
||||
logger.info(" Best hyperparameters:")
|
||||
for key, value in best_trial.params.items():
|
||||
logger.info(f" {key}: {value}")
|
||||
if best_trial is None:
|
||||
logger.warning("Optuna study finished, but no successful trial was completed.")
|
||||
else:
|
||||
logger.info(f"Best trial found (Trial {best_trial.number}):")
|
||||
# Log details for the best trial
|
||||
validation_metric_monitor = base_config.optuna.metric_to_optimize
|
||||
logger.info(f" Objective ({validation_metric_monitor}): {best_trial.value:.5f}") # Use .value for single objective
|
||||
logger.info(f" Hyperparameters:")
|
||||
for key, value in best_trial.params.items():
|
||||
logger.info(f" {key}: {value}")
|
||||
|
||||
# --- Save Best Hyperparameters (Optional) ---
|
||||
best_params_file = output_dir / f"{args.study_name}_best_params.json"
|
||||
try:
|
||||
with open(best_params_file, 'w') as f:
|
||||
import json
|
||||
json.dump(best_trial.params, f, indent=4)
|
||||
logger.info(f"Best hyperparameters saved to {best_params_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save best parameters: {e}")
|
||||
# --- Re-run and Save Artifacts for the Best Trial ---
|
||||
logger.info(f"-> Re-running Best Trial {best_trial.number} to save artifacts...")
|
||||
trial_output_dir = hpo_base_output_dir / f"best_trial_num{best_trial.number}" # Simplified directory name
|
||||
trial_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
# 1. Create config for this trial
|
||||
best_config_dict = copy.deepcopy(base_config.model_dump(mode='python'))
|
||||
# Update with trial's hyperparameters
|
||||
for key, value in best_trial.params.items(): # Use best_trial.params
|
||||
# Need to place the param in the correct nested config dict
|
||||
if key in best_config_dict['training']: best_config_dict['training'][key] = value
|
||||
elif key in best_config_dict['model']: best_config_dict['model'][key] = value
|
||||
elif key in best_config_dict['features']: best_config_dict['features'][key] = value
|
||||
# Add more sections if HPO tunes them
|
||||
|
||||
# Add non-tuned features forecast_horizon back
|
||||
best_config_dict['features']['forecast_horizon'] = base_config.features.forecast_horizon
|
||||
|
||||
# Ensure evaluation plots and model saving are enabled for this final run
|
||||
best_config_dict['evaluation']['save_plots'] = True
|
||||
best_config_dict['training']['save_model'] = True # Assuming you want to save the best model
|
||||
|
||||
# Validate the final config for this trial
|
||||
best_trial_config = MainConfig(**best_config_dict)
|
||||
|
||||
# Save the specific config used for this run
|
||||
with open(trial_output_dir / "best_config.yaml", 'w') as f:
|
||||
yaml.dump(best_config_dict, f, default_flow_style=False, sort_keys=False)
|
||||
|
||||
# 2. Run classic training (which saves model & plots)
|
||||
logger.info(f"-> Running classic training for Best Trial {best_trial.number}...")
|
||||
# Pass the specific config and output directory
|
||||
run_classic_training(
|
||||
config=best_trial_config,
|
||||
full_df=df,
|
||||
output_base_dir=trial_output_dir # outputs -> hpo_results/best_trial_num<n>/classic_run/
|
||||
)
|
||||
logger.info(f"-> Finished re-running and saving artifacts for Best Trial {best_trial.number} to {trial_output_dir}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"-> Failed to re-run or save artifacts for Best Trial {best_trial.number}: {e}", exc_info=True)
|
||||
|
||||
# --- Save Best Hyperparameters ---
|
||||
best_params_file = hpo_base_output_dir / f"{hpo_config.study_name}_best_params.json" # Simplified filename
|
||||
try:
|
||||
with open(best_params_file, 'w') as f:
|
||||
import json
|
||||
json.dump(best_trial.params, f, indent=4) # Use best_trial.params
|
||||
logger.info(f"Hyperparameters of the best trial saved to {best_params_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save parameters for best trial: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.critical(f"An critical error occurred during the Optuna study: {e}", exc_info=True)
|
||||
logger.critical(f"A critical error occurred during the Optuna study: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user