482 lines
24 KiB
Python
482 lines
24 KiB
Python
import argparse
|
|
import logging
|
|
import sys
|
|
import warnings
|
|
|
|
import copy
|
|
from pathlib import Path
|
|
import time
|
|
|
|
import pandas as pd
|
|
import torch
|
|
import yaml
|
|
|
|
import optuna
|
|
import pytorch_lightning as pl
|
|
from pytorch_lightning.callbacks import EarlyStopping
|
|
|
|
# Import necessary components from the forecasting_model package
|
|
from forecasting_model.utils.forecast_config_model import MainConfig
|
|
from forecasting_model.utils.data_processing import (
|
|
prepare_fold_data_and_loaders,
|
|
split_data_classic
|
|
)
|
|
from forecasting_model.train.model import LSTMForecastLightningModule
|
|
from forecasting_model.train.classic import run_model_training
|
|
|
|
|
|
# Import helper functions from forecasting_model_run.py
|
|
from forecasting_model.utils.helper import load_config, set_seeds
|
|
|
|
# Import the data processing functions
|
|
from forecasting_model import load_raw_data
|
|
|
|
# --- Suppress specific PL warnings about logger=True with no logger ---
|
|
# This is expected behavior in optuna_run.py where logger=False is intentional
|
|
warnings.filterwarnings(
|
|
"ignore",
|
|
message=".*You called `self.log.*logger=True.*no logger configured.*",
|
|
category=UserWarning,
|
|
module="pytorch_lightning.core.module"
|
|
)
|
|
|
|
# Silence overly verbose libraries if needed
|
|
mpl_logger = logging.getLogger('matplotlib')
|
|
mpl_logger.setLevel(logging.WARNING)
|
|
pil_logger = logging.getLogger('PIL')
|
|
pil_logger.setLevel(logging.WARNING)
|
|
pl_logger = logging.getLogger('pytorch_lightning')
|
|
pl_logger.setLevel(logging.WARNING)
|
|
|
|
# --- Basic Logging Setup ---
|
|
logging.basicConfig(level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)-7s - %(message)s',
|
|
datefmt='%H:%M:%S')
|
|
# Get the root logger
|
|
logger = logging.getLogger()
|
|
|
|
|
|
# --- Argument Parsing (Simplified) ---
|
|
def parse_arguments():
|
|
"""Parses command-line arguments for Optuna HPO."""
|
|
parser = argparse.ArgumentParser(
|
|
description="Run Hyperparameter Optimization using Optuna.",
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
)
|
|
parser.add_argument(
|
|
'-c', '--config', type=str, default='forecasting_config.yaml',
|
|
help="Path to the YAML configuration file containing HPO settings."
|
|
)
|
|
parser.add_argument(
|
|
'--output-dir', type=str, default=None,
|
|
help="Override output directory specified in the configuration file."
|
|
)
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
# --- Optuna Objective Function ---
|
|
def objective(
|
|
trial: optuna.Trial,
|
|
base_config: MainConfig,
|
|
df: pd.DataFrame,
|
|
) -> float: # Ensure it returns a float
|
|
"""
|
|
Optuna single-objective function using a classic train/val/test split.
|
|
|
|
Returns:
|
|
- Validation score from the classic split (minimize).
|
|
"""
|
|
logger.info(f"\n--- Starting Optuna Trial {trial.number} ---")
|
|
trial_start_time = time.perf_counter()
|
|
|
|
# Get HPO settings
|
|
hpo_config = base_config.optuna
|
|
# Use the specific metric name from config for validation monitoring
|
|
validation_metric_monitor = hpo_config.metric_to_optimize
|
|
# The objective will be based on the validation metric
|
|
# Hardcode minimization for the single objective
|
|
monitor_mode = "min"
|
|
worst_value = float('inf') # Represents a poor validation score
|
|
|
|
# --- 1. Suggest Hyperparameters ---
|
|
try:
|
|
# Deep copy the base config dictionary to modify for the trial
|
|
trial_config_dict = copy.deepcopy(base_config.model_dump(mode='python')) # Use mode='python' for easier modification
|
|
except Exception as e:
|
|
logger.error(f"Trial {trial.number}: Failed to deep copy base configuration: {e}", exc_info=True)
|
|
# Return worst value
|
|
return worst_value
|
|
|
|
# ----- Suggest Hyperparameters -----
|
|
# (Suggest parameters as before, modifying trial_config_dict)
|
|
trial_config_dict['training']['learning_rate'] = trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True)
|
|
trial_config_dict['training']['batch_size'] = trial.suggest_categorical('batch_size', [32, 64, 128, 256])
|
|
trial_config_dict['training']['loss_function'] = trial.suggest_categorical('loss_function', ['MSE', 'MAE'])
|
|
trial_config_dict['model']['hidden_size'] = trial.suggest_int('hidden_size', 18, 498, step=32)
|
|
trial_config_dict['model']['num_layers'] = trial.suggest_int('num_layers', 1, 8)
|
|
trial_config_dict['model']['dropout'] = trial.suggest_float('dropout', 0.0, 0.25, step=0.05)
|
|
trial_config_dict['features']['sequence_length'] = trial.suggest_int('sequence_length', 24, 168, step=12)
|
|
# Note: forecast_horizon is NOT tuned here, taken from base_config
|
|
trial_config_dict['features']['scaling_method'] = trial.suggest_categorical('scaling_method', ['standard', 'minmax', None])
|
|
use_configured_lags = trial.suggest_categorical('use_lags', [True, False])
|
|
if not use_configured_lags: trial_config_dict['features']['lags'] = []
|
|
use_configured_rolling = trial.suggest_categorical('use_rolling_windows', [True, False])
|
|
if not use_configured_rolling: trial_config_dict['features']['rolling_window_sizes'] = []
|
|
trial_config_dict['features']['use_time_features'] = trial.suggest_categorical('use_time_features', [True, False])
|
|
trial_config_dict['features']['sinus_curve'] = trial.suggest_categorical('sinus_curve', [True, False])
|
|
trial_config_dict['features']['cosine_curve'] = trial.suggest_categorical('cosine_curve', [True, False])
|
|
trial_config_dict['features']['fill_nan'] = trial.suggest_categorical('fill_nan', ['ffill', 'bfill', 0])
|
|
# ----- End of Hyperparameter Suggestions -----
|
|
|
|
# --- 2. Re-validate Trial Config ---
|
|
try:
|
|
# Transfer forecast horizon from base config (not tuned)
|
|
trial_config_dict['features']['forecast_horizon'] = base_config.features.forecast_horizon
|
|
trial_config = MainConfig(**trial_config_dict)
|
|
logger.info(f"Trial {trial.number} Parameters: {trial.params}")
|
|
except Exception as e:
|
|
logger.error(f"Trial {trial.number}: Invalid configuration generated: {e}")
|
|
return worst_value
|
|
|
|
# --- Early check for invalid sequence length / forecast horizon combination ---
|
|
try:
|
|
# Make sure forecast_horizon is a list before checking max()
|
|
if not isinstance(trial_config.features.forecast_horizon, list) or not trial_config.features.forecast_horizon:
|
|
raise ValueError("Trial config has invalid forecast_horizon list.")
|
|
min_data_for_sequence = trial_config.features.sequence_length + max(trial_config.features.forecast_horizon)
|
|
if min_data_for_sequence > len(df):
|
|
logger.warning(f"Trial {trial.number}: Skipped. sequence_length ({trial_config.features.sequence_length}) + "
|
|
f"max_horizon ({max(trial_config.features.forecast_horizon)}) "
|
|
f"exceeds data length ({len(df)}).")
|
|
# Optuna doesn't directly support skipping, so return worst values
|
|
return worst_value
|
|
except Exception as e:
|
|
logger.error(f"Trial {trial.number}: Error during pre-check: {e}", exc_info=True)
|
|
return worst_value
|
|
|
|
# --- 3. Run Classic Train/Test ---
|
|
logger.info(f"Trial {trial.number}: Starting Classic Run...")
|
|
try:
|
|
n_samples = len(df)
|
|
val_frac = trial_config.cross_validation.val_size_fraction
|
|
test_frac = trial_config.cross_validation.test_size_fraction
|
|
train_idx_cl, val_idx_cl, test_idx_cl = split_data_classic(n_samples, val_frac, test_frac)
|
|
|
|
|
|
# Prepare data for classic split
|
|
train_loader_cl, val_loader_cl, test_loader_cl, target_scaler_cl, data_scaler, input_size_cl = prepare_fold_data_and_loaders(
|
|
full_df=df, train_idx=train_idx_cl, val_idx=val_idx_cl, test_idx=test_idx_cl,
|
|
target_col=trial_config.data.target_col, feature_config=trial_config.features,
|
|
train_config=trial_config.training, eval_config=trial_config.evaluation
|
|
)
|
|
|
|
# Initialize Model
|
|
model_cl = LSTMForecastLightningModule(
|
|
model_config=trial_config.model, train_config=trial_config.training,
|
|
input_size=input_size_cl, target_scaler=target_scaler_cl, data_scaler=data_scaler
|
|
)
|
|
|
|
# Callbacks (EarlyStopping and Pruning)
|
|
callbacks_cl = []
|
|
if trial_config.training.early_stopping_patience and trial_config.training.early_stopping_patience > 0:
|
|
callbacks_cl.append(EarlyStopping(
|
|
monitor=validation_metric_monitor, # Monitor validation metric
|
|
patience=trial_config.training.early_stopping_patience,
|
|
mode=monitor_mode, verbose=False
|
|
))
|
|
|
|
# Trainer for classic run
|
|
trainer_cl = pl.Trainer(
|
|
check_val_every_n_epoch=trial_config.training.check_val_n_epoch,
|
|
accelerator='gpu' if torch.cuda.is_available() else 'cpu', devices=1 if torch.cuda.is_available() else None,
|
|
max_epochs=trial_config.training.epochs, callbacks=callbacks_cl, logger=False, # logger=False as per original
|
|
enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False,
|
|
gradient_clip_val=getattr(trial_config.training, 'gradient_clip_val', None),
|
|
precision=getattr(trial_config.training, 'precision', 32),
|
|
)
|
|
|
|
# --- Train Model ---
|
|
logger.info(f"Trial {trial.number}: Fitting model on classic train/val split...")
|
|
trainer_cl.fit(model_cl, train_dataloaders=train_loader_cl, val_dataloaders=val_loader_cl)
|
|
|
|
# --- Get Best Validation Score ---
|
|
# Check early stopping callback first if it exists
|
|
best_score_tensor = None
|
|
if callbacks_cl and isinstance(callbacks_cl[0], EarlyStopping):
|
|
if hasattr(callbacks_cl[0], 'best_score') and callbacks_cl[0].best_score is not None:
|
|
best_score_tensor = callbacks_cl[0].best_score
|
|
elif callbacks_cl[0].stopped_epoch > 0 : # Early stopping triggered
|
|
logger.debug(f"Trial {trial.number}: Early stopping triggered, attempting to use last callback metric.")
|
|
|
|
# If early stopping didn't capture best score, use last metrics from trainer
|
|
if best_score_tensor is None:
|
|
metric_val = trainer_cl.callback_metrics.get(validation_metric_monitor)
|
|
if metric_val is not None:
|
|
best_score_tensor = metric_val # Use the last logged value
|
|
|
|
if best_score_tensor is None:
|
|
logger.warning(f"Trial {trial.number}: Metric '{validation_metric_monitor}' not found in callbacks or metrics. Using {worst_value}.")
|
|
validation_metric_value = worst_value
|
|
else:
|
|
validation_metric_value = best_score_tensor.item()
|
|
logger.info(f"Trial {trial.number}: Best val score ({validation_metric_monitor}) = {validation_metric_value:.4f}")
|
|
|
|
# Report intermediate value for pruning (if enabled)
|
|
trial.report(validation_metric_value, trainer_cl.current_epoch)
|
|
if trial.should_prune():
|
|
logger.info(f"Trial {trial.number}: Pruned.")
|
|
raise optuna.TrialPruned()
|
|
|
|
# Note: We don't run prediction/evaluation on the test set here,
|
|
# as the objective is based on validation performance.
|
|
# The test set evaluation will be done later for the best trial.
|
|
|
|
logger.info(f"Trial {trial.number}: Finished Classic Run in {time.perf_counter() - trial_start_time:.2f}s")
|
|
|
|
except optuna.TrialPruned:
|
|
# Propagate prune signal, objective will be set to "worst" later by Optuna
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Trial {trial.number}: Failed during classic run phase: {e}", exc_info=True)
|
|
validation_metric_value = worst_value # Assign worst value if classic run fails
|
|
finally:
|
|
# Clean up GPU memory after the run
|
|
del model_cl, trainer_cl, train_loader_cl, val_loader_cl, test_loader_cl
|
|
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
|
|
|
|
|
# --- 4. Return Objective ---
|
|
logger.info(f"--- Trial {trial.number}: Finished ---")
|
|
logger.info(f" Objective (Validation {validation_metric_monitor}): {validation_metric_value:.5f}")
|
|
logger.info(f" Total time: {time.perf_counter() - trial_start_time:.2f}s")
|
|
|
|
# Return the single objective (validation metric)
|
|
return float(validation_metric_value)
|
|
|
|
|
|
# --- Main HPO Execution ---
|
|
def run_hpo():
|
|
"""Main execution function for HPO."""
|
|
args = parse_arguments()
|
|
config_path = Path(args.config)
|
|
try:
|
|
base_config = load_config(config_path, MainConfig) # Load base config once
|
|
logger.info(f"Successfully loaded configuration from {config_path}")
|
|
except Exception as e:
|
|
logger.critical(f"Failed to load configuration from {config_path}: {e}", exc_info=True)
|
|
sys.exit(1)
|
|
|
|
# --- Setup Output Dir ---
|
|
# 1. Determine the main output directory
|
|
if args.output_dir:
|
|
main_output_dir = Path(args.output_dir)
|
|
logger.info(f"Using main output directory from command line: {main_output_dir}")
|
|
elif hasattr(base_config, 'output_dir') and base_config.output_dir:
|
|
main_output_dir = Path(base_config.output_dir)
|
|
logger.info(f"Using main output directory from config file: {main_output_dir}")
|
|
else:
|
|
main_output_dir = Path("output") # Default
|
|
logger.warning(f"No output directory specified in config or args, defaulting to: {main_output_dir}")
|
|
|
|
# 2. Define the specific directory for this classic HPO run
|
|
classic_hpo_output_dir = main_output_dir / "classic"
|
|
|
|
# 3. Create directories
|
|
main_output_dir.mkdir(parents=True, exist_ok=True)
|
|
classic_hpo_output_dir.mkdir(parents=True, exist_ok=True)
|
|
logger.info(f"Classic HPO outputs will be saved under: {classic_hpo_output_dir}")
|
|
|
|
# --- Setup Logging ---
|
|
try:
|
|
level_name = base_config.log_level.upper()
|
|
effective_log_level = logging.getLevelName(level_name)
|
|
# Ensure study name is filesystem-safe
|
|
safe_study_name = base_config.optuna.study_name
|
|
safe_study_name = "".join(c if c.isalnum() or c in ('_', '-') else '_' for c in safe_study_name)
|
|
# Place log file directly inside the classic HPO directory
|
|
log_file = classic_hpo_output_dir / f"{safe_study_name}_hpo.log" # Changed filename slightly
|
|
file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8') # Add encoding
|
|
formatter = logging.Formatter('%(asctime)s - %(name)-25s - %(levelname)-7s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
|
file_handler.setFormatter(formatter)
|
|
# Add handler only if it's not already added
|
|
if not any(isinstance(h, logging.FileHandler) and h.baseFilename == str(log_file.resolve()) for h in logger.handlers):
|
|
logger.addHandler(file_handler)
|
|
logger.setLevel(effective_log_level)
|
|
logger.info(f"Set log level to {level_name}. Logging HPO run to console and {log_file}")
|
|
if effective_log_level <= logging.DEBUG: logger.debug("Debug logging enabled.")
|
|
except (AttributeError, ValueError, TypeError) as e: # Add TypeError
|
|
logger.warning(f"Could not set log level from config. Defaulting to INFO. Error: {e}")
|
|
logger.setLevel(logging.INFO)
|
|
# Still try to log to a default file if possible
|
|
try:
|
|
# Default log file also goes into the specific classic directory
|
|
log_file = classic_hpo_output_dir / "default_classic_hpo.log"
|
|
file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8')
|
|
formatter = logging.Formatter('%(asctime)s - %(name)-25s - %(levelname)-7s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
|
file_handler.setFormatter(formatter)
|
|
if not any(isinstance(h, logging.FileHandler) and h.baseFilename == str(log_file.resolve()) for h in logger.handlers):
|
|
logger.addHandler(file_handler)
|
|
logger.info(f"Logging to default file: {log_file}")
|
|
except Exception as log_e:
|
|
logger.error(f"Failed to set up default file logging: {log_e}")
|
|
|
|
|
|
# Setup seeding...
|
|
try:
|
|
seed = base_config.random_seed
|
|
set_seeds(seed)
|
|
except AttributeError:
|
|
logger.warning("Config missing 'random_seed'. Using default seed 42.")
|
|
set_seeds(42)
|
|
|
|
# Load data...
|
|
try:
|
|
logger.info("Loading base dataset for HPO...")
|
|
df = load_raw_data(base_config.data)
|
|
logger.info(f"Base dataset loaded. Shape: {df.shape}")
|
|
except Exception as e:
|
|
logger.critical(f"Failed to load raw data for HPO: {e}", exc_info=True)
|
|
sys.exit(1)
|
|
|
|
# --- Optuna Study Setup ---
|
|
try:
|
|
hpo_config = base_config.optuna
|
|
if not hpo_config.enabled:
|
|
logger.info("Optuna optimization is disabled in the configuration.")
|
|
sys.exit(0)
|
|
except AttributeError:
|
|
logger.critical("Optuna configuration section ('optuna') missing.")
|
|
sys.exit(1)
|
|
|
|
storage_string = hpo_config.storage
|
|
storage_path = None # Initialize
|
|
|
|
if storage_string and storage_string.startswith("sqlite:///"):
|
|
db_filename = storage_string.replace("sqlite:///", "").strip()
|
|
if not db_filename:
|
|
# Use study name if filename is empty
|
|
db_filename = f"{safe_study_name}.db" # Default DB name for classic
|
|
logger.warning(f"SQLite path in config was empty, using default filename: {db_filename}")
|
|
# Place the DB file inside the classic HPO directory
|
|
db_path = classic_hpo_output_dir / db_filename
|
|
storage_path = f"sqlite:///{db_path.resolve()}"
|
|
logger.info(f"Using SQLite storage: {storage_path}")
|
|
elif storage_string:
|
|
# Assume it's a non-SQLite connection string or a pre-configured path
|
|
storage_path = storage_string
|
|
logger.warning(f"Using non-SQLite Optuna storage: {storage_path}. Note: DB file will not be placed inside {classic_hpo_output_dir}")
|
|
else:
|
|
storage_path = None # Explicitly set to None for in-memory
|
|
logger.warning("No Optuna storage DB specified, using in-memory storage (results will be lost on exit).")
|
|
|
|
|
|
try:
|
|
# Change to single objective 'minimize'
|
|
study = optuna.create_study(
|
|
study_name=hpo_config.study_name,
|
|
storage=storage_path,
|
|
direction=hpo_config.direction, # Use direction from config
|
|
load_if_exists=True,
|
|
pruner=optuna.pruners.MedianPruner() if hpo_config.pruning else optuna.pruners.NopPruner()
|
|
)
|
|
|
|
# --- Run Optimization ---
|
|
logger.info(f"Starting Optuna single-objective optimization: study='{hpo_config.study_name}', n_trials={hpo_config.n_trials}, direction='{hpo_config.direction}'")
|
|
study.optimize(
|
|
lambda trial: objective(trial, base_config, df), # Objective doesn't need the path here
|
|
n_trials=hpo_config.n_trials,
|
|
timeout=None,
|
|
gc_after_trial=True
|
|
)
|
|
|
|
# --- Report and Process Best Trial ---
|
|
logger.info("--- Optuna HPO Finished ---")
|
|
logger.info(f"Number of finished trials: {len(study.trials)}")
|
|
|
|
best_trial = None
|
|
try:
|
|
best_trial = study.best_trial
|
|
except ValueError:
|
|
logger.warning("Optuna study finished, but no successful trial was completed.")
|
|
|
|
if best_trial:
|
|
logger.info(f"Best trial found (Trial {best_trial.number}):")
|
|
# Log details for the best trial
|
|
validation_metric_monitor = base_config.optuna.metric_to_optimize
|
|
logger.info(f" Objective ({validation_metric_monitor}): {best_trial.value:.5f}") # Use .value for single objective
|
|
logger.info(f" Hyperparameters:")
|
|
for key, value in best_trial.params.items():
|
|
logger.info(f" {key}: {value}")
|
|
|
|
# --- Re-run and Save Artifacts for the Best Trial ---
|
|
logger.info(f"-> Re-running Best Trial {best_trial.number} to save artifacts...")
|
|
# Define the output directory for this specific best trial run
|
|
best_trial_output_dir = classic_hpo_output_dir / f"best_trial_num{best_trial.number}"
|
|
best_trial_output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
try:
|
|
# 1. Create config for this trial
|
|
best_config_dict = copy.deepcopy(base_config.model_dump(mode='python'))
|
|
# Update with trial's hyperparameters
|
|
for key, value in best_trial.params.items(): # Use best_trial.params
|
|
# Need to place the param in the correct nested config dict
|
|
if key in best_config_dict['training']: best_config_dict['training'][key] = value
|
|
elif key in best_config_dict['model']: best_config_dict['model'][key] = value
|
|
elif key in best_config_dict['features']: best_config_dict['features'][key] = value
|
|
# Add more sections if HPO tunes them
|
|
|
|
# Add non-tuned features forecast_horizon back
|
|
best_config_dict['features']['forecast_horizon'] = base_config.features.forecast_horizon
|
|
|
|
# Ensure evaluation plots and model saving are enabled for this final run
|
|
best_config_dict['evaluation']['save_plots'] = True
|
|
# Add a flag to save the model if not already present/configurable
|
|
# best_config_dict['training']['save_model'] = True # Assuming you have this or handle it in run_model_training
|
|
|
|
best_trial_config = MainConfig(**best_config_dict)
|
|
|
|
# Save the specific config used for this best run inside its directory
|
|
with open(best_trial_output_dir / "best_run_config.yaml", 'w', encoding='utf-8') as f:
|
|
yaml.dump(best_config_dict, f, default_flow_style=False, sort_keys=False, allow_unicode=True)
|
|
|
|
# 2. Run classic training, saving outputs under best_trial_output_dir
|
|
logger.info(f"-> Running classic training for Best Trial {best_trial.number}...")
|
|
run_model_training(
|
|
config=best_trial_config,
|
|
full_df=df,
|
|
# Pass the specific directory for this run's artifacts
|
|
output_base_dir=best_trial_output_dir
|
|
)
|
|
logger.info(f"-> Finished re-running and saving artifacts for Best Trial {best_trial.number} to {best_trial_output_dir}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"-> Failed to re-run or save artifacts for Best Trial {best_trial.number}: {e}", exc_info=True)
|
|
|
|
# --- Save Best Hyperparameters and Config at the Top Level ---
|
|
# Save best parameters file directly into the classic HPO output dir
|
|
best_params_file = classic_hpo_output_dir / f"{safe_study_name}_best_params.json"
|
|
try:
|
|
with open(best_params_file, 'w', encoding='utf-8') as f:
|
|
import json
|
|
json.dump(best_trial.params, f, indent=4) # Use best_trial.params
|
|
logger.info(f"Hyperparameters of the best trial saved to {best_params_file}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to save parameters for best trial: {e}", exc_info=True)
|
|
|
|
# Save the best config file directly into the classic HPO output dir
|
|
best_config_file = classic_hpo_output_dir / f"{safe_study_name}_best_config.yaml"
|
|
try:
|
|
# best_config_dict should still hold the config from the re-run step above
|
|
with open(best_config_file, 'w', encoding='utf-8') as f:
|
|
yaml.dump(best_config_dict, f, default_flow_style=False, sort_keys=False, allow_unicode=True)
|
|
logger.info(f"Configuration for best trial saved to {best_config_file}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to save best configuration: {e}", exc_info=True)
|
|
|
|
|
|
except optuna.exceptions.StorageInternalError as e:
|
|
logger.critical(f"Optuna storage error: {e}. Check storage path/permissions: {storage_path}", exc_info=True)
|
|
sys.exit(1)
|
|
|
|
if __name__ == "__main__":
|
|
run_hpo() |