intermediate backup
This commit is contained in:
603
optuna_ensemble_run.py
Normal file
603
optuna_ensemble_run.py
Normal file
@ -0,0 +1,603 @@
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
import warnings
|
||||
import copy # For deep copying config
|
||||
from pathlib import Path
|
||||
import time
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import yaml
|
||||
import json # Import json to save best ensemble definition
|
||||
|
||||
import optuna
|
||||
|
||||
# Import necessary components from the forecasting_model package
|
||||
from forecasting_model.utils.forecast_config_model import MainConfig
|
||||
from forecasting_model.data_processing import (
|
||||
load_raw_data,
|
||||
TimeSeriesCrossValidationSplitter,
|
||||
# prepare_fold_data_and_loaders used by run_single_fold
|
||||
)
|
||||
# Import the single fold runner from the main script
|
||||
from forecasting_model_run import run_single_fold
|
||||
from forecasting_model.train.ensemble_evaluation import run_ensemble_evaluation
|
||||
from typing import List, Optional, Tuple, Dict, Any # Added Any for dictionary
|
||||
|
||||
# Import helper functions
|
||||
from forecasting_model.utils.helper import load_config, set_seeds, aggregate_cv_metrics, save_results
|
||||
|
||||
# --- Suppress specific PL warnings about logger=True with no logger ---
|
||||
# This is expected behavior in optuna_run.py where logger=False is intentional
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=".*You called `self.log.*logger=True.*no logger configured.*",
|
||||
category=UserWarning, # These specific warnings are often UserWarnings
|
||||
module="pytorch_lightning.core.module"
|
||||
)
|
||||
|
||||
# Silence overly verbose libraries if needed
|
||||
mpl_logger = logging.getLogger('matplotlib')
|
||||
mpl_logger.setLevel(logging.WARNING)
|
||||
pil_logger = logging.getLogger('PIL')
|
||||
pil_logger.setLevel(logging.WARNING)
|
||||
pl_logger = logging.getLogger('pytorch_lightning')
|
||||
pl_logger.setLevel(logging.WARNING) # Set PL to WARNING by default, INFO/DEBUG set below if needed
|
||||
|
||||
# --- Basic Logging Setup ---
|
||||
# Configure logging early. Level will be set properly later based on config.
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)-7s - %(message)s',
|
||||
datefmt='%H:%M:%S')
|
||||
# Get the root logger
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
# --- Argument Parsing ---
|
||||
def parse_arguments():
|
||||
"""Parses command-line arguments for Optuna Ensemble HPO."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run HPO optimizing ensemble performance using Optuna.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument(
|
||||
'-c', '--config', type=str, default='forecasting_config.yaml',
|
||||
help="Path to the YAML configuration file."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output-dir', type=str, default=None,
|
||||
help="Override base output directory for HPO results."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--keep-artifacts', action='store_true',
|
||||
help="Prevent cleanup of trial directories after best trial is determined."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
# --- Optuna Objective Function ---
|
||||
def objective(
|
||||
trial: optuna.Trial,
|
||||
base_config: MainConfig,
|
||||
df: pd.DataFrame,
|
||||
hpo_base_output_dir: Path # Pass base dir for trial outputs
|
||||
) -> float: # Return the single ensemble metric to optimize
|
||||
"""
|
||||
Optuna objective function optimizing ensemble performance.
|
||||
"""
|
||||
logger.info(f"\n--- Starting Optuna Trial {trial.number} ---")
|
||||
trial_start_time = time.perf_counter()
|
||||
|
||||
# Define trial-specific output directory for fold artifacts
|
||||
trial_artifacts_dir = hpo_base_output_dir / "ensemble_runs_artifacts" / f"trial_{trial.number}"
|
||||
trial_artifacts_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.debug(f"Trial artifacts will be saved to: {trial_artifacts_dir}")
|
||||
|
||||
hpo_config = base_config.optuna
|
||||
# Metric for pruning based on individual fold performance
|
||||
validation_metric_monitor = hpo_config.metric_to_optimize
|
||||
# Ensemble metric and method to optimize (e.g., MAE of the 'mean' ensemble)
|
||||
ensemble_metric_optimize = 'MAE'
|
||||
ensemble_method_optimize = 'mean'
|
||||
optimization_direction = hpo_config.direction # 'minimize' or 'maximize'
|
||||
worst_value = float('inf') if optimization_direction == 'minimize' else float('-inf')
|
||||
|
||||
# Store paths and details for all saved artifacts for this trial's folds
|
||||
fold_artifact_details: List[Dict[str, Any]] = [] # Changed to list of dicts
|
||||
|
||||
# --- 1. Suggest Hyperparameters ---
|
||||
try:
|
||||
trial_config_dict = copy.deepcopy(base_config.model_dump(mode='python'))
|
||||
except Exception as e:
|
||||
logger.error(f"Trial {trial.number}: Failed to deep copy base config: {e}", exc_info=True)
|
||||
return worst_value
|
||||
|
||||
# ----- Suggest Hyperparameters -----
|
||||
# Modify trial_config_dict using trial.suggest_*
|
||||
trial_config_dict['training']['learning_rate'] = trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True)
|
||||
trial_config_dict['training']['batch_size'] = trial.suggest_categorical('batch_size', [32, 64, 128, 256])
|
||||
trial_config_dict['training']['loss_function'] = trial.suggest_categorical('loss_function', ['MSE', 'MAE'])
|
||||
trial_config_dict['model']['hidden_size'] = trial.suggest_int('hidden_size', 18, 498, step=32)
|
||||
trial_config_dict['model']['num_layers'] = trial.suggest_int('num_layers', 1, 8)
|
||||
trial_config_dict['model']['dropout'] = trial.suggest_float('dropout', 0.0, 0.25, step=0.05)
|
||||
trial_config_dict['features']['sequence_length'] = trial.suggest_int('sequence_length', 24, 168, step=12)
|
||||
trial_config_dict['features']['scaling_method'] = trial.suggest_categorical('scaling_method', ['standard', 'minmax', None])
|
||||
use_configured_lags = trial.suggest_categorical('use_lags', [True, False])
|
||||
if not use_configured_lags: trial_config_dict['features']['lags'] = []
|
||||
use_configured_rolling = trial.suggest_categorical('use_rolling_windows', [True, False])
|
||||
if not use_configured_rolling: trial_config_dict['features']['rolling_window_sizes'] = []
|
||||
trial_config_dict['features']['use_time_features'] = trial.suggest_categorical('use_time_features', [True, False])
|
||||
trial_config_dict['features']['sinus_curve'] = trial.suggest_categorical('sinus_curve', [True, False])
|
||||
trial_config_dict['features']['cosin_curve'] = trial.suggest_categorical('cosin_curve', [True, False])
|
||||
trial_config_dict['features']['fill_nan'] = trial.suggest_categorical('fill_nan', ['ffill', 'bfill', 0])
|
||||
# ----- End of Suggestions -----
|
||||
|
||||
# --- 2. Re-validate Trial Config ---
|
||||
try:
|
||||
trial_config_dict['features']['forecast_horizon'] = base_config.features.forecast_horizon
|
||||
# Disable plotting during HPO runs to save time/resources
|
||||
trial_config_dict['evaluation']['save_plots'] = False
|
||||
trial_config = MainConfig(**trial_config_dict)
|
||||
logger.info(f"Trial {trial.number} Parameters: {trial.params}")
|
||||
except Exception as e:
|
||||
logger.error(f"Trial {trial.number}: Invalid config generated: {e}", exc_info=True)
|
||||
return worst_value
|
||||
|
||||
# --- Early check for data length ---
|
||||
# ... (Keep the check as in optuna_run.py) ...
|
||||
try:
|
||||
if not isinstance(trial_config.features.forecast_horizon, list) or not trial_config.features.forecast_horizon:
|
||||
raise ValueError("Trial config has invalid forecast_horizon list.")
|
||||
min_data_for_sequence = trial_config.features.sequence_length + max(trial_config.features.forecast_horizon)
|
||||
if min_data_for_sequence > len(df):
|
||||
logger.warning(f"Trial {trial.number}: Skipped. sequence_length + max_horizon ({min_data_for_sequence}) exceeds data length ({len(df)}).")
|
||||
# Report worst value so Optuna knows this trial failed badly
|
||||
# Using study direction to determine the appropriate "worst" value
|
||||
return worst_value
|
||||
# raise optuna.TrialPruned() # Alternative: Prune instead of returning worst value
|
||||
except Exception as e:
|
||||
logger.error(f"Trial {trial.number}: Error during pre-check: {e}", exc_info=True)
|
||||
return worst_value
|
||||
|
||||
# --- 3. Run Cross-Validation Training (Saving Artifacts) ---
|
||||
all_fold_best_val_scores = {} # Store best val scores for pruning
|
||||
actual_folds_trained = 0
|
||||
# Store paths to saved models and scalers for this trial
|
||||
fold_model_paths = []
|
||||
fold_scaler_paths = []
|
||||
try:
|
||||
cv_splitter = TimeSeriesCrossValidationSplitter(trial_config.cross_validation, len(df))
|
||||
|
||||
for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split()):
|
||||
fold_id = fold_num + 1
|
||||
logger.info(f"Trial {trial.number}, Fold {fold_id}/{cv_splitter.n_splits}: Training model...")
|
||||
current_fold_best_metric = None # Reset for each fold
|
||||
|
||||
try:
|
||||
# Use run_single_fold - it handles training and saving artifacts
|
||||
# Pass trial_output_dir so fold artifacts are saved per trial
|
||||
fold_metrics, best_val_score, saved_model_path, saved_scaler_path, saved_input_size_path, saved_config_path = run_single_fold(
|
||||
fold_num=fold_num,
|
||||
train_idx=train_idx, val_idx=val_idx, test_idx=test_idx,
|
||||
config=trial_config, # Use the config with trial's hyperparameters
|
||||
full_df=df,
|
||||
output_base_dir=trial_artifacts_dir # Save folds under trial dir
|
||||
)
|
||||
actual_folds_trained += 1
|
||||
all_fold_best_val_scores[fold_id] = best_val_score
|
||||
|
||||
# Store all artifact paths for this fold
|
||||
fold_artifact_details.append({
|
||||
"fold_id": fold_id,
|
||||
"model_path": str(saved_model_path) if saved_model_path else None,
|
||||
"target_scaler_path": str(saved_scaler_path) if saved_scaler_path else None,
|
||||
"input_size_path": str(saved_input_size_path) if saved_input_size_path else None,
|
||||
"config_path": str(saved_config_path) if saved_config_path else None,
|
||||
})
|
||||
|
||||
# Check if the monitored validation metric was returned
|
||||
if best_val_score is not None and np.isfinite(best_val_score):
|
||||
current_fold_best_metric = best_val_score
|
||||
logger.info(f"Trial {trial.number}, Fold {fold_id}: Best val score ({validation_metric_monitor}) = {current_fold_best_metric:.4f}")
|
||||
else:
|
||||
# Use worst value if metric is missing/invalid for pruning
|
||||
logger.warning(f"Trial {trial.number}, Fold {fold_id}: Invalid or missing validation score ({validation_metric_monitor}). Using {worst_value} for pruning.")
|
||||
current_fold_best_metric = worst_value # Assign worst for pruning report
|
||||
|
||||
# Report intermediate value (individual fold validation score) for pruning
|
||||
trial.report(current_fold_best_metric, fold_num)
|
||||
if trial.should_prune():
|
||||
logger.info(f"Trial {trial.number}: Pruned after fold {fold_id}.")
|
||||
raise optuna.TrialPruned()
|
||||
|
||||
except optuna.TrialPruned:
|
||||
raise # Propagate prune signal
|
||||
except Exception as e:
|
||||
logger.error(f"Trial {trial.number}, Fold {fold_id}: Failed CV fold training: {e}", exc_info=True)
|
||||
all_fold_best_val_scores[fold_id] = None # Mark fold as failed
|
||||
# Continue to next fold if possible, but report worst value for this fold
|
||||
trial.report(worst_value, fold_num)
|
||||
# Optionally raise prune here if too many folds fail? Or let the ensemble eval handle it.
|
||||
|
||||
except optuna.TrialPruned:
|
||||
logger.info(f"Trial {trial.number}: Pruned during CV training phase.")
|
||||
return worst_value # Return worst value when pruned
|
||||
except Exception as e:
|
||||
logger.critical(f"Trial {trial.number}: Failed critically during CV training setup/loop: {e}", exc_info=True)
|
||||
return worst_value
|
||||
|
||||
|
||||
# --- 4. Run Ensemble Evaluation ---
|
||||
if actual_folds_trained < 2:
|
||||
logger.error(f"Trial {trial.number}: Only {actual_folds_trained} folds trained successfully. Cannot run ensemble evaluation.")
|
||||
return worst_value # Not enough models for ensemble
|
||||
|
||||
logger.info(f"Trial {trial.number}: Starting Ensemble Evaluation using {actual_folds_trained} trained models...")
|
||||
ensemble_metric_final = worst_value # Initialize to worst
|
||||
best_ensemble_method_for_trial = None # Track the best method for this trial
|
||||
try:
|
||||
# Run evaluation using the artifacts saved in the trial's output directory
|
||||
ensemble_results = run_ensemble_evaluation(
|
||||
config=trial_config, # Pass trial config
|
||||
output_base_dir=trial_artifacts_dir # Directory containing trial's fold subdirs
|
||||
)
|
||||
|
||||
if ensemble_results:
|
||||
# Aggregate the results to get the final objective value
|
||||
ensemble_metrics_for_method = []
|
||||
for fold_num, fold_res in ensemble_results.items():
|
||||
if fold_res and ensemble_method_optimize in fold_res:
|
||||
method_metrics = fold_res[ensemble_method_optimize]
|
||||
if method_metrics and ensemble_metric_optimize in method_metrics:
|
||||
metric_val = method_metrics[ensemble_metric_optimize]
|
||||
if metric_val is not None and np.isfinite(metric_val):
|
||||
ensemble_metrics_for_method.append(metric_val)
|
||||
else:
|
||||
logger.warning(f"Trial {trial.number}: Invalid ensemble metric value found for fold {fold_num}, method '{ensemble_method_optimize}', metric '{ensemble_metric_optimize}'.")
|
||||
else:
|
||||
logger.warning(f"Trial {trial.number}: Metric '{ensemble_metric_optimize}' not found for method '{ensemble_method_optimize}' in fold {fold_num}.")
|
||||
else:
|
||||
logger.warning(f"Trial {trial.number}: Ensemble method '{ensemble_method_optimize}' results not found for fold {fold_num}.")
|
||||
|
||||
if not ensemble_metrics_for_method:
|
||||
logger.error(f"Trial {trial.number}: No valid ensemble metrics found for method '{ensemble_method_optimize}', metric '{ensemble_metric_optimize}'.")
|
||||
ensemble_metric_final = worst_value
|
||||
else:
|
||||
# Calculate the mean of the chosen ensemble metric across test folds
|
||||
ensemble_metric_final = np.mean(ensemble_metrics_for_method)
|
||||
logger.info(f"Trial {trial.number}: Final Ensemble Metric (Avg {ensemble_method_optimize} {ensemble_metric_optimize}): {ensemble_metric_final:.6f}")
|
||||
|
||||
# Determine the best ensemble method based on average performance across folds
|
||||
# This requires re-calculating averages for *all* methods evaluated by run_ensemble_evaluation
|
||||
all_ensemble_methods = set()
|
||||
for fold_res in ensemble_results.values():
|
||||
if fold_res: all_ensemble_methods.update(fold_res.keys())
|
||||
|
||||
avg_metrics_per_method = {}
|
||||
for method in all_ensemble_methods:
|
||||
method_metrics_across_folds = []
|
||||
for fold_res in ensemble_results.values():
|
||||
if fold_res and method in fold_res and ensemble_metric_optimize in fold_res[method]:
|
||||
metric_val = fold_res[method][ensemble_metric_optimize]
|
||||
if metric_val is not None and np.isfinite(metric_val):
|
||||
method_metrics_across_folds.append(metric_val)
|
||||
if method_metrics_across_folds:
|
||||
avg_metrics_per_method[method] = np.mean(method_metrics_across_folds)
|
||||
|
||||
if avg_metrics_per_method:
|
||||
if optimization_direction == 'minimize':
|
||||
best_ensemble_method_for_trial = min(avg_metrics_per_method, key=avg_metrics_per_method.get)
|
||||
else: # maximize
|
||||
best_ensemble_method_for_trial = max(avg_metrics_per_method, key=avg_metrics_per_method.get)
|
||||
logger.info(f"Trial {trial.number}: Best performing ensemble method for this trial (based on avg {ensemble_metric_optimize}): {best_ensemble_method_for_trial}")
|
||||
else:
|
||||
logger.warning(f"Trial {trial.number}: Could not determine best ensemble method based on average {ensemble_metric_optimize}.")
|
||||
|
||||
else:
|
||||
logger.error(f"Trial {trial.number}: Ensemble evaluation function returned no results.")
|
||||
ensemble_metric_final = worst_value
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Trial {trial.number}: Failed during ensemble evaluation phase: {e}", exc_info=True)
|
||||
ensemble_metric_final = worst_value
|
||||
|
||||
|
||||
# --- 5. Return Final Objective Value ---
|
||||
trial_duration = time.perf_counter() - trial_start_time
|
||||
logger.info(f"--- Trial {trial.number}: Finished ---")
|
||||
logger.info(f" Final Objective Value (Avg Ensemble {ensemble_method_optimize} {ensemble_metric_optimize}): {ensemble_metric_final:.6f}")
|
||||
logger.info(f" Total time: {trial_duration:.2f}s")
|
||||
|
||||
# Store ensemble evaluation results and best method in trial user attributes
|
||||
# This makes it easier to retrieve the best ensemble details after the study
|
||||
trial.set_user_attr("ensemble_evaluation_results", ensemble_results)
|
||||
trial.set_user_attr("best_ensemble_method", best_ensemble_method_for_trial)
|
||||
# trial.set_user_attr("fold_model_paths", fold_model_paths) # Removed
|
||||
# trial.set_user_attr("fold_scaler_paths", fold_scaler_paths) # Removed
|
||||
trial.set_user_attr("fold_artifact_details", fold_artifact_details) # Added comprehensive artifact details
|
||||
|
||||
return ensemble_metric_final
|
||||
|
||||
|
||||
# --- Main HPO Execution ---
|
||||
def run_hpo():
|
||||
"""Main execution function for HPO optimizing ensemble performance."""
|
||||
args = parse_arguments()
|
||||
config_path = Path(args.config)
|
||||
try:
|
||||
base_config = load_config(config_path)
|
||||
logger.info(f"Successfully loaded configuration from {config_path}")
|
||||
except Exception as e:
|
||||
logger.critical(f"Failed to load configuration from {config_path}: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
# --- Setup Output Dir ---
|
||||
if args.output_dir:
|
||||
hpo_base_output_dir = Path(args.output_dir)
|
||||
elif base_config.optuna.storage and base_config.optuna.storage.startswith("sqlite:///"):
|
||||
hpo_base_output_dir = Path(base_config.optuna.storage.replace("sqlite:///", "")).parent
|
||||
else:
|
||||
# Fallback to default if output_dir is not in config either
|
||||
main_output_dir_str = getattr(base_config, 'output_dir', 'output')
|
||||
if not main_output_dir_str: # Handle empty string case
|
||||
main_output_dir_str = 'output'
|
||||
main_output_dir = Path(main_output_dir_str)
|
||||
hpo_base_output_dir = main_output_dir / f'{base_config.optuna.study_name}_ensemble_hpo' # Specific subdir using study name
|
||||
hpo_base_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Using HPO output directory: {hpo_base_output_dir}")
|
||||
|
||||
# --- Setup Logging ---
|
||||
try:
|
||||
level_name = base_config.log_level.upper()
|
||||
effective_log_level = logging.getLevelName(level_name)
|
||||
# Ensure study name is filesystem-safe if used directly
|
||||
safe_study_name = "".join(c if c.isalnum() or c in ('_', '-') else '_' for c in base_config.optuna.study_name)
|
||||
log_file = hpo_base_output_dir / f"{safe_study_name}_ensemble_hpo.log"
|
||||
file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8') # Specify encoding
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)-25s - %(levelname)-7s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
||||
file_handler.setFormatter(formatter)
|
||||
# Prevent adding duplicate handlers if script/function is called multiple times
|
||||
if not any(isinstance(h, logging.FileHandler) and h.baseFilename == str(log_file.resolve()) for h in logger.handlers):
|
||||
logger.addHandler(file_handler)
|
||||
logger.setLevel(effective_log_level)
|
||||
logger.info(f"Set log level to {level_name}. Logging HPO run to console and {log_file}")
|
||||
if effective_log_level <= logging.DEBUG: logger.debug("Debug logging enabled.")
|
||||
except (AttributeError, ValueError, TypeError) as e: # Added TypeError
|
||||
logger.warning(f"Could not set log level from config: {e}. Defaulting to INFO.")
|
||||
logger.setLevel(logging.INFO)
|
||||
# Still try to log to a default file if possible
|
||||
try:
|
||||
log_file = hpo_base_output_dir / "default_ensemble_hpo.log"
|
||||
file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8')
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)-25s - %(levelname)-7s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
||||
file_handler.setFormatter(formatter)
|
||||
if not any(isinstance(h, logging.FileHandler) and h.baseFilename == str(log_file.resolve()) for h in logger.handlers):
|
||||
logger.addHandler(file_handler)
|
||||
logger.info(f"Logging to default file: {log_file}")
|
||||
except Exception as log_e:
|
||||
logger.error(f"Failed to set up default file logging: {log_e}")
|
||||
|
||||
|
||||
# --- Setup Seeding ---
|
||||
set_seeds(getattr(base_config, 'random_seed', 42))
|
||||
|
||||
# --- Load Data ---
|
||||
try:
|
||||
logger.info("Loading base dataset for HPO...")
|
||||
df = load_raw_data(base_config.data)
|
||||
logger.info(f"Base dataset loaded. Shape: {df.shape}")
|
||||
except FileNotFoundError as e:
|
||||
logger.critical(f"Data file not found: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logger.critical(f"Failed to load raw data for HPO: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
# --- Optuna Study Setup ---
|
||||
try:
|
||||
hpo_config = base_config.optuna
|
||||
if not hpo_config.enabled:
|
||||
logger.info("Optuna optimization is disabled in the configuration.")
|
||||
sys.exit(0)
|
||||
except AttributeError:
|
||||
logger.critical("Optuna configuration section ('optuna') missing.")
|
||||
sys.exit(1)
|
||||
|
||||
storage_path = hpo_config.storage
|
||||
if storage_path and storage_path.startswith("sqlite:///"):
|
||||
db_path_str = storage_path.replace("sqlite:///", "")
|
||||
if not db_path_str:
|
||||
# Default filename if only 'sqlite:///' is provided
|
||||
db_path = hpo_base_output_dir / f"{base_config.optuna.study_name}.db"
|
||||
logger.warning(f"SQLite path was empty, defaulting to: {db_path}")
|
||||
else:
|
||||
db_path = Path(db_path_str)
|
||||
|
||||
if not db_path.is_absolute():
|
||||
db_path = hpo_base_output_dir / db_path
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True) # Ensure parent dir exists
|
||||
storage_path = f"sqlite:///{db_path.resolve()}"
|
||||
logger.info(f"Using SQLite storage: {storage_path}")
|
||||
elif storage_path:
|
||||
logger.info(f"Using Optuna storage: {storage_path} (Assuming non-SQLite or pre-configured)")
|
||||
else:
|
||||
storage_path = None # Explicitly set to None for in-memory
|
||||
logger.warning("No Optuna storage DB specified, using in-memory storage.")
|
||||
|
||||
try:
|
||||
# Single objective study based on ensemble performance
|
||||
study = optuna.create_study(
|
||||
study_name=hpo_config.study_name,
|
||||
storage=storage_path,
|
||||
direction=hpo_config.direction, # 'minimize' or 'maximize'
|
||||
load_if_exists=True,
|
||||
pruner=optuna.pruners.MedianPruner() if hpo_config.pruning else optuna.pruners.NopPruner()
|
||||
)
|
||||
|
||||
# --- Run Optimization ---
|
||||
logger.info(f"Starting Optuna optimization for ensemble performance: study='{hpo_config.study_name}', n_trials={hpo_config.n_trials}, direction='{hpo_config.direction}'")
|
||||
study.optimize(
|
||||
lambda trial: objective(trial, base_config, df, hpo_base_output_dir), # Pass base_config and output dir
|
||||
n_trials=hpo_config.n_trials,
|
||||
timeout=None,
|
||||
gc_after_trial=True # Garbage collect after trial
|
||||
)
|
||||
|
||||
# --- Report and Save Best Trial ---
|
||||
logger.info("--- Optuna HPO Finished ---")
|
||||
logger.info(f"Number of finished trials: {len(study.trials)}")
|
||||
|
||||
# Filter trials to find the actual best one (excluding pruned/failed)
|
||||
try:
|
||||
best_trial = study.best_trial
|
||||
except ValueError: # Optuna raises ValueError if no trials completed successfully
|
||||
best_trial = None
|
||||
logger.warning("No successful trials completed. Cannot determine best trial.")
|
||||
|
||||
|
||||
if best_trial:
|
||||
logger.info("--- Best Trial ---")
|
||||
logger.info(f" Trial Number: {best_trial.number}")
|
||||
# Ensure value is not None before formatting
|
||||
best_value_str = f"{best_trial.value:.6f}" if best_trial.value is not None else "N/A"
|
||||
logger.info(f" Objective Value (Ensemble Metric): {best_value_str}")
|
||||
logger.info(f" Hyperparameters:")
|
||||
best_params = best_trial.params
|
||||
for key, value in best_params.items():
|
||||
logger.info(f" {key}: {value}")
|
||||
|
||||
# Save best hyperparameters
|
||||
best_params_file = hpo_base_output_dir / f"{safe_study_name}_best_params.json"
|
||||
try:
|
||||
with open(best_params_file, 'w', encoding='utf-8') as f:
|
||||
import json
|
||||
json.dump(best_params, f, indent=4)
|
||||
logger.info(f"Best hyperparameters saved to {best_params_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save best parameters: {e}", exc_info=True)
|
||||
|
||||
# Save the corresponding config
|
||||
best_config_file = hpo_base_output_dir / f"{safe_study_name}_best_config.yaml"
|
||||
try:
|
||||
# Use a fresh deepcopy to avoid modifying the original base_config
|
||||
best_config_dict = copy.deepcopy(base_config.model_dump(mode='python'))
|
||||
|
||||
# Update with best trial's hyperparameters
|
||||
# Pitfall: This assumes keys match exactly and exist in these sections.
|
||||
# A more robust approach might involve checking key existence or
|
||||
# iterating through the config structure if params are nested differently.
|
||||
for key, value in best_params.items():
|
||||
if key in best_config_dict.get('training', {}): best_config_dict['training'][key] = value
|
||||
elif key in best_config_dict.get('model', {}): best_config_dict['model'][key] = value
|
||||
elif key in best_config_dict.get('features', {}): best_config_dict['features'][key] = value
|
||||
else:
|
||||
logger.warning(f"Best parameter '{key}' not found in expected config sections (training, model, features).")
|
||||
|
||||
# Ensure forecast horizon is preserved from the original config
|
||||
best_config_dict['features']['forecast_horizon'] = base_config.features.forecast_horizon
|
||||
# Maybe remove optuna section from the best config?
|
||||
# best_config_dict.pop('optuna', None)
|
||||
|
||||
with open(best_config_file, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(best_config_dict, f, default_flow_style=False, sort_keys=False, allow_unicode=True)
|
||||
logger.info(f"Configuration for best trial saved to {best_config_file}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save best configuration: {e}", exc_info=True)
|
||||
|
||||
# Retrieve saved artifact paths and best ensemble method from user attributes
|
||||
best_trial_artifacts_dir = hpo_base_output_dir / "ensemble_runs_artifacts" / f"trial_{best_trial.number}"
|
||||
best_ensemble_method = best_trial.user_attrs.get("best_ensemble_method")
|
||||
# fold_model_paths = best_trial.user_attrs.get("fold_model_paths", []) # Removed
|
||||
# fold_scaler_paths = best_trial.user_attrs.get("fold_scaler_paths", []) # Removed
|
||||
fold_artifact_details = best_trial.user_attrs.get("fold_artifact_details", []) # Retrieve comprehensive details
|
||||
|
||||
if not best_trial_artifacts_dir.exists():
|
||||
logger.error(f"Artifacts directory for best trial {best_trial.number} not found: {best_trial_artifacts_dir}. Cannot save best ensemble definition.")
|
||||
elif not best_ensemble_method:
|
||||
logger.error(f"Best ensemble method not recorded for best trial {best_trial.number}. Cannot save best ensemble definition.")
|
||||
elif not fold_artifact_details: # Check if any artifact details were recorded
|
||||
logger.error(f"No artifact details recorded for best trial {best_trial.number}. Cannot save best ensemble definition.")
|
||||
else:
|
||||
# --- Save Best Ensemble Definition ---
|
||||
logger.info(f"Saving best ensemble definition for trial {best_trial.number}...")
|
||||
|
||||
ensemble_definition_file = hpo_base_output_dir / f"{safe_study_name}_best_ensemble.json"
|
||||
|
||||
best_ensemble_definition = {
|
||||
"trial_number": best_trial.number,
|
||||
"objective_value": best_trial.value,
|
||||
"hyperparameters": best_trial.params,
|
||||
"ensemble_method": best_ensemble_method,
|
||||
"fold_models": [], # List of dictionaries for each fold's model and scaler, input_size, config
|
||||
"ensemble_artifacts_base_dir": str(best_trial_artifacts_dir.relative_to(hpo_base_output_dir)) # Save path relative to hpo_base_output_dir
|
||||
}
|
||||
|
||||
# Populate fold_models with paths to saved artifacts
|
||||
for artifact_detail in fold_artifact_details:
|
||||
fold_def = {
|
||||
"fold_id": artifact_detail.get("fold_id"), # Include fold ID
|
||||
"model_path": None,
|
||||
"target_scaler_path": None,
|
||||
"input_size_path": None,
|
||||
"config_path": None,
|
||||
}
|
||||
|
||||
# Process each path, making it relative if possible
|
||||
for key in ["model_path", "target_scaler_path", "input_size_path", "config_path"]:
|
||||
abs_path_str = artifact_detail.get(key)
|
||||
if abs_path_str:
|
||||
abs_path = Path(abs_path_str)
|
||||
try:
|
||||
# Make path relative to the trial artifacts dir
|
||||
relative_path = str(abs_path.relative_to(best_trial_artifacts_dir))
|
||||
fold_def[key] = relative_path
|
||||
except ValueError:
|
||||
logger.warning(f"Failed to make path {abs_path} relative to {best_trial_artifacts_dir}. Saving absolute path for {key}.")
|
||||
fold_def[key] = str(abs_path) # Fallback to absolute path
|
||||
|
||||
best_ensemble_definition["fold_models"].append(fold_def)
|
||||
|
||||
|
||||
try:
|
||||
with open(ensemble_definition_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(best_ensemble_definition, f, indent=4)
|
||||
logger.info(f"Best ensemble definition saved to {ensemble_definition_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save best ensemble definition: {e}", exc_info=True)
|
||||
|
||||
|
||||
|
||||
# --- Optional: Clean up artifact directories for non-best trials ---
|
||||
if not args.keep_artifacts:
|
||||
logger.info("Cleaning up artifact directories for non-best trials...")
|
||||
ensemble_artifacts_base_dir = hpo_base_output_dir / "ensemble_runs_artifacts"
|
||||
if ensemble_artifacts_base_dir.exists():
|
||||
for item in ensemble_artifacts_base_dir.iterdir():
|
||||
if item.is_dir():
|
||||
# Check if this directory belongs to the best trial
|
||||
if best_trial and item.name == f"trial_{best_trial.number}":
|
||||
logger.debug(f"Keeping artifact directory for best trial: {item}")
|
||||
continue
|
||||
else:
|
||||
logger.debug(f"Removing artifact directory for non-best trial: {item}")
|
||||
try:
|
||||
import shutil
|
||||
shutil.rmtree(item)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to remove directory {item}: {e}", exc_info=True)
|
||||
else:
|
||||
logger.debug(f"Artifacts base directory not found for cleanup: {ensemble_artifacts_base_dir}")
|
||||
|
||||
|
||||
except optuna.exceptions.StorageInternalError as e:
|
||||
logger.critical(f"Optuna storage error: {e}. Check storage path/permissions: {storage_path}", exc_info=True)
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logger.critical(f"An critical error occurred during the Optuna study: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_hpo() # Changed main() to run_hpo()
|
Reference in New Issue
Block a user