Files
entrix_case_challange/optuna_ensemble_run.py
2025-05-12 20:05:28 +02:00

613 lines
33 KiB
Python

import argparse
import logging
import sys
import warnings
import copy
from pathlib import Path
import time
import numpy as np
import pandas as pd
import yaml
import optuna
from forecasting_model.utils.forecast_config_model import MainConfig
from forecasting_model import TimeSeriesCrossValidationSplitter, load_raw_data
from forecasting_model_run import run_single_fold
from forecasting_model.train.ensemble_evaluation import run_ensemble_evaluation
from typing import List, Dict, Any #
# Import helper functions
from forecasting_model.utils.helper import load_config, set_seeds
# --- Suppress specific PL warnings about logger=True with no logger ---
# This is expected behavior in optuna_run.py where logger=False is intentional
warnings.filterwarnings(
"ignore",
message=".*You called `self.log.*logger=True.*no logger configured.*",
category=UserWarning,
module="pytorch_lightning.core.module"
)
# Silence overly verbose libraries if needed
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)
pil_logger = logging.getLogger('PIL')
pil_logger.setLevel(logging.WARNING)
pl_logger = logging.getLogger('pytorch_lightning')
pl_logger.setLevel(logging.WARNING)
# --- Basic Logging Setup ---
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)-7s - %(message)s',
datefmt='%H:%M:%S')
# Get the root logger
logger = logging.getLogger()
# --- Argument Parsing ---
def parse_arguments():
"""Parses command-line arguments for Optuna Ensemble HPO."""
parser = argparse.ArgumentParser(
description="Run HPO optimizing ensemble performance using Optuna.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
'-c', '--config', type=str, default='forecasting_config.yaml',
help="Path to the YAML configuration file."
)
parser.add_argument(
'--output-dir', type=str, default=None,
help="Override base output directory for HPO results."
)
parser.add_argument(
'--keep-artifacts', action='store_true',
help="Prevent cleanup of trial directories after best trial is determined."
)
args = parser.parse_args()
return args
# --- Optuna Objective Function ---
def objective(
trial: optuna.Trial,
base_config: MainConfig,
df: pd.DataFrame,
ensemble_hpo_output_dir: Path # Renamed parameter for clarity
) -> float: # Return the single ensemble metric to optimize
"""
Optuna objective function optimizing ensemble performance.
"""
logger.info(f"\n--- Starting Optuna Trial {trial.number} ---")
trial_start_time = time.perf_counter()
# Define trial-specific output directory for fold artifacts
trial_artifacts_dir = ensemble_hpo_output_dir / "ensemble_runs_artifacts" / f"trial_{trial.number}"
trial_artifacts_dir.mkdir(parents=True, exist_ok=True)
logger.debug(f"Trial artifacts will be saved to: {trial_artifacts_dir}")
hpo_config = base_config.optuna
# Metric for pruning based on individual fold performance
validation_metric_monitor = hpo_config.metric_to_optimize
# Ensemble metric and method to optimize (e.g., MAE of the 'mean' ensemble)
ensemble_metric_optimize = 'MAE'
ensemble_method_optimize = 'mean'
optimization_direction = hpo_config.direction # 'minimize' or 'maximize'
worst_value = float('inf') if optimization_direction == 'minimize' else float('-inf')
# Store paths and details for all saved artifacts for this trial's folds
fold_artifact_details: List[Dict[str, Any]] = [] # Changed to list of dicts
# --- 1. Suggest Hyperparameters ---
try:
trial_config_dict = copy.deepcopy(base_config.model_dump(mode='python'))
except Exception as e:
logger.error(f"Trial {trial.number}: Failed to deep copy base config: {e}", exc_info=True)
return worst_value
# ----- Suggest Hyperparameters -----
# Modify trial_config_dict using trial.suggest_*
trial_config_dict['training']['learning_rate'] = trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True)
trial_config_dict['training']['batch_size'] = trial.suggest_categorical('batch_size', [16, 32, 64, 128, 256, 512])
trial_config_dict['training']['loss_function'] = trial.suggest_categorical('loss_function', ['MSE', 'MAE'])
trial_config_dict['model']['hidden_size'] = trial.suggest_int('hidden_size', 18, 498, step=32)
trial_config_dict['model']['num_layers'] = trial.suggest_int('num_layers', 1, 3)
trial_config_dict['model']['dropout'] = trial.suggest_float('dropout', 0.0, 0.5, step=0.05)
trial_config_dict['features']['sequence_length'] = trial.suggest_int('sequence_length', 3, 72, step=2)
trial_config_dict['features']['scaling_method'] = trial.suggest_categorical('scaling_method', ['standard', 'minmax', None])
use_configured_lags = trial.suggest_categorical('use_lags', [True, False])
if not use_configured_lags: trial_config_dict['features']['lags'] = []
use_configured_rolling = trial.suggest_categorical('use_rolling_windows', [True, False])
if not use_configured_rolling: trial_config_dict['features']['rolling_window_sizes'] = []
trial_config_dict['features']['use_time_features'] = trial.suggest_categorical('use_time_features', [True, False])
trial_config_dict['features']['sinus_curve'] = trial.suggest_categorical('sinus_curve', [True, False])
trial_config_dict['features']['cosine_curve'] = trial.suggest_categorical('cosine_curve', [True, False])
trial_config_dict['features']['fill_nan'] = trial.suggest_categorical('fill_nan', ['ffill', 'bfill', 0])
# ----- End of Suggestions -----
# --- 2. Re-validate Trial Config ---
try:
trial_config_dict['features']['forecast_horizon'] = base_config.features.forecast_horizon
# Disable plotting during HPO runs to save time/resources
trial_config_dict['evaluation']['save_plots'] = False
trial_config = MainConfig(**trial_config_dict)
logger.info(f"Trial {trial.number} Parameters: {trial.params}")
except Exception as e:
logger.error(f"Trial {trial.number}: Invalid config generated: {e}", exc_info=True)
return worst_value
# --- Early check for data length ---
# ... (Keep the check as in optuna_run.py) ...
try:
if not isinstance(trial_config.features.forecast_horizon, list) or not trial_config.features.forecast_horizon:
raise ValueError("Trial config has invalid forecast_horizon list.")
min_data_for_sequence = trial_config.features.sequence_length + max(trial_config.features.forecast_horizon)
if min_data_for_sequence > len(df):
logger.warning(f"Trial {trial.number}: Skipped. sequence_length + max_horizon ({min_data_for_sequence}) exceeds data length ({len(df)}).")
# Report worst value so Optuna knows this trial failed badly
# Using study direction to determine the appropriate "worst" value
return worst_value
# raise optuna.TrialPruned() # Alternative: Prune instead of returning worst value
except Exception as e:
logger.error(f"Trial {trial.number}: Error during pre-check: {e}", exc_info=True)
return worst_value
# --- 3. Run Cross-Validation Training (Saving Artifacts) ---
all_fold_best_val_scores = {} # Store best val scores for pruning
actual_folds_trained = 0
# Store paths to saved models and scalers for this trial
# fold_model_paths = [] # Removed, using fold_artifact_details instead
# fold_scaler_paths = [] # Removed, using fold_artifact_details instead
try:
cv_splitter = TimeSeriesCrossValidationSplitter(trial_config.cross_validation, len(df))
for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split()):
fold_id = fold_num + 1
logger.info(f"Trial {trial.number}, Fold {fold_id}/{cv_splitter.n_splits}: Training model...")
current_fold_best_metric = None # Reset for each fold
try:
# Use run_single_fold - it handles training and saving artifacts
fold_metrics, best_val_score, saved_model_path, saved_target_scaler_path, saved_data_scaler_path, saved_input_size_path, saved_config_path = run_single_fold(
fold_num=fold_num,
train_idx=train_idx, val_idx=val_idx, test_idx=test_idx,
config=trial_config,
full_df=df,
output_base_dir=trial_artifacts_dir,
enable_progress_bar=False
)
actual_folds_trained += 1
all_fold_best_val_scores[fold_id] = best_val_score
# Store all artifact paths for this fold
fold_artifact_details.append({
"fold_id": fold_id,
"model_path": str(saved_model_path) if saved_model_path else None,
"target_scaler_path": str(saved_target_scaler_path) if saved_target_scaler_path else None,
"data_scaler_path": str(saved_data_scaler_path) if saved_data_scaler_path else None,
"input_size_path": str(saved_input_size_path) if saved_input_size_path else None,
"config_path": str(saved_config_path) if saved_config_path else None,
})
# Check if the monitored validation metric was returned
if best_val_score is not None and np.isfinite(best_val_score):
current_fold_best_metric = best_val_score
logger.info(f"Trial {trial.number}, Fold {fold_id}: Best val score ({validation_metric_monitor}) = {current_fold_best_metric:.4f}")
else:
# Use worst value if metric is missing/invalid for pruning
logger.warning(f"Trial {trial.number}, Fold {fold_id}: Invalid or missing validation score ({validation_metric_monitor}). Using {worst_value} for pruning.")
current_fold_best_metric = worst_value # Assign worst for pruning report
# Report intermediate value (individual fold validation score) for pruning
trial.report(current_fold_best_metric, fold_num)
if trial.should_prune():
logger.info(f"Trial {trial.number}: Pruned after fold {fold_id}.")
raise optuna.TrialPruned()
except optuna.TrialPruned:
raise # Propagate prune signal
except Exception as e:
logger.error(f"Trial {trial.number}, Fold {fold_id}: Failed CV fold training: {e}", exc_info=True)
all_fold_best_val_scores[fold_id] = None # Mark fold as failed
trial.report(worst_value, fold_num)
except optuna.TrialPruned:
logger.info(f"Trial {trial.number}: Pruned during CV training phase.")
return worst_value # Return worst value when pruned
except Exception as e:
logger.critical(f"Trial {trial.number}: Failed critically during CV training setup/loop: {e}", exc_info=True)
return worst_value
# --- 4. Run Ensemble Evaluation ---
if actual_folds_trained < 2:
logger.error(f"Trial {trial.number}: Only {actual_folds_trained} folds trained successfully. Cannot run ensemble evaluation.")
return worst_value # Not enough models for ensemble
logger.info(f"Trial {trial.number}: Starting Ensemble Evaluation using {actual_folds_trained} trained models...")
ensemble_metric_final = worst_value # Initialize to worst
best_ensemble_method_for_trial = None # Track the best method for this trial
try:
# Run evaluation using the artifacts saved in the trial's output directory
ensemble_results = run_ensemble_evaluation(
config=trial_config,
output_base_dir=trial_artifacts_dir # Pass the specific trial's artifact dir
)
if ensemble_results:
# Aggregate the results to get the final objective value
ensemble_metrics_for_method = []
for fold_num, fold_res in ensemble_results.items():
if fold_res and ensemble_method_optimize in fold_res:
method_metrics = fold_res[ensemble_method_optimize]
if method_metrics and ensemble_metric_optimize in method_metrics:
metric_val = method_metrics[ensemble_metric_optimize]
if metric_val is not None and np.isfinite(metric_val):
ensemble_metrics_for_method.append(metric_val)
else:
logger.warning(f"Trial {trial.number}: Invalid ensemble metric value found for fold {fold_num}, method '{ensemble_method_optimize}', metric '{ensemble_metric_optimize}'.")
else:
logger.warning(f"Trial {trial.number}: Metric '{ensemble_metric_optimize}' not found for method '{ensemble_method_optimize}' in fold {fold_num}.")
else:
logger.warning(f"Trial {trial.number}: Ensemble method '{ensemble_method_optimize}' results not found for fold {fold_num}.")
if not ensemble_metrics_for_method:
logger.error(f"Trial {trial.number}: No valid ensemble metrics found for method '{ensemble_method_optimize}', metric '{ensemble_metric_optimize}'.")
ensemble_metric_final = worst_value
else:
# Calculate the mean of the chosen ensemble metric across test folds
ensemble_metric_final = np.mean(ensemble_metrics_for_method)
logger.info(f"Trial {trial.number}: Final Ensemble Metric (Avg {ensemble_method_optimize} {ensemble_metric_optimize}): {ensemble_metric_final:.6f}")
# Determine the best ensemble method based on average performance across folds
# This requires re-calculating averages for *all* methods evaluated by run_ensemble_evaluation
all_ensemble_methods = set()
for fold_res in ensemble_results.values():
if fold_res: all_ensemble_methods.update(fold_res.keys())
avg_metrics_per_method = {}
for method in all_ensemble_methods:
method_metrics_across_folds = []
for fold_res in ensemble_results.values():
if fold_res and method in fold_res and ensemble_metric_optimize in fold_res[method]:
metric_val = fold_res[method][ensemble_metric_optimize]
if metric_val is not None and np.isfinite(metric_val):
method_metrics_across_folds.append(metric_val)
if method_metrics_across_folds:
avg_metrics_per_method[method] = np.mean(method_metrics_across_folds)
if avg_metrics_per_method:
if optimization_direction == 'minimize':
best_ensemble_method_for_trial = min(avg_metrics_per_method, key=avg_metrics_per_method.get)
else: # maximize
best_ensemble_method_for_trial = max(avg_metrics_per_method, key=avg_metrics_per_method.get)
logger.info(f"Trial {trial.number}: Best performing ensemble method for this trial (based on avg {ensemble_metric_optimize}): {best_ensemble_method_for_trial}")
else:
logger.warning(f"Trial {trial.number}: Could not determine best ensemble method based on average {ensemble_metric_optimize}.")
else:
logger.error(f"Trial {trial.number}: Ensemble evaluation function returned no results.")
ensemble_metric_final = worst_value
except Exception as e:
logger.error(f"Trial {trial.number}: Failed during ensemble evaluation phase: {e}", exc_info=True)
ensemble_metric_final = worst_value
# --- 5. Return Final Objective Value ---
trial_duration = time.perf_counter() - trial_start_time
logger.info(f"--- Trial {trial.number}: Finished ---")
logger.info(f" Final Objective Value (Avg Ensemble {ensemble_method_optimize} {ensemble_metric_optimize}): {ensemble_metric_final:.6f}")
logger.info(f" Total time: {trial_duration:.2f}s")
# Store ensemble evaluation results and best method in trial user attributes
# This makes it easier to retrieve the best ensemble details after the study
trial.set_user_attr("ensemble_evaluation_results", ensemble_results)
trial.set_user_attr("best_ensemble_method", best_ensemble_method_for_trial)
# trial.set_user_attr("fold_model_paths", fold_model_paths) # Removed
# trial.set_user_attr("fold_scaler_paths", fold_scaler_paths) # Removed
trial.set_user_attr("fold_artifact_details", fold_artifact_details) # Added comprehensive artifact details
return ensemble_metric_final
# --- Main HPO Execution ---
def run_hpo():
"""Main execution function for HPO optimizing ensemble performance."""
args = parse_arguments()
config_path = Path(args.config)
try:
base_config = load_config(config_path, MainConfig)
logger.info(f"Successfully loaded configuration from {config_path}")
except Exception as e:
logger.critical(f"Failed to load configuration from {config_path}: {e}", exc_info=True)
sys.exit(1)
# --- Setup Output Dir ---
# 1. Determine the main output directory
if args.output_dir:
# Command-line argument overrides config
main_output_dir = Path(args.output_dir)
logger.info(f"Using main output directory from command line: {main_output_dir}")
elif hasattr(base_config, 'output_dir') and base_config.output_dir:
main_output_dir = Path(base_config.output_dir)
logger.info(f"Using main output directory from config file: {main_output_dir}")
else:
main_output_dir = Path("output") # Default if not specified anywhere
logger.warning(f"No output directory specified in config or args, defaulting to: {main_output_dir}")
# 2. Define the specific directory for this ensemble HPO run
ensemble_hpo_output_dir = main_output_dir / "ensemble"
# 3. Create directories
main_output_dir.mkdir(parents=True, exist_ok=True)
ensemble_hpo_output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Ensemble HPO outputs will be saved under: {ensemble_hpo_output_dir}")
# --- Setup Logging ---
try:
level_name = base_config.log_level.upper()
# getattr(logging, 'levelname') is the **new** way to do this ( deprecated, but works! )
effective_log_level = logging.getLevelName(level_name)
# Ensure study name is filesystem-safe if used directly
safe_study_name = base_config.optuna.study_name
safe_study_name = "".join(c if c.isalnum() or c in ('_', '-') else '_' for c in safe_study_name)
# Place log file directly inside the ensemble HPO directory
log_file = ensemble_hpo_output_dir / f"{safe_study_name}_ensemble_hpo.log"
file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8') # Specify encoding
formatter = logging.Formatter('%(asctime)s - %(name)-25s - %(levelname)-7s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
file_handler.setFormatter(formatter)
# Prevent adding duplicate handlers if script/function is called multiple times
if not any(isinstance(h, logging.FileHandler) and h.baseFilename == str(log_file.resolve()) for h in logger.handlers):
logger.addHandler(file_handler)
logger.setLevel(effective_log_level)
logger.info(f"Set log level to {level_name}. Logging HPO run to console and {log_file}")
if effective_log_level <= logging.DEBUG: logger.debug("Debug logging enabled.")
except (AttributeError, ValueError, TypeError) as e:
logger.warning(f"Could not set log level from config: {e}. Defaulting to INFO.")
logger.setLevel(logging.INFO)
# Still try to log to a default file if possible
try:
# Default log file also goes into the specific ensemble directory
log_file = ensemble_hpo_output_dir / "default_ensemble_hpo.log"
file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8')
formatter = logging.Formatter('%(asctime)s - %(name)-25s - %(levelname)-7s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
file_handler.setFormatter(formatter)
if not any(isinstance(h, logging.FileHandler) and h.baseFilename == str(log_file.resolve()) for h in logger.handlers):
logger.addHandler(file_handler)
logger.info(f"Logging to default file: {log_file}")
except Exception as log_e:
logger.error(f"Failed to set up default file logging: {log_e}")
# --- Setup Seeding ---
set_seeds(getattr(base_config, 'random_seed', 42))
# --- Load Data ---
try:
logger.info("Loading base dataset for HPO...")
df = load_raw_data(base_config.data)
logger.info(f"Base dataset loaded. Shape: {df.shape}")
except FileNotFoundError as e:
logger.critical(f"Data file not found: {e}", exc_info=True)
sys.exit(1)
except Exception as e:
logger.critical(f"Failed to load raw data for HPO: {e}", exc_info=True)
sys.exit(1)
# --- Optuna Study Setup ---
try:
hpo_config = base_config.optuna
if not hpo_config.enabled:
logger.info("Optuna optimization is disabled in the configuration.")
sys.exit(0)
except AttributeError:
logger.critical("Optuna configuration section ('optuna') missing.")
sys.exit(1)
storage_string = hpo_config.storage # Use a more descriptive name
storage_path = None # Initialize
if storage_string and storage_string.startswith("sqlite:///"):
db_filename = storage_string.replace("sqlite:///", "").strip()
if not db_filename:
# Use study name if filename is empty
db_filename = f"{safe_study_name}_ensemble.db"
logger.warning(f"SQLite path in config was empty, using default filename: {db_filename}")
# Place the DB file inside the ensemble HPO directory
db_path = ensemble_hpo_output_dir / db_filename
storage_path = f"sqlite:///{db_path.resolve()}"
logger.info(f"Using SQLite storage: {storage_path}")
elif storage_string:
# Assume it's a non-SQLite connection string or a pre-configured path
storage_path = storage_string
logger.warning(f"Using non-SQLite Optuna storage: {storage_path}. Note: DB file will not be placed inside {ensemble_hpo_output_dir}")
else:
storage_path = None # Explicitly set to None for in-memory
logger.warning("No Optuna storage DB specified, using in-memory storage (results will be lost on exit).")
try:
# Single objective study based on ensemble performance
study = optuna.create_study(
study_name=hpo_config.study_name,
storage=storage_path,
direction=hpo_config.direction, # 'minimize' or 'maximize'
load_if_exists=True,
pruner=optuna.pruners.MedianPruner() if hpo_config.pruning else optuna.pruners.NopPruner()
)
# --- Run Optimization ---
logger.info(f"Starting Optuna optimization for ensemble performance: study='{hpo_config.study_name}', n_trials={hpo_config.n_trials}, direction='{hpo_config.direction}'")
study.optimize(
lambda trial: objective(trial, base_config, df, ensemble_hpo_output_dir), # Pass ensemble output dir
n_trials=hpo_config.n_trials,
timeout=None,
gc_after_trial=True # Garbage collect after trial
)
# --- Report and Save Best Trial ---
logger.info("--- Optuna HPO Finished ---")
logger.info(f"Number of finished trials: {len(study.trials)}")
# Filter trials to find the actual best one (excluding pruned/failed)
try:
best_trial = study.best_trial
except ValueError: # Optuna raises ValueError if no trials completed successfully
best_trial = None
logger.warning("No successful trials completed. Cannot determine best trial.")
if best_trial:
logger.info("--- Best Trial ---")
logger.info(f" Trial Number: {best_trial.number}")
# Ensure value is not None before formatting
best_value_str = f"{best_trial.value:.6f}" if best_trial.value is not None else "N/A"
logger.info(f" Objective Value (Ensemble Metric): {best_value_str}")
logger.info(f" Hyperparameters:")
best_params = best_trial.params
for key, value in best_params.items():
logger.info(f" {key}: {value}")
# Save best hyperparameters directly into the ensemble output dir
best_params_file = ensemble_hpo_output_dir / f"{safe_study_name}_best_params.json"
try:
with open(best_params_file, 'w', encoding='utf-8') as f:
import json
json.dump(best_params, f, indent=4)
logger.info(f"Best hyperparameters saved to {best_params_file}")
except Exception as e:
logger.error(f"Failed to save best parameters: {e}", exc_info=True)
# Save the corresponding config directly into the ensemble output dir
best_config_file = ensemble_hpo_output_dir / f"{safe_study_name}_best_config.yaml"
try:
# Use a fresh deepcopy to avoid modifying the original base_config
best_config_dict = copy.deepcopy(base_config.model_dump(mode='python'))
# Update with best trial's hyperparameters
# Pitfall: This assumes keys match exactly and exist in these sections.
# A more robust approach might involve checking key existence or
# iterating through the config structure if params are nested differently.
for key, value in best_params.items():
if key in best_config_dict.get('training', {}): best_config_dict['training'][key] = value
elif key in best_config_dict.get('model', {}): best_config_dict['model'][key] = value
elif key in best_config_dict.get('features', {}): best_config_dict['features'][key] = value
elif key in ["use_lags", "use_rolling_windows"]:
# IF false, we set this to [] in the parameter suggestion section.
pass
else:
logger.warning(f"Best parameter '{key}' not found in expected config sections (training, model, features).")
# Ensure forecast horizon is preserved from the original config
best_config_dict['features']['forecast_horizon'] = base_config.features.forecast_horizon
# Maybe remove optuna section from the best config?
# best_config_dict.pop('optuna', None)
with open(best_config_file, 'w', encoding='utf-8') as f:
yaml.dump(best_config_dict, f, default_flow_style=False, sort_keys=False, allow_unicode=True)
logger.info(f"Configuration for best trial saved to {best_config_file}")
except Exception as e:
logger.error(f"Failed to save best configuration: {e}", exc_info=True)
# --- Retrieve Artifacts and Save Ensemble Definition ---
# Base directory for this trial's artifacts
best_trial_artifacts_dir = ensemble_hpo_output_dir / "ensemble_runs_artifacts" / f"trial_{best_trial.number}"
best_ensemble_method = best_trial.user_attrs.get("best_ensemble_method")
fold_artifact_details = best_trial.user_attrs.get("fold_artifact_details", [])
# Check if artifacts exist and data is available
if not best_trial_artifacts_dir.exists():
logger.error(f"Artifacts directory for best trial {best_trial.number} not found: {best_trial_artifacts_dir}. Cannot save best ensemble definition.")
elif not best_ensemble_method:
logger.error(f"Best ensemble method not recorded for best trial {best_trial.number}. Cannot save best ensemble definition.")
elif not fold_artifact_details: # Check if any artifact details were recorded
logger.error(f"No artifact details recorded for best trial {best_trial.number}. Cannot save best ensemble definition.")
else:
# --- Save Best Ensemble Definition ---
logger.info(f"Saving best ensemble definition for trial {best_trial.number}...")
# Save definition file directly into the ensemble output dir
ensemble_definition_file = ensemble_hpo_output_dir / f"{safe_study_name}_best_ensemble.json"
best_ensemble_definition = {
"trial_number": best_trial.number,
"objective_value": best_trial.value,
"hyperparameters": best_trial.params,
"ensemble_method": best_ensemble_method,
# The base dir for artifacts, relative to the main ensemble output dir
"ensemble_artifacts_base_dir": str(best_trial_artifacts_dir.relative_to(ensemble_hpo_output_dir)), # Corrected path
"fold_models": [],
}
# Populate fold_models with paths relative to best_trial_artifacts_dir
for artifact_detail in fold_artifact_details:
fold_def = {
"fold_id": artifact_detail.get("fold_id"),
"model_path": None,
"target_scaler_path": None,
"data_scaler_path": None, # Added placeholder
"input_size_path": None,
"config_path": None,
}
# Process each path, making it relative if possible
# Added "data_scaler_path" to the list of keys to process
for key in ["model_path", "target_scaler_path", "data_scaler_path", "input_size_path", "config_path"]:
abs_path_str = artifact_detail.get(key)
if abs_path_str:
abs_path = Path(abs_path_str).absolute()
try:
# Make path relative to the trial artifacts dir (where models/scalers reside)
relative_path = str(abs_path.relative_to(best_trial_artifacts_dir.absolute()))
fold_def[key] = relative_path
except ValueError:
# This shouldn't happen if paths were saved correctly, but handle just in case
logger.warning(f"Failed to make path {abs_path} relative to {best_trial_artifacts_dir}. Saving absolute path for {key}.")
fold_def[key] = str(abs_path) # Fallback to absolute path
best_ensemble_definition["fold_models"].append(fold_def)
try:
with open(ensemble_definition_file, 'w', encoding='utf-8') as f:
json.dump(best_ensemble_definition, f, indent=4)
logger.info(f"Best ensemble definition saved to {ensemble_definition_file}")
except Exception as e:
logger.error(f"Failed to save best ensemble definition: {e}", exc_info=True)
# --- Optional: Clean up artifact directories for non-best trials ---
if not args.keep_artifacts:
logger.info("Cleaning up artifact directories for non-best trials...")
# The base path for all trial artifacts within the ensemble dir
ensemble_artifacts_base_dir = ensemble_hpo_output_dir / "ensemble_runs_artifacts" # Corrected base path
if ensemble_artifacts_base_dir.exists():
for item in ensemble_artifacts_base_dir.iterdir():
if item.is_dir():
# Check if this directory belongs to the best trial
if best_trial and item.name == f"trial_{best_trial.number}":
logger.debug(f"Keeping artifact directory for best trial: {item}")
continue
else:
logger.debug(f"Removing artifact directory for non-best trial: {item}")
try:
import shutil
shutil.rmtree(item)
except Exception as e:
logger.error(f"Failed to remove directory {item}: {e}", exc_info=True)
else:
logger.debug(f"Artifacts base directory not found for cleanup: {ensemble_artifacts_base_dir}")
except optuna.exceptions.StorageInternalError as e:
logger.critical(f"Optuna storage error: {e}. Check storage path/permissions: {storage_path}", exc_info=True)
sys.exit(1)
except Exception as e:
logger.critical(f"An critical error occurred during the Optuna study: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
run_hpo() # Changed main() to run_hpo()