298 lines
15 KiB
Python
298 lines
15 KiB
Python
import logging
|
|
import yaml
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Dict, Any, Optional, List
|
|
|
|
import torch
|
|
from sklearn.base import BaseEstimator, TransformerMixin # For scaler type hint
|
|
|
|
# Import necessary components from forecasting_model
|
|
from forecasting_model.utils.forecast_config_model import MainConfig, FeatureConfig
|
|
from forecasting_model.train.model import LSTMForecastLightningModule
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def load_single_model_artifact(
|
|
model_path: Path,
|
|
config_path: Path,
|
|
input_size_path: Path,
|
|
target_scaler_path: Optional[Path] = None
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Loads artifacts for a single trained model checkpoint.
|
|
|
|
Args:
|
|
model_path: Path to the model checkpoint file (.ckpt).
|
|
config_path: Path to the corresponding main YAML config file.
|
|
input_size_path: Path to the input_size.pt file.
|
|
target_scaler_path: Optional path to the target_scaler.pt file.
|
|
|
|
Returns:
|
|
A dictionary containing loaded artifacts ('model_instance', 'feature_config',
|
|
'target_scaler', 'main_forecasting_config'), or None if loading fails.
|
|
"""
|
|
logger.info(f"Loading single model artifact from directory: {model_path.parent}")
|
|
loaded_artifacts = {}
|
|
|
|
try:
|
|
# 1. Load Config
|
|
if not config_path.is_file():
|
|
logger.error(f"Config file not found at {config_path}")
|
|
return None
|
|
with open(config_path, 'r') as f:
|
|
config_data = yaml.safe_load(f)
|
|
main_config = MainConfig(**config_data)
|
|
loaded_artifacts['main_forecasting_config'] = main_config
|
|
loaded_artifacts['feature_config'] = main_config.features
|
|
logger.debug(f"Loaded config from {config_path}")
|
|
|
|
# 2. Load Input Size
|
|
if not input_size_path.is_file():
|
|
logger.error(f"Input size file not found at {input_size_path}")
|
|
return None
|
|
input_size = torch.load(input_size_path)
|
|
if not isinstance(input_size, int) or input_size <= 0:
|
|
logger.error(f"Invalid input size loaded from {input_size_path}: {input_size}")
|
|
return None
|
|
logger.debug(f"Loaded input size ({input_size}) from {input_size_path}")
|
|
|
|
# 3. Load Target Scaler (Optional)
|
|
target_scaler = None
|
|
if target_scaler_path:
|
|
if not target_scaler_path.is_file():
|
|
logger.warning(f"Target scaler file not found at {target_scaler_path}. Proceeding without scaler.")
|
|
else:
|
|
try:
|
|
target_scaler = torch.load(target_scaler_path)
|
|
# Basic check if it looks like a scaler
|
|
if not isinstance(target_scaler, (BaseEstimator, TransformerMixin)):
|
|
logger.warning(f"Loaded object from {target_scaler_path} might not be a valid scaler ({type(target_scaler)}).")
|
|
# Decide if this should be a hard failure or just a warning
|
|
else:
|
|
logger.debug(f"Loaded target scaler from {target_scaler_path}")
|
|
except Exception as e:
|
|
logger.error(f"Error loading target scaler from {target_scaler_path}: {e}", exc_info=True)
|
|
# Decide if this should be a hard failure
|
|
return None # Fail hard if scaler loading fails
|
|
loaded_artifacts['target_scaler'] = target_scaler
|
|
|
|
# 4. Initialize Model Architecture
|
|
# Ensure model config forecast horizon matches feature config (should be guaranteed by MainConfig validation)
|
|
if set(main_config.model.forecast_horizon) != set(main_config.features.forecast_horizon):
|
|
logger.warning(f"Mismatch between model ({main_config.model.forecast_horizon}) and feature ({main_config.features.forecast_horizon}) forecast horizons in config {config_path}. Using feature config.")
|
|
# This might indicate an issue with the saved config, but we proceed using the feature config horizon
|
|
# main_config.model.forecast_horizon = main_config.features.forecast_horizon # Correct it for model init? Risky.
|
|
|
|
model_instance = LSTMForecastLightningModule(
|
|
model_config=main_config.model,
|
|
train_config=main_config.training, # Pass train config if needed
|
|
input_size=input_size,
|
|
target_scaler=target_scaler # Pass scaler to model if it uses it internally during inference
|
|
)
|
|
logger.debug("Initialized model architecture.")
|
|
|
|
# 5. Load Model State Dictionary
|
|
if not model_path.is_file():
|
|
logger.error(f"Model checkpoint file not found at {model_path}")
|
|
return None
|
|
# Load onto CPU first to avoid GPU memory issues if the loading machine is different
|
|
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
|
|
# Adjust state dict keys if saved with 'model.' prefix from Lightning wrapper common during saving ckpt
|
|
if any(key.startswith('model.') for key in state_dict.get('state_dict', state_dict).keys()):
|
|
state_dict = {k.partition('model.')[2]: v for k, v in state_dict.get('state_dict', state_dict).items()}
|
|
logger.debug("Adjusted state dict keys (removed 'model.' prefix).")
|
|
|
|
# Load the state dict
|
|
# Use strict=False initially if unsure about exact key matching, but strict=True is safer
|
|
try:
|
|
load_result = model_instance.load_state_dict(state_dict, strict=True)
|
|
logger.debug(f"Model state loaded. Result: {load_result}")
|
|
except RuntimeError as e:
|
|
logger.error(f"Error loading state dict into model (strict=True): {e}. Trying strict=False.")
|
|
try:
|
|
load_result = model_instance.load_state_dict(state_dict, strict=False)
|
|
logger.warning(f"Model state loaded with strict=False. Result: {load_result}. Check for missing/unexpected keys.")
|
|
except Exception as e_false:
|
|
logger.error(f"Failed to load state dict even with strict=False: {e_false}", exc_info=True)
|
|
return None
|
|
|
|
|
|
model_instance.eval() # Set model to evaluation mode
|
|
loaded_artifacts['model_instance'] = model_instance
|
|
logger.info(f"Successfully loaded single model artifact: {model_path.name}")
|
|
|
|
return loaded_artifacts
|
|
|
|
except FileNotFoundError:
|
|
logger.error(f"A required file was not found during artifact loading for {model_path.parent}.", exc_info=True)
|
|
return None
|
|
except yaml.YAMLError as e:
|
|
logger.error(f"Error parsing YAML config file {config_path}: {e}", exc_info=True)
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Failed to load single model artifact from {model_path.parent}: {e}", exc_info=True)
|
|
return None
|
|
|
|
|
|
def load_ensemble_artifact(
|
|
ensemble_definition_path: Path,
|
|
hpo_base_output_dir: Path # Base directory where HPO study results (including ensemble JSON) are saved
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Loads artifacts for an ensemble based on its definition JSON file.
|
|
|
|
Args:
|
|
ensemble_definition_path: Path to the _best_ensemble.json file.
|
|
hpo_base_output_dir: The base directory where the HPO study ran and
|
|
where relative paths within the JSON are anchored.
|
|
|
|
Returns:
|
|
A dictionary containing 'ensemble_method', 'fold_artifacts' (a list
|
|
of dictionaries, each like the output of load_single_model_artifact),
|
|
'ensemble_feature_config', and 'ensemble_target_col', or None if loading fails.
|
|
"""
|
|
logger.info(f"Loading ensemble artifact definition from: {ensemble_definition_path}")
|
|
|
|
try:
|
|
if not ensemble_definition_path.is_file():
|
|
logger.error(f"Ensemble definition file not found at: {ensemble_definition_path}")
|
|
return None
|
|
with open(ensemble_definition_path, 'r') as f:
|
|
ensemble_definition = json.load(f)
|
|
except json.JSONDecodeError as e:
|
|
logger.error(f"Error decoding ensemble definition JSON file: {e}", exc_info=True)
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error loading ensemble definition: {e}", exc_info=True)
|
|
return None
|
|
|
|
# Extract information from the definition
|
|
ensemble_method = ensemble_definition.get("ensemble_method")
|
|
fold_models_definitions = ensemble_definition.get("fold_models")
|
|
# Base directory for artifacts *relative to* hpo_base_output_dir
|
|
relative_artifacts_base_dir = ensemble_definition.get("ensemble_artifacts_base_dir")
|
|
|
|
if not ensemble_method or not fold_models_definitions:
|
|
logger.error("Ensemble definition file is missing 'ensemble_method' or 'fold_models' list.")
|
|
return None
|
|
if not relative_artifacts_base_dir:
|
|
logger.error("Ensemble definition file is missing 'ensemble_artifacts_base_dir'. Cannot locate fold artifacts.")
|
|
return None
|
|
|
|
# --- Determine Absolute Path to Fold Artifacts ---
|
|
# The paths inside fold_models are relative to ensemble_artifacts_base_dir,
|
|
# which itself is relative to hpo_base_output_dir.
|
|
absolute_artifacts_base_dir = hpo_base_output_dir / Path(relative_artifacts_base_dir)
|
|
logger.debug(f"Absolute base directory for fold artifacts: {absolute_artifacts_base_dir}")
|
|
if not absolute_artifacts_base_dir.is_dir():
|
|
logger.error(f"Calculated absolute artifact base directory does not exist or is not a directory: {absolute_artifacts_base_dir}")
|
|
return None
|
|
|
|
|
|
loaded_fold_artifacts: List[Dict[str, Any]] = []
|
|
common_feature_config: Optional[FeatureConfig] = None
|
|
common_target_col: Optional[str] = None
|
|
|
|
logger.info(f"Loading artifacts for {len(fold_models_definitions)} folds defined in the ensemble...")
|
|
|
|
# --- Load Artifacts for Each Fold ---
|
|
for i, fold_def in enumerate(fold_models_definitions):
|
|
fold_id = fold_def.get("fold_id", i + 1)
|
|
logger.debug(f"--- Loading Fold {fold_id} ---")
|
|
|
|
model_path_rel = fold_def.get("model_path")
|
|
scaler_path_rel = fold_def.get("target_scaler_path")
|
|
input_size_path_rel = fold_def.get("input_size_path")
|
|
config_path_rel = fold_def.get("config_path")
|
|
|
|
if not model_path_rel or not input_size_path_rel or not config_path_rel:
|
|
logger.error(f"Fold {fold_id}: Definition is missing required path(s) (model, input_size, or config). Skipping fold.")
|
|
continue
|
|
|
|
# Construct absolute paths for this fold's artifacts
|
|
try:
|
|
abs_model_path = (absolute_artifacts_base_dir / Path(model_path_rel)).resolve()
|
|
abs_input_size_path = (absolute_artifacts_base_dir / Path(input_size_path_rel)).resolve()
|
|
abs_config_path = (absolute_artifacts_base_dir / Path(config_path_rel)).resolve()
|
|
abs_scaler_path = (absolute_artifacts_base_dir / Path(scaler_path_rel)).resolve() if scaler_path_rel else None
|
|
|
|
logger.debug(f"Fold {fold_id} - Model Path: {abs_model_path}")
|
|
logger.debug(f"Fold {fold_id} - Config Path: {abs_config_path}")
|
|
logger.debug(f"Fold {fold_id} - Input Size Path: {abs_input_size_path}")
|
|
logger.debug(f"Fold {fold_id} - Scaler Path: {abs_scaler_path}")
|
|
|
|
# Load the artifacts for this single fold using the other function
|
|
single_fold_loaded_artifacts = load_single_model_artifact(
|
|
model_path=abs_model_path,
|
|
config_path=abs_config_path,
|
|
input_size_path=abs_input_size_path,
|
|
target_scaler_path=abs_scaler_path
|
|
)
|
|
|
|
if single_fold_loaded_artifacts:
|
|
# Add fold_id for reference
|
|
single_fold_loaded_artifacts['fold_id'] = fold_id
|
|
loaded_fold_artifacts.append(single_fold_loaded_artifacts)
|
|
logger.info(f"Successfully loaded artifacts for fold {fold_id}.")
|
|
|
|
# --- Consistency Check (Optional but Recommended) ---
|
|
# Store the feature config and target col from the first successful fold
|
|
# Then compare subsequent folds against these
|
|
current_feature_config = single_fold_loaded_artifacts['feature_config']
|
|
current_target_col = single_fold_loaded_artifacts['main_forecasting_config'].data.target_col
|
|
|
|
if common_feature_config is None:
|
|
common_feature_config = current_feature_config
|
|
common_target_col = current_target_col
|
|
logger.debug(f"Set common feature config and target column based on fold {fold_id}.")
|
|
else:
|
|
# Compare crucial feature engineering aspects
|
|
if common_feature_config.sequence_length != current_feature_config.sequence_length or \
|
|
set(common_feature_config.forecast_horizon) != set(current_feature_config.forecast_horizon) or \
|
|
common_feature_config.scaling_method != current_feature_config.scaling_method: # Add more checks if needed
|
|
logger.error(f"Fold {fold_id}: Feature configuration mismatch with previous folds. Cannot proceed with this ensemble definition.")
|
|
# You might want to compare more fields like lags, rolling_windows etc.
|
|
return None # Fail hard if configs are inconsistent
|
|
if common_target_col != current_target_col:
|
|
logger.error(f"Fold {fold_id}: Target column '{current_target_col}' differs from previous folds ('{common_target_col}'). Cannot proceed.")
|
|
return None # Fail hard
|
|
|
|
else:
|
|
logger.error(f"Failed to load artifacts for fold {fold_id}. Skipping fold.")
|
|
# Decide if ensemble loading should fail if *any* fold fails
|
|
# For now, we continue and will check if enough folds loaded later
|
|
|
|
except TypeError as e:
|
|
# Catch potential errors if paths are None or invalid types
|
|
logger.error(f"Fold {fold_id}: Error constructing artifact paths - check definition file content: {e}", exc_info=True)
|
|
continue
|
|
except Exception as e:
|
|
logger.error(f"Fold {fold_id}: Unexpected error during loading: {e}", exc_info=True)
|
|
continue # Skip this fold
|
|
|
|
# --- Final Checks and Return ---
|
|
if not loaded_fold_artifacts:
|
|
logger.error("Failed to load artifacts for *any* fold in the ensemble.")
|
|
return None
|
|
|
|
# Add a check if a minimum number of folds is required (e.g., > 1)
|
|
if len(loaded_fold_artifacts) < 1: # Or maybe check against len(fold_models_definitions)?
|
|
logger.error(f"Only successfully loaded {len(loaded_fold_artifacts)} folds, which might be insufficient for the ensemble.")
|
|
# Decide if this is an error or just a warning
|
|
return None
|
|
|
|
if common_feature_config is None or common_target_col is None:
|
|
# This should not happen if loaded_fold_artifacts is not empty, but check anyway
|
|
logger.error("Internal error: Could not determine common feature config or target column for the ensemble.")
|
|
return None
|
|
|
|
logger.info(f"Successfully loaded artifacts for {len(loaded_fold_artifacts)} ensemble folds.")
|
|
|
|
return {
|
|
'ensemble_method': ensemble_method,
|
|
'fold_artifacts': loaded_fold_artifacts, # List of dicts
|
|
'ensemble_feature_config': common_feature_config, # The common config
|
|
'ensemble_target_col': common_target_col # The common target column name
|
|
}
|