Files
entrix_case_challange/optuna_classic_run.py
2025-05-12 20:05:28 +02:00

482 lines
24 KiB
Python

import argparse
import logging
import sys
import warnings
import copy
from pathlib import Path
import time
import pandas as pd
import torch
import yaml
import optuna
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
# Import necessary components from the forecasting_model package
from forecasting_model.utils.forecast_config_model import MainConfig
from forecasting_model.utils.data_processing import (
prepare_fold_data_and_loaders,
split_data_classic
)
from forecasting_model.train.model import LSTMForecastLightningModule
from forecasting_model.train.classic import run_model_training
# Import helper functions from forecasting_model_run.py
from forecasting_model.utils.helper import load_config, set_seeds
# Import the data processing functions
from forecasting_model import load_raw_data
# --- Suppress specific PL warnings about logger=True with no logger ---
# This is expected behavior in optuna_run.py where logger=False is intentional
warnings.filterwarnings(
"ignore",
message=".*You called `self.log.*logger=True.*no logger configured.*",
category=UserWarning,
module="pytorch_lightning.core.module"
)
# Silence overly verbose libraries if needed
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)
pil_logger = logging.getLogger('PIL')
pil_logger.setLevel(logging.WARNING)
pl_logger = logging.getLogger('pytorch_lightning')
pl_logger.setLevel(logging.WARNING)
# --- Basic Logging Setup ---
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)-7s - %(message)s',
datefmt='%H:%M:%S')
# Get the root logger
logger = logging.getLogger()
# --- Argument Parsing (Simplified) ---
def parse_arguments():
"""Parses command-line arguments for Optuna HPO."""
parser = argparse.ArgumentParser(
description="Run Hyperparameter Optimization using Optuna.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
'-c', '--config', type=str, default='forecasting_config.yaml',
help="Path to the YAML configuration file containing HPO settings."
)
parser.add_argument(
'--output-dir', type=str, default=None,
help="Override output directory specified in the configuration file."
)
args = parser.parse_args()
return args
# --- Optuna Objective Function ---
def objective(
trial: optuna.Trial,
base_config: MainConfig,
df: pd.DataFrame,
) -> float: # Ensure it returns a float
"""
Optuna single-objective function using a classic train/val/test split.
Returns:
- Validation score from the classic split (minimize).
"""
logger.info(f"\n--- Starting Optuna Trial {trial.number} ---")
trial_start_time = time.perf_counter()
# Get HPO settings
hpo_config = base_config.optuna
# Use the specific metric name from config for validation monitoring
validation_metric_monitor = hpo_config.metric_to_optimize
# The objective will be based on the validation metric
# Hardcode minimization for the single objective
monitor_mode = "min"
worst_value = float('inf') # Represents a poor validation score
# --- 1. Suggest Hyperparameters ---
try:
# Deep copy the base config dictionary to modify for the trial
trial_config_dict = copy.deepcopy(base_config.model_dump(mode='python')) # Use mode='python' for easier modification
except Exception as e:
logger.error(f"Trial {trial.number}: Failed to deep copy base configuration: {e}", exc_info=True)
# Return worst value
return worst_value
# ----- Suggest Hyperparameters -----
# (Suggest parameters as before, modifying trial_config_dict)
trial_config_dict['training']['learning_rate'] = trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True)
trial_config_dict['training']['batch_size'] = trial.suggest_categorical('batch_size', [32, 64, 128, 256])
trial_config_dict['training']['loss_function'] = trial.suggest_categorical('loss_function', ['MSE', 'MAE'])
trial_config_dict['model']['hidden_size'] = trial.suggest_int('hidden_size', 18, 498, step=32)
trial_config_dict['model']['num_layers'] = trial.suggest_int('num_layers', 1, 8)
trial_config_dict['model']['dropout'] = trial.suggest_float('dropout', 0.0, 0.25, step=0.05)
trial_config_dict['features']['sequence_length'] = trial.suggest_int('sequence_length', 24, 168, step=12)
# Note: forecast_horizon is NOT tuned here, taken from base_config
trial_config_dict['features']['scaling_method'] = trial.suggest_categorical('scaling_method', ['standard', 'minmax', None])
use_configured_lags = trial.suggest_categorical('use_lags', [True, False])
if not use_configured_lags: trial_config_dict['features']['lags'] = []
use_configured_rolling = trial.suggest_categorical('use_rolling_windows', [True, False])
if not use_configured_rolling: trial_config_dict['features']['rolling_window_sizes'] = []
trial_config_dict['features']['use_time_features'] = trial.suggest_categorical('use_time_features', [True, False])
trial_config_dict['features']['sinus_curve'] = trial.suggest_categorical('sinus_curve', [True, False])
trial_config_dict['features']['cosine_curve'] = trial.suggest_categorical('cosine_curve', [True, False])
trial_config_dict['features']['fill_nan'] = trial.suggest_categorical('fill_nan', ['ffill', 'bfill', 0])
# ----- End of Hyperparameter Suggestions -----
# --- 2. Re-validate Trial Config ---
try:
# Transfer forecast horizon from base config (not tuned)
trial_config_dict['features']['forecast_horizon'] = base_config.features.forecast_horizon
trial_config = MainConfig(**trial_config_dict)
logger.info(f"Trial {trial.number} Parameters: {trial.params}")
except Exception as e:
logger.error(f"Trial {trial.number}: Invalid configuration generated: {e}")
return worst_value
# --- Early check for invalid sequence length / forecast horizon combination ---
try:
# Make sure forecast_horizon is a list before checking max()
if not isinstance(trial_config.features.forecast_horizon, list) or not trial_config.features.forecast_horizon:
raise ValueError("Trial config has invalid forecast_horizon list.")
min_data_for_sequence = trial_config.features.sequence_length + max(trial_config.features.forecast_horizon)
if min_data_for_sequence > len(df):
logger.warning(f"Trial {trial.number}: Skipped. sequence_length ({trial_config.features.sequence_length}) + "
f"max_horizon ({max(trial_config.features.forecast_horizon)}) "
f"exceeds data length ({len(df)}).")
# Optuna doesn't directly support skipping, so return worst values
return worst_value
except Exception as e:
logger.error(f"Trial {trial.number}: Error during pre-check: {e}", exc_info=True)
return worst_value
# --- 3. Run Classic Train/Test ---
logger.info(f"Trial {trial.number}: Starting Classic Run...")
try:
n_samples = len(df)
val_frac = trial_config.cross_validation.val_size_fraction
test_frac = trial_config.cross_validation.test_size_fraction
train_idx_cl, val_idx_cl, test_idx_cl = split_data_classic(n_samples, val_frac, test_frac)
# Prepare data for classic split
train_loader_cl, val_loader_cl, test_loader_cl, target_scaler_cl, data_scaler, input_size_cl = prepare_fold_data_and_loaders(
full_df=df, train_idx=train_idx_cl, val_idx=val_idx_cl, test_idx=test_idx_cl,
target_col=trial_config.data.target_col, feature_config=trial_config.features,
train_config=trial_config.training, eval_config=trial_config.evaluation
)
# Initialize Model
model_cl = LSTMForecastLightningModule(
model_config=trial_config.model, train_config=trial_config.training,
input_size=input_size_cl, target_scaler=target_scaler_cl, data_scaler=data_scaler
)
# Callbacks (EarlyStopping and Pruning)
callbacks_cl = []
if trial_config.training.early_stopping_patience and trial_config.training.early_stopping_patience > 0:
callbacks_cl.append(EarlyStopping(
monitor=validation_metric_monitor, # Monitor validation metric
patience=trial_config.training.early_stopping_patience,
mode=monitor_mode, verbose=False
))
# Trainer for classic run
trainer_cl = pl.Trainer(
check_val_every_n_epoch=trial_config.training.check_val_n_epoch,
accelerator='gpu' if torch.cuda.is_available() else 'cpu', devices=1 if torch.cuda.is_available() else None,
max_epochs=trial_config.training.epochs, callbacks=callbacks_cl, logger=False, # logger=False as per original
enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False,
gradient_clip_val=getattr(trial_config.training, 'gradient_clip_val', None),
precision=getattr(trial_config.training, 'precision', 32),
)
# --- Train Model ---
logger.info(f"Trial {trial.number}: Fitting model on classic train/val split...")
trainer_cl.fit(model_cl, train_dataloaders=train_loader_cl, val_dataloaders=val_loader_cl)
# --- Get Best Validation Score ---
# Check early stopping callback first if it exists
best_score_tensor = None
if callbacks_cl and isinstance(callbacks_cl[0], EarlyStopping):
if hasattr(callbacks_cl[0], 'best_score') and callbacks_cl[0].best_score is not None:
best_score_tensor = callbacks_cl[0].best_score
elif callbacks_cl[0].stopped_epoch > 0 : # Early stopping triggered
logger.debug(f"Trial {trial.number}: Early stopping triggered, attempting to use last callback metric.")
# If early stopping didn't capture best score, use last metrics from trainer
if best_score_tensor is None:
metric_val = trainer_cl.callback_metrics.get(validation_metric_monitor)
if metric_val is not None:
best_score_tensor = metric_val # Use the last logged value
if best_score_tensor is None:
logger.warning(f"Trial {trial.number}: Metric '{validation_metric_monitor}' not found in callbacks or metrics. Using {worst_value}.")
validation_metric_value = worst_value
else:
validation_metric_value = best_score_tensor.item()
logger.info(f"Trial {trial.number}: Best val score ({validation_metric_monitor}) = {validation_metric_value:.4f}")
# Report intermediate value for pruning (if enabled)
trial.report(validation_metric_value, trainer_cl.current_epoch)
if trial.should_prune():
logger.info(f"Trial {trial.number}: Pruned.")
raise optuna.TrialPruned()
# Note: We don't run prediction/evaluation on the test set here,
# as the objective is based on validation performance.
# The test set evaluation will be done later for the best trial.
logger.info(f"Trial {trial.number}: Finished Classic Run in {time.perf_counter() - trial_start_time:.2f}s")
except optuna.TrialPruned:
# Propagate prune signal, objective will be set to "worst" later by Optuna
raise
except Exception as e:
logger.error(f"Trial {trial.number}: Failed during classic run phase: {e}", exc_info=True)
validation_metric_value = worst_value # Assign worst value if classic run fails
finally:
# Clean up GPU memory after the run
del model_cl, trainer_cl, train_loader_cl, val_loader_cl, test_loader_cl
if torch.cuda.is_available(): torch.cuda.empty_cache()
# --- 4. Return Objective ---
logger.info(f"--- Trial {trial.number}: Finished ---")
logger.info(f" Objective (Validation {validation_metric_monitor}): {validation_metric_value:.5f}")
logger.info(f" Total time: {time.perf_counter() - trial_start_time:.2f}s")
# Return the single objective (validation metric)
return float(validation_metric_value)
# --- Main HPO Execution ---
def run_hpo():
"""Main execution function for HPO."""
args = parse_arguments()
config_path = Path(args.config)
try:
base_config = load_config(config_path, MainConfig) # Load base config once
logger.info(f"Successfully loaded configuration from {config_path}")
except Exception as e:
logger.critical(f"Failed to load configuration from {config_path}: {e}", exc_info=True)
sys.exit(1)
# --- Setup Output Dir ---
# 1. Determine the main output directory
if args.output_dir:
main_output_dir = Path(args.output_dir)
logger.info(f"Using main output directory from command line: {main_output_dir}")
elif hasattr(base_config, 'output_dir') and base_config.output_dir:
main_output_dir = Path(base_config.output_dir)
logger.info(f"Using main output directory from config file: {main_output_dir}")
else:
main_output_dir = Path("output") # Default
logger.warning(f"No output directory specified in config or args, defaulting to: {main_output_dir}")
# 2. Define the specific directory for this classic HPO run
classic_hpo_output_dir = main_output_dir / "classic"
# 3. Create directories
main_output_dir.mkdir(parents=True, exist_ok=True)
classic_hpo_output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Classic HPO outputs will be saved under: {classic_hpo_output_dir}")
# --- Setup Logging ---
try:
level_name = base_config.log_level.upper()
effective_log_level = logging.getLevelName(level_name)
# Ensure study name is filesystem-safe
safe_study_name = base_config.optuna.study_name
safe_study_name = "".join(c if c.isalnum() or c in ('_', '-') else '_' for c in safe_study_name)
# Place log file directly inside the classic HPO directory
log_file = classic_hpo_output_dir / f"{safe_study_name}_hpo.log" # Changed filename slightly
file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8') # Add encoding
formatter = logging.Formatter('%(asctime)s - %(name)-25s - %(levelname)-7s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
file_handler.setFormatter(formatter)
# Add handler only if it's not already added
if not any(isinstance(h, logging.FileHandler) and h.baseFilename == str(log_file.resolve()) for h in logger.handlers):
logger.addHandler(file_handler)
logger.setLevel(effective_log_level)
logger.info(f"Set log level to {level_name}. Logging HPO run to console and {log_file}")
if effective_log_level <= logging.DEBUG: logger.debug("Debug logging enabled.")
except (AttributeError, ValueError, TypeError) as e: # Add TypeError
logger.warning(f"Could not set log level from config. Defaulting to INFO. Error: {e}")
logger.setLevel(logging.INFO)
# Still try to log to a default file if possible
try:
# Default log file also goes into the specific classic directory
log_file = classic_hpo_output_dir / "default_classic_hpo.log"
file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8')
formatter = logging.Formatter('%(asctime)s - %(name)-25s - %(levelname)-7s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
file_handler.setFormatter(formatter)
if not any(isinstance(h, logging.FileHandler) and h.baseFilename == str(log_file.resolve()) for h in logger.handlers):
logger.addHandler(file_handler)
logger.info(f"Logging to default file: {log_file}")
except Exception as log_e:
logger.error(f"Failed to set up default file logging: {log_e}")
# Setup seeding...
try:
seed = base_config.random_seed
set_seeds(seed)
except AttributeError:
logger.warning("Config missing 'random_seed'. Using default seed 42.")
set_seeds(42)
# Load data...
try:
logger.info("Loading base dataset for HPO...")
df = load_raw_data(base_config.data)
logger.info(f"Base dataset loaded. Shape: {df.shape}")
except Exception as e:
logger.critical(f"Failed to load raw data for HPO: {e}", exc_info=True)
sys.exit(1)
# --- Optuna Study Setup ---
try:
hpo_config = base_config.optuna
if not hpo_config.enabled:
logger.info("Optuna optimization is disabled in the configuration.")
sys.exit(0)
except AttributeError:
logger.critical("Optuna configuration section ('optuna') missing.")
sys.exit(1)
storage_string = hpo_config.storage
storage_path = None # Initialize
if storage_string and storage_string.startswith("sqlite:///"):
db_filename = storage_string.replace("sqlite:///", "").strip()
if not db_filename:
# Use study name if filename is empty
db_filename = f"{safe_study_name}.db" # Default DB name for classic
logger.warning(f"SQLite path in config was empty, using default filename: {db_filename}")
# Place the DB file inside the classic HPO directory
db_path = classic_hpo_output_dir / db_filename
storage_path = f"sqlite:///{db_path.resolve()}"
logger.info(f"Using SQLite storage: {storage_path}")
elif storage_string:
# Assume it's a non-SQLite connection string or a pre-configured path
storage_path = storage_string
logger.warning(f"Using non-SQLite Optuna storage: {storage_path}. Note: DB file will not be placed inside {classic_hpo_output_dir}")
else:
storage_path = None # Explicitly set to None for in-memory
logger.warning("No Optuna storage DB specified, using in-memory storage (results will be lost on exit).")
try:
# Change to single objective 'minimize'
study = optuna.create_study(
study_name=hpo_config.study_name,
storage=storage_path,
direction=hpo_config.direction, # Use direction from config
load_if_exists=True,
pruner=optuna.pruners.MedianPruner() if hpo_config.pruning else optuna.pruners.NopPruner()
)
# --- Run Optimization ---
logger.info(f"Starting Optuna single-objective optimization: study='{hpo_config.study_name}', n_trials={hpo_config.n_trials}, direction='{hpo_config.direction}'")
study.optimize(
lambda trial: objective(trial, base_config, df), # Objective doesn't need the path here
n_trials=hpo_config.n_trials,
timeout=None,
gc_after_trial=True
)
# --- Report and Process Best Trial ---
logger.info("--- Optuna HPO Finished ---")
logger.info(f"Number of finished trials: {len(study.trials)}")
best_trial = None
try:
best_trial = study.best_trial
except ValueError:
logger.warning("Optuna study finished, but no successful trial was completed.")
if best_trial:
logger.info(f"Best trial found (Trial {best_trial.number}):")
# Log details for the best trial
validation_metric_monitor = base_config.optuna.metric_to_optimize
logger.info(f" Objective ({validation_metric_monitor}): {best_trial.value:.5f}") # Use .value for single objective
logger.info(f" Hyperparameters:")
for key, value in best_trial.params.items():
logger.info(f" {key}: {value}")
# --- Re-run and Save Artifacts for the Best Trial ---
logger.info(f"-> Re-running Best Trial {best_trial.number} to save artifacts...")
# Define the output directory for this specific best trial run
best_trial_output_dir = classic_hpo_output_dir / f"best_trial_num{best_trial.number}"
best_trial_output_dir.mkdir(parents=True, exist_ok=True)
try:
# 1. Create config for this trial
best_config_dict = copy.deepcopy(base_config.model_dump(mode='python'))
# Update with trial's hyperparameters
for key, value in best_trial.params.items(): # Use best_trial.params
# Need to place the param in the correct nested config dict
if key in best_config_dict['training']: best_config_dict['training'][key] = value
elif key in best_config_dict['model']: best_config_dict['model'][key] = value
elif key in best_config_dict['features']: best_config_dict['features'][key] = value
# Add more sections if HPO tunes them
# Add non-tuned features forecast_horizon back
best_config_dict['features']['forecast_horizon'] = base_config.features.forecast_horizon
# Ensure evaluation plots and model saving are enabled for this final run
best_config_dict['evaluation']['save_plots'] = True
# Add a flag to save the model if not already present/configurable
# best_config_dict['training']['save_model'] = True # Assuming you have this or handle it in run_model_training
best_trial_config = MainConfig(**best_config_dict)
# Save the specific config used for this best run inside its directory
with open(best_trial_output_dir / "best_run_config.yaml", 'w', encoding='utf-8') as f:
yaml.dump(best_config_dict, f, default_flow_style=False, sort_keys=False, allow_unicode=True)
# 2. Run classic training, saving outputs under best_trial_output_dir
logger.info(f"-> Running classic training for Best Trial {best_trial.number}...")
run_model_training(
config=best_trial_config,
full_df=df,
# Pass the specific directory for this run's artifacts
output_base_dir=best_trial_output_dir
)
logger.info(f"-> Finished re-running and saving artifacts for Best Trial {best_trial.number} to {best_trial_output_dir}")
except Exception as e:
logger.error(f"-> Failed to re-run or save artifacts for Best Trial {best_trial.number}: {e}", exc_info=True)
# --- Save Best Hyperparameters and Config at the Top Level ---
# Save best parameters file directly into the classic HPO output dir
best_params_file = classic_hpo_output_dir / f"{safe_study_name}_best_params.json"
try:
with open(best_params_file, 'w', encoding='utf-8') as f:
import json
json.dump(best_trial.params, f, indent=4) # Use best_trial.params
logger.info(f"Hyperparameters of the best trial saved to {best_params_file}")
except Exception as e:
logger.error(f"Failed to save parameters for best trial: {e}", exc_info=True)
# Save the best config file directly into the classic HPO output dir
best_config_file = classic_hpo_output_dir / f"{safe_study_name}_best_config.yaml"
try:
# best_config_dict should still hold the config from the re-run step above
with open(best_config_file, 'w', encoding='utf-8') as f:
yaml.dump(best_config_dict, f, default_flow_style=False, sort_keys=False, allow_unicode=True)
logger.info(f"Configuration for best trial saved to {best_config_file}")
except Exception as e:
logger.error(f"Failed to save best configuration: {e}", exc_info=True)
except optuna.exceptions.StorageInternalError as e:
logger.critical(f"Optuna storage error: {e}. Check storage path/permissions: {storage_path}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
run_hpo()