613 lines
33 KiB
Python
613 lines
33 KiB
Python
import argparse
|
|
import logging
|
|
import sys
|
|
import warnings
|
|
import copy
|
|
from pathlib import Path
|
|
import time
|
|
import numpy as np
|
|
import pandas as pd
|
|
import yaml
|
|
import optuna
|
|
|
|
from forecasting_model.utils.forecast_config_model import MainConfig
|
|
from forecasting_model import TimeSeriesCrossValidationSplitter, load_raw_data
|
|
from forecasting_model_run import run_single_fold
|
|
from forecasting_model.train.ensemble_evaluation import run_ensemble_evaluation
|
|
from typing import List, Dict, Any #
|
|
|
|
# Import helper functions
|
|
from forecasting_model.utils.helper import load_config, set_seeds
|
|
|
|
# --- Suppress specific PL warnings about logger=True with no logger ---
|
|
# This is expected behavior in optuna_run.py where logger=False is intentional
|
|
warnings.filterwarnings(
|
|
"ignore",
|
|
message=".*You called `self.log.*logger=True.*no logger configured.*",
|
|
category=UserWarning,
|
|
module="pytorch_lightning.core.module"
|
|
)
|
|
|
|
# Silence overly verbose libraries if needed
|
|
mpl_logger = logging.getLogger('matplotlib')
|
|
mpl_logger.setLevel(logging.WARNING)
|
|
pil_logger = logging.getLogger('PIL')
|
|
pil_logger.setLevel(logging.WARNING)
|
|
pl_logger = logging.getLogger('pytorch_lightning')
|
|
pl_logger.setLevel(logging.WARNING)
|
|
|
|
# --- Basic Logging Setup ---
|
|
logging.basicConfig(level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)-7s - %(message)s',
|
|
datefmt='%H:%M:%S')
|
|
# Get the root logger
|
|
logger = logging.getLogger()
|
|
|
|
|
|
# --- Argument Parsing ---
|
|
def parse_arguments():
|
|
"""Parses command-line arguments for Optuna Ensemble HPO."""
|
|
parser = argparse.ArgumentParser(
|
|
description="Run HPO optimizing ensemble performance using Optuna.",
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
)
|
|
parser.add_argument(
|
|
'-c', '--config', type=str, default='forecasting_config.yaml',
|
|
help="Path to the YAML configuration file."
|
|
)
|
|
parser.add_argument(
|
|
'--output-dir', type=str, default=None,
|
|
help="Override base output directory for HPO results."
|
|
)
|
|
parser.add_argument(
|
|
'--keep-artifacts', action='store_true',
|
|
help="Prevent cleanup of trial directories after best trial is determined."
|
|
)
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
# --- Optuna Objective Function ---
|
|
def objective(
|
|
trial: optuna.Trial,
|
|
base_config: MainConfig,
|
|
df: pd.DataFrame,
|
|
ensemble_hpo_output_dir: Path # Renamed parameter for clarity
|
|
) -> float: # Return the single ensemble metric to optimize
|
|
"""
|
|
Optuna objective function optimizing ensemble performance.
|
|
"""
|
|
logger.info(f"\n--- Starting Optuna Trial {trial.number} ---")
|
|
trial_start_time = time.perf_counter()
|
|
|
|
# Define trial-specific output directory for fold artifacts
|
|
trial_artifacts_dir = ensemble_hpo_output_dir / "ensemble_runs_artifacts" / f"trial_{trial.number}"
|
|
trial_artifacts_dir.mkdir(parents=True, exist_ok=True)
|
|
logger.debug(f"Trial artifacts will be saved to: {trial_artifacts_dir}")
|
|
|
|
hpo_config = base_config.optuna
|
|
# Metric for pruning based on individual fold performance
|
|
validation_metric_monitor = hpo_config.metric_to_optimize
|
|
# Ensemble metric and method to optimize (e.g., MAE of the 'mean' ensemble)
|
|
ensemble_metric_optimize = 'MAE'
|
|
ensemble_method_optimize = 'mean'
|
|
optimization_direction = hpo_config.direction # 'minimize' or 'maximize'
|
|
worst_value = float('inf') if optimization_direction == 'minimize' else float('-inf')
|
|
|
|
# Store paths and details for all saved artifacts for this trial's folds
|
|
fold_artifact_details: List[Dict[str, Any]] = [] # Changed to list of dicts
|
|
|
|
# --- 1. Suggest Hyperparameters ---
|
|
try:
|
|
trial_config_dict = copy.deepcopy(base_config.model_dump(mode='python'))
|
|
except Exception as e:
|
|
logger.error(f"Trial {trial.number}: Failed to deep copy base config: {e}", exc_info=True)
|
|
return worst_value
|
|
|
|
# ----- Suggest Hyperparameters -----
|
|
# Modify trial_config_dict using trial.suggest_*
|
|
trial_config_dict['training']['learning_rate'] = trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True)
|
|
trial_config_dict['training']['batch_size'] = trial.suggest_categorical('batch_size', [16, 32, 64, 128, 256, 512])
|
|
trial_config_dict['training']['loss_function'] = trial.suggest_categorical('loss_function', ['MSE', 'MAE'])
|
|
trial_config_dict['model']['hidden_size'] = trial.suggest_int('hidden_size', 18, 498, step=32)
|
|
trial_config_dict['model']['num_layers'] = trial.suggest_int('num_layers', 1, 3)
|
|
trial_config_dict['model']['dropout'] = trial.suggest_float('dropout', 0.0, 0.5, step=0.05)
|
|
trial_config_dict['features']['sequence_length'] = trial.suggest_int('sequence_length', 3, 72, step=2)
|
|
trial_config_dict['features']['scaling_method'] = trial.suggest_categorical('scaling_method', ['standard', 'minmax', None])
|
|
use_configured_lags = trial.suggest_categorical('use_lags', [True, False])
|
|
if not use_configured_lags: trial_config_dict['features']['lags'] = []
|
|
use_configured_rolling = trial.suggest_categorical('use_rolling_windows', [True, False])
|
|
if not use_configured_rolling: trial_config_dict['features']['rolling_window_sizes'] = []
|
|
trial_config_dict['features']['use_time_features'] = trial.suggest_categorical('use_time_features', [True, False])
|
|
trial_config_dict['features']['sinus_curve'] = trial.suggest_categorical('sinus_curve', [True, False])
|
|
trial_config_dict['features']['cosine_curve'] = trial.suggest_categorical('cosine_curve', [True, False])
|
|
trial_config_dict['features']['fill_nan'] = trial.suggest_categorical('fill_nan', ['ffill', 'bfill', 0])
|
|
# ----- End of Suggestions -----
|
|
|
|
# --- 2. Re-validate Trial Config ---
|
|
try:
|
|
trial_config_dict['features']['forecast_horizon'] = base_config.features.forecast_horizon
|
|
# Disable plotting during HPO runs to save time/resources
|
|
trial_config_dict['evaluation']['save_plots'] = False
|
|
trial_config = MainConfig(**trial_config_dict)
|
|
logger.info(f"Trial {trial.number} Parameters: {trial.params}")
|
|
except Exception as e:
|
|
logger.error(f"Trial {trial.number}: Invalid config generated: {e}", exc_info=True)
|
|
return worst_value
|
|
|
|
# --- Early check for data length ---
|
|
# ... (Keep the check as in optuna_run.py) ...
|
|
try:
|
|
if not isinstance(trial_config.features.forecast_horizon, list) or not trial_config.features.forecast_horizon:
|
|
raise ValueError("Trial config has invalid forecast_horizon list.")
|
|
min_data_for_sequence = trial_config.features.sequence_length + max(trial_config.features.forecast_horizon)
|
|
if min_data_for_sequence > len(df):
|
|
logger.warning(f"Trial {trial.number}: Skipped. sequence_length + max_horizon ({min_data_for_sequence}) exceeds data length ({len(df)}).")
|
|
# Report worst value so Optuna knows this trial failed badly
|
|
# Using study direction to determine the appropriate "worst" value
|
|
return worst_value
|
|
# raise optuna.TrialPruned() # Alternative: Prune instead of returning worst value
|
|
except Exception as e:
|
|
logger.error(f"Trial {trial.number}: Error during pre-check: {e}", exc_info=True)
|
|
return worst_value
|
|
|
|
# --- 3. Run Cross-Validation Training (Saving Artifacts) ---
|
|
all_fold_best_val_scores = {} # Store best val scores for pruning
|
|
actual_folds_trained = 0
|
|
# Store paths to saved models and scalers for this trial
|
|
# fold_model_paths = [] # Removed, using fold_artifact_details instead
|
|
# fold_scaler_paths = [] # Removed, using fold_artifact_details instead
|
|
try:
|
|
cv_splitter = TimeSeriesCrossValidationSplitter(trial_config.cross_validation, len(df))
|
|
|
|
for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split()):
|
|
fold_id = fold_num + 1
|
|
logger.info(f"Trial {trial.number}, Fold {fold_id}/{cv_splitter.n_splits}: Training model...")
|
|
current_fold_best_metric = None # Reset for each fold
|
|
|
|
try:
|
|
# Use run_single_fold - it handles training and saving artifacts
|
|
fold_metrics, best_val_score, saved_model_path, saved_target_scaler_path, saved_data_scaler_path, saved_input_size_path, saved_config_path = run_single_fold(
|
|
fold_num=fold_num,
|
|
train_idx=train_idx, val_idx=val_idx, test_idx=test_idx,
|
|
config=trial_config,
|
|
full_df=df,
|
|
output_base_dir=trial_artifacts_dir,
|
|
enable_progress_bar=False
|
|
)
|
|
actual_folds_trained += 1
|
|
all_fold_best_val_scores[fold_id] = best_val_score
|
|
|
|
# Store all artifact paths for this fold
|
|
fold_artifact_details.append({
|
|
"fold_id": fold_id,
|
|
"model_path": str(saved_model_path) if saved_model_path else None,
|
|
"target_scaler_path": str(saved_target_scaler_path) if saved_target_scaler_path else None,
|
|
"data_scaler_path": str(saved_data_scaler_path) if saved_data_scaler_path else None,
|
|
"input_size_path": str(saved_input_size_path) if saved_input_size_path else None,
|
|
"config_path": str(saved_config_path) if saved_config_path else None,
|
|
})
|
|
|
|
# Check if the monitored validation metric was returned
|
|
if best_val_score is not None and np.isfinite(best_val_score):
|
|
current_fold_best_metric = best_val_score
|
|
logger.info(f"Trial {trial.number}, Fold {fold_id}: Best val score ({validation_metric_monitor}) = {current_fold_best_metric:.4f}")
|
|
else:
|
|
# Use worst value if metric is missing/invalid for pruning
|
|
logger.warning(f"Trial {trial.number}, Fold {fold_id}: Invalid or missing validation score ({validation_metric_monitor}). Using {worst_value} for pruning.")
|
|
current_fold_best_metric = worst_value # Assign worst for pruning report
|
|
|
|
# Report intermediate value (individual fold validation score) for pruning
|
|
trial.report(current_fold_best_metric, fold_num)
|
|
if trial.should_prune():
|
|
logger.info(f"Trial {trial.number}: Pruned after fold {fold_id}.")
|
|
raise optuna.TrialPruned()
|
|
|
|
except optuna.TrialPruned:
|
|
raise # Propagate prune signal
|
|
except Exception as e:
|
|
logger.error(f"Trial {trial.number}, Fold {fold_id}: Failed CV fold training: {e}", exc_info=True)
|
|
all_fold_best_val_scores[fold_id] = None # Mark fold as failed
|
|
trial.report(worst_value, fold_num)
|
|
|
|
except optuna.TrialPruned:
|
|
logger.info(f"Trial {trial.number}: Pruned during CV training phase.")
|
|
return worst_value # Return worst value when pruned
|
|
except Exception as e:
|
|
logger.critical(f"Trial {trial.number}: Failed critically during CV training setup/loop: {e}", exc_info=True)
|
|
return worst_value
|
|
|
|
|
|
# --- 4. Run Ensemble Evaluation ---
|
|
if actual_folds_trained < 2:
|
|
logger.error(f"Trial {trial.number}: Only {actual_folds_trained} folds trained successfully. Cannot run ensemble evaluation.")
|
|
return worst_value # Not enough models for ensemble
|
|
|
|
logger.info(f"Trial {trial.number}: Starting Ensemble Evaluation using {actual_folds_trained} trained models...")
|
|
ensemble_metric_final = worst_value # Initialize to worst
|
|
best_ensemble_method_for_trial = None # Track the best method for this trial
|
|
try:
|
|
# Run evaluation using the artifacts saved in the trial's output directory
|
|
ensemble_results = run_ensemble_evaluation(
|
|
config=trial_config,
|
|
output_base_dir=trial_artifacts_dir # Pass the specific trial's artifact dir
|
|
)
|
|
|
|
if ensemble_results:
|
|
# Aggregate the results to get the final objective value
|
|
ensemble_metrics_for_method = []
|
|
for fold_num, fold_res in ensemble_results.items():
|
|
if fold_res and ensemble_method_optimize in fold_res:
|
|
method_metrics = fold_res[ensemble_method_optimize]
|
|
if method_metrics and ensemble_metric_optimize in method_metrics:
|
|
metric_val = method_metrics[ensemble_metric_optimize]
|
|
if metric_val is not None and np.isfinite(metric_val):
|
|
ensemble_metrics_for_method.append(metric_val)
|
|
else:
|
|
logger.warning(f"Trial {trial.number}: Invalid ensemble metric value found for fold {fold_num}, method '{ensemble_method_optimize}', metric '{ensemble_metric_optimize}'.")
|
|
else:
|
|
logger.warning(f"Trial {trial.number}: Metric '{ensemble_metric_optimize}' not found for method '{ensemble_method_optimize}' in fold {fold_num}.")
|
|
else:
|
|
logger.warning(f"Trial {trial.number}: Ensemble method '{ensemble_method_optimize}' results not found for fold {fold_num}.")
|
|
|
|
if not ensemble_metrics_for_method:
|
|
logger.error(f"Trial {trial.number}: No valid ensemble metrics found for method '{ensemble_method_optimize}', metric '{ensemble_metric_optimize}'.")
|
|
ensemble_metric_final = worst_value
|
|
else:
|
|
# Calculate the mean of the chosen ensemble metric across test folds
|
|
ensemble_metric_final = np.mean(ensemble_metrics_for_method)
|
|
logger.info(f"Trial {trial.number}: Final Ensemble Metric (Avg {ensemble_method_optimize} {ensemble_metric_optimize}): {ensemble_metric_final:.6f}")
|
|
|
|
# Determine the best ensemble method based on average performance across folds
|
|
# This requires re-calculating averages for *all* methods evaluated by run_ensemble_evaluation
|
|
all_ensemble_methods = set()
|
|
for fold_res in ensemble_results.values():
|
|
if fold_res: all_ensemble_methods.update(fold_res.keys())
|
|
|
|
avg_metrics_per_method = {}
|
|
for method in all_ensemble_methods:
|
|
method_metrics_across_folds = []
|
|
for fold_res in ensemble_results.values():
|
|
if fold_res and method in fold_res and ensemble_metric_optimize in fold_res[method]:
|
|
metric_val = fold_res[method][ensemble_metric_optimize]
|
|
if metric_val is not None and np.isfinite(metric_val):
|
|
method_metrics_across_folds.append(metric_val)
|
|
if method_metrics_across_folds:
|
|
avg_metrics_per_method[method] = np.mean(method_metrics_across_folds)
|
|
|
|
if avg_metrics_per_method:
|
|
if optimization_direction == 'minimize':
|
|
best_ensemble_method_for_trial = min(avg_metrics_per_method, key=avg_metrics_per_method.get)
|
|
else: # maximize
|
|
best_ensemble_method_for_trial = max(avg_metrics_per_method, key=avg_metrics_per_method.get)
|
|
logger.info(f"Trial {trial.number}: Best performing ensemble method for this trial (based on avg {ensemble_metric_optimize}): {best_ensemble_method_for_trial}")
|
|
else:
|
|
logger.warning(f"Trial {trial.number}: Could not determine best ensemble method based on average {ensemble_metric_optimize}.")
|
|
|
|
else:
|
|
logger.error(f"Trial {trial.number}: Ensemble evaluation function returned no results.")
|
|
ensemble_metric_final = worst_value
|
|
|
|
except Exception as e:
|
|
logger.error(f"Trial {trial.number}: Failed during ensemble evaluation phase: {e}", exc_info=True)
|
|
ensemble_metric_final = worst_value
|
|
|
|
|
|
# --- 5. Return Final Objective Value ---
|
|
trial_duration = time.perf_counter() - trial_start_time
|
|
logger.info(f"--- Trial {trial.number}: Finished ---")
|
|
logger.info(f" Final Objective Value (Avg Ensemble {ensemble_method_optimize} {ensemble_metric_optimize}): {ensemble_metric_final:.6f}")
|
|
logger.info(f" Total time: {trial_duration:.2f}s")
|
|
|
|
# Store ensemble evaluation results and best method in trial user attributes
|
|
# This makes it easier to retrieve the best ensemble details after the study
|
|
trial.set_user_attr("ensemble_evaluation_results", ensemble_results)
|
|
trial.set_user_attr("best_ensemble_method", best_ensemble_method_for_trial)
|
|
# trial.set_user_attr("fold_model_paths", fold_model_paths) # Removed
|
|
# trial.set_user_attr("fold_scaler_paths", fold_scaler_paths) # Removed
|
|
trial.set_user_attr("fold_artifact_details", fold_artifact_details) # Added comprehensive artifact details
|
|
|
|
return ensemble_metric_final
|
|
|
|
|
|
# --- Main HPO Execution ---
|
|
def run_hpo():
|
|
"""Main execution function for HPO optimizing ensemble performance."""
|
|
args = parse_arguments()
|
|
config_path = Path(args.config)
|
|
try:
|
|
base_config = load_config(config_path, MainConfig)
|
|
logger.info(f"Successfully loaded configuration from {config_path}")
|
|
except Exception as e:
|
|
logger.critical(f"Failed to load configuration from {config_path}: {e}", exc_info=True)
|
|
sys.exit(1)
|
|
|
|
# --- Setup Output Dir ---
|
|
# 1. Determine the main output directory
|
|
if args.output_dir:
|
|
# Command-line argument overrides config
|
|
main_output_dir = Path(args.output_dir)
|
|
logger.info(f"Using main output directory from command line: {main_output_dir}")
|
|
elif hasattr(base_config, 'output_dir') and base_config.output_dir:
|
|
main_output_dir = Path(base_config.output_dir)
|
|
logger.info(f"Using main output directory from config file: {main_output_dir}")
|
|
else:
|
|
main_output_dir = Path("output") # Default if not specified anywhere
|
|
logger.warning(f"No output directory specified in config or args, defaulting to: {main_output_dir}")
|
|
|
|
# 2. Define the specific directory for this ensemble HPO run
|
|
ensemble_hpo_output_dir = main_output_dir / "ensemble"
|
|
|
|
# 3. Create directories
|
|
main_output_dir.mkdir(parents=True, exist_ok=True)
|
|
ensemble_hpo_output_dir.mkdir(parents=True, exist_ok=True)
|
|
logger.info(f"Ensemble HPO outputs will be saved under: {ensemble_hpo_output_dir}")
|
|
|
|
|
|
# --- Setup Logging ---
|
|
try:
|
|
level_name = base_config.log_level.upper()
|
|
# getattr(logging, 'levelname') is the **new** way to do this ( deprecated, but works! )
|
|
effective_log_level = logging.getLevelName(level_name)
|
|
# Ensure study name is filesystem-safe if used directly
|
|
safe_study_name = base_config.optuna.study_name
|
|
safe_study_name = "".join(c if c.isalnum() or c in ('_', '-') else '_' for c in safe_study_name)
|
|
# Place log file directly inside the ensemble HPO directory
|
|
log_file = ensemble_hpo_output_dir / f"{safe_study_name}_ensemble_hpo.log"
|
|
file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8') # Specify encoding
|
|
formatter = logging.Formatter('%(asctime)s - %(name)-25s - %(levelname)-7s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
|
file_handler.setFormatter(formatter)
|
|
# Prevent adding duplicate handlers if script/function is called multiple times
|
|
if not any(isinstance(h, logging.FileHandler) and h.baseFilename == str(log_file.resolve()) for h in logger.handlers):
|
|
logger.addHandler(file_handler)
|
|
logger.setLevel(effective_log_level)
|
|
logger.info(f"Set log level to {level_name}. Logging HPO run to console and {log_file}")
|
|
if effective_log_level <= logging.DEBUG: logger.debug("Debug logging enabled.")
|
|
except (AttributeError, ValueError, TypeError) as e:
|
|
logger.warning(f"Could not set log level from config: {e}. Defaulting to INFO.")
|
|
logger.setLevel(logging.INFO)
|
|
# Still try to log to a default file if possible
|
|
try:
|
|
# Default log file also goes into the specific ensemble directory
|
|
log_file = ensemble_hpo_output_dir / "default_ensemble_hpo.log"
|
|
file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8')
|
|
formatter = logging.Formatter('%(asctime)s - %(name)-25s - %(levelname)-7s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
|
file_handler.setFormatter(formatter)
|
|
if not any(isinstance(h, logging.FileHandler) and h.baseFilename == str(log_file.resolve()) for h in logger.handlers):
|
|
logger.addHandler(file_handler)
|
|
logger.info(f"Logging to default file: {log_file}")
|
|
except Exception as log_e:
|
|
logger.error(f"Failed to set up default file logging: {log_e}")
|
|
|
|
|
|
# --- Setup Seeding ---
|
|
set_seeds(getattr(base_config, 'random_seed', 42))
|
|
|
|
# --- Load Data ---
|
|
try:
|
|
logger.info("Loading base dataset for HPO...")
|
|
df = load_raw_data(base_config.data)
|
|
logger.info(f"Base dataset loaded. Shape: {df.shape}")
|
|
except FileNotFoundError as e:
|
|
logger.critical(f"Data file not found: {e}", exc_info=True)
|
|
sys.exit(1)
|
|
except Exception as e:
|
|
logger.critical(f"Failed to load raw data for HPO: {e}", exc_info=True)
|
|
sys.exit(1)
|
|
|
|
# --- Optuna Study Setup ---
|
|
try:
|
|
hpo_config = base_config.optuna
|
|
if not hpo_config.enabled:
|
|
logger.info("Optuna optimization is disabled in the configuration.")
|
|
sys.exit(0)
|
|
except AttributeError:
|
|
logger.critical("Optuna configuration section ('optuna') missing.")
|
|
sys.exit(1)
|
|
|
|
storage_string = hpo_config.storage # Use a more descriptive name
|
|
storage_path = None # Initialize
|
|
|
|
if storage_string and storage_string.startswith("sqlite:///"):
|
|
db_filename = storage_string.replace("sqlite:///", "").strip()
|
|
if not db_filename:
|
|
# Use study name if filename is empty
|
|
db_filename = f"{safe_study_name}_ensemble.db"
|
|
logger.warning(f"SQLite path in config was empty, using default filename: {db_filename}")
|
|
# Place the DB file inside the ensemble HPO directory
|
|
db_path = ensemble_hpo_output_dir / db_filename
|
|
storage_path = f"sqlite:///{db_path.resolve()}"
|
|
logger.info(f"Using SQLite storage: {storage_path}")
|
|
elif storage_string:
|
|
# Assume it's a non-SQLite connection string or a pre-configured path
|
|
storage_path = storage_string
|
|
logger.warning(f"Using non-SQLite Optuna storage: {storage_path}. Note: DB file will not be placed inside {ensemble_hpo_output_dir}")
|
|
else:
|
|
storage_path = None # Explicitly set to None for in-memory
|
|
logger.warning("No Optuna storage DB specified, using in-memory storage (results will be lost on exit).")
|
|
|
|
|
|
try:
|
|
# Single objective study based on ensemble performance
|
|
study = optuna.create_study(
|
|
study_name=hpo_config.study_name,
|
|
storage=storage_path,
|
|
direction=hpo_config.direction, # 'minimize' or 'maximize'
|
|
load_if_exists=True,
|
|
pruner=optuna.pruners.MedianPruner() if hpo_config.pruning else optuna.pruners.NopPruner()
|
|
)
|
|
|
|
# --- Run Optimization ---
|
|
logger.info(f"Starting Optuna optimization for ensemble performance: study='{hpo_config.study_name}', n_trials={hpo_config.n_trials}, direction='{hpo_config.direction}'")
|
|
study.optimize(
|
|
lambda trial: objective(trial, base_config, df, ensemble_hpo_output_dir), # Pass ensemble output dir
|
|
n_trials=hpo_config.n_trials,
|
|
timeout=None,
|
|
gc_after_trial=True # Garbage collect after trial
|
|
)
|
|
|
|
# --- Report and Save Best Trial ---
|
|
logger.info("--- Optuna HPO Finished ---")
|
|
logger.info(f"Number of finished trials: {len(study.trials)}")
|
|
|
|
# Filter trials to find the actual best one (excluding pruned/failed)
|
|
try:
|
|
best_trial = study.best_trial
|
|
except ValueError: # Optuna raises ValueError if no trials completed successfully
|
|
best_trial = None
|
|
logger.warning("No successful trials completed. Cannot determine best trial.")
|
|
|
|
|
|
if best_trial:
|
|
logger.info("--- Best Trial ---")
|
|
logger.info(f" Trial Number: {best_trial.number}")
|
|
# Ensure value is not None before formatting
|
|
best_value_str = f"{best_trial.value:.6f}" if best_trial.value is not None else "N/A"
|
|
logger.info(f" Objective Value (Ensemble Metric): {best_value_str}")
|
|
logger.info(f" Hyperparameters:")
|
|
best_params = best_trial.params
|
|
for key, value in best_params.items():
|
|
logger.info(f" {key}: {value}")
|
|
|
|
# Save best hyperparameters directly into the ensemble output dir
|
|
best_params_file = ensemble_hpo_output_dir / f"{safe_study_name}_best_params.json"
|
|
try:
|
|
with open(best_params_file, 'w', encoding='utf-8') as f:
|
|
import json
|
|
json.dump(best_params, f, indent=4)
|
|
logger.info(f"Best hyperparameters saved to {best_params_file}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to save best parameters: {e}", exc_info=True)
|
|
|
|
# Save the corresponding config directly into the ensemble output dir
|
|
best_config_file = ensemble_hpo_output_dir / f"{safe_study_name}_best_config.yaml"
|
|
try:
|
|
# Use a fresh deepcopy to avoid modifying the original base_config
|
|
best_config_dict = copy.deepcopy(base_config.model_dump(mode='python'))
|
|
|
|
# Update with best trial's hyperparameters
|
|
# Pitfall: This assumes keys match exactly and exist in these sections.
|
|
# A more robust approach might involve checking key existence or
|
|
# iterating through the config structure if params are nested differently.
|
|
for key, value in best_params.items():
|
|
if key in best_config_dict.get('training', {}): best_config_dict['training'][key] = value
|
|
elif key in best_config_dict.get('model', {}): best_config_dict['model'][key] = value
|
|
elif key in best_config_dict.get('features', {}): best_config_dict['features'][key] = value
|
|
elif key in ["use_lags", "use_rolling_windows"]:
|
|
# IF false, we set this to [] in the parameter suggestion section.
|
|
pass
|
|
else:
|
|
logger.warning(f"Best parameter '{key}' not found in expected config sections (training, model, features).")
|
|
|
|
# Ensure forecast horizon is preserved from the original config
|
|
best_config_dict['features']['forecast_horizon'] = base_config.features.forecast_horizon
|
|
# Maybe remove optuna section from the best config?
|
|
# best_config_dict.pop('optuna', None)
|
|
|
|
with open(best_config_file, 'w', encoding='utf-8') as f:
|
|
yaml.dump(best_config_dict, f, default_flow_style=False, sort_keys=False, allow_unicode=True)
|
|
logger.info(f"Configuration for best trial saved to {best_config_file}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to save best configuration: {e}", exc_info=True)
|
|
|
|
# --- Retrieve Artifacts and Save Ensemble Definition ---
|
|
# Base directory for this trial's artifacts
|
|
best_trial_artifacts_dir = ensemble_hpo_output_dir / "ensemble_runs_artifacts" / f"trial_{best_trial.number}"
|
|
best_ensemble_method = best_trial.user_attrs.get("best_ensemble_method")
|
|
fold_artifact_details = best_trial.user_attrs.get("fold_artifact_details", [])
|
|
|
|
# Check if artifacts exist and data is available
|
|
if not best_trial_artifacts_dir.exists():
|
|
logger.error(f"Artifacts directory for best trial {best_trial.number} not found: {best_trial_artifacts_dir}. Cannot save best ensemble definition.")
|
|
elif not best_ensemble_method:
|
|
logger.error(f"Best ensemble method not recorded for best trial {best_trial.number}. Cannot save best ensemble definition.")
|
|
elif not fold_artifact_details: # Check if any artifact details were recorded
|
|
logger.error(f"No artifact details recorded for best trial {best_trial.number}. Cannot save best ensemble definition.")
|
|
else:
|
|
# --- Save Best Ensemble Definition ---
|
|
logger.info(f"Saving best ensemble definition for trial {best_trial.number}...")
|
|
|
|
# Save definition file directly into the ensemble output dir
|
|
ensemble_definition_file = ensemble_hpo_output_dir / f"{safe_study_name}_best_ensemble.json"
|
|
|
|
best_ensemble_definition = {
|
|
"trial_number": best_trial.number,
|
|
"objective_value": best_trial.value,
|
|
"hyperparameters": best_trial.params,
|
|
"ensemble_method": best_ensemble_method,
|
|
# The base dir for artifacts, relative to the main ensemble output dir
|
|
"ensemble_artifacts_base_dir": str(best_trial_artifacts_dir.relative_to(ensemble_hpo_output_dir)), # Corrected path
|
|
"fold_models": [],
|
|
}
|
|
|
|
# Populate fold_models with paths relative to best_trial_artifacts_dir
|
|
for artifact_detail in fold_artifact_details:
|
|
fold_def = {
|
|
"fold_id": artifact_detail.get("fold_id"),
|
|
"model_path": None,
|
|
"target_scaler_path": None,
|
|
"data_scaler_path": None, # Added placeholder
|
|
"input_size_path": None,
|
|
"config_path": None,
|
|
}
|
|
|
|
# Process each path, making it relative if possible
|
|
# Added "data_scaler_path" to the list of keys to process
|
|
for key in ["model_path", "target_scaler_path", "data_scaler_path", "input_size_path", "config_path"]:
|
|
abs_path_str = artifact_detail.get(key)
|
|
if abs_path_str:
|
|
abs_path = Path(abs_path_str).absolute()
|
|
try:
|
|
# Make path relative to the trial artifacts dir (where models/scalers reside)
|
|
relative_path = str(abs_path.relative_to(best_trial_artifacts_dir.absolute()))
|
|
fold_def[key] = relative_path
|
|
except ValueError:
|
|
# This shouldn't happen if paths were saved correctly, but handle just in case
|
|
logger.warning(f"Failed to make path {abs_path} relative to {best_trial_artifacts_dir}. Saving absolute path for {key}.")
|
|
fold_def[key] = str(abs_path) # Fallback to absolute path
|
|
|
|
best_ensemble_definition["fold_models"].append(fold_def)
|
|
|
|
|
|
try:
|
|
with open(ensemble_definition_file, 'w', encoding='utf-8') as f:
|
|
json.dump(best_ensemble_definition, f, indent=4)
|
|
logger.info(f"Best ensemble definition saved to {ensemble_definition_file}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to save best ensemble definition: {e}", exc_info=True)
|
|
|
|
|
|
|
|
# --- Optional: Clean up artifact directories for non-best trials ---
|
|
if not args.keep_artifacts:
|
|
logger.info("Cleaning up artifact directories for non-best trials...")
|
|
# The base path for all trial artifacts within the ensemble dir
|
|
ensemble_artifacts_base_dir = ensemble_hpo_output_dir / "ensemble_runs_artifacts" # Corrected base path
|
|
if ensemble_artifacts_base_dir.exists():
|
|
for item in ensemble_artifacts_base_dir.iterdir():
|
|
if item.is_dir():
|
|
# Check if this directory belongs to the best trial
|
|
if best_trial and item.name == f"trial_{best_trial.number}":
|
|
logger.debug(f"Keeping artifact directory for best trial: {item}")
|
|
continue
|
|
else:
|
|
logger.debug(f"Removing artifact directory for non-best trial: {item}")
|
|
try:
|
|
import shutil
|
|
shutil.rmtree(item)
|
|
except Exception as e:
|
|
logger.error(f"Failed to remove directory {item}: {e}", exc_info=True)
|
|
else:
|
|
logger.debug(f"Artifacts base directory not found for cleanup: {ensemble_artifacts_base_dir}")
|
|
|
|
|
|
except optuna.exceptions.StorageInternalError as e:
|
|
logger.critical(f"Optuna storage error: {e}. Check storage path/permissions: {storage_path}", exc_info=True)
|
|
sys.exit(1)
|
|
except Exception as e:
|
|
logger.critical(f"An critical error occurred during the Optuna study: {e}", exc_info=True)
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_hpo() # Changed main() to run_hpo() |