import logging import yaml import json from pathlib import Path from typing import Dict, Any, Optional, List import torch from sklearn.base import BaseEstimator, TransformerMixin # For scaler type hint # Import necessary components from forecasting_model from forecasting_model.utils.forecast_config_model import MainConfig, FeatureConfig from forecasting_model.train.model import LSTMForecastLightningModule logger = logging.getLogger(__name__) def load_single_model_artifact( model_path: Path, config_path: Path, input_size_path: Path, target_scaler_path: Optional[Path] = None ) -> Optional[Dict[str, Any]]: """ Loads artifacts for a single trained model checkpoint. Args: model_path: Path to the model checkpoint file (.ckpt). config_path: Path to the corresponding main YAML config file. input_size_path: Path to the input_size.pt file. target_scaler_path: Optional path to the target_scaler.pt file. Returns: A dictionary containing loaded artifacts ('model_instance', 'feature_config', 'target_scaler', 'main_forecasting_config'), or None if loading fails. """ logger.info(f"Loading single model artifact from directory: {model_path.parent}") loaded_artifacts = {} try: # 1. Load Config if not config_path.is_file(): logger.error(f"Config file not found at {config_path}") return None with open(config_path, 'r') as f: config_data = yaml.safe_load(f) main_config = MainConfig(**config_data) loaded_artifacts['main_forecasting_config'] = main_config loaded_artifacts['feature_config'] = main_config.features logger.debug(f"Loaded config from {config_path}") # 2. Load Input Size if not input_size_path.is_file(): logger.error(f"Input size file not found at {input_size_path}") return None input_size = torch.load(input_size_path) if not isinstance(input_size, int) or input_size <= 0: logger.error(f"Invalid input size loaded from {input_size_path}: {input_size}") return None logger.debug(f"Loaded input size ({input_size}) from {input_size_path}") # 3. Load Target Scaler (Optional) target_scaler = None if target_scaler_path: if not target_scaler_path.is_file(): logger.warning(f"Target scaler file not found at {target_scaler_path}. Proceeding without scaler.") else: try: target_scaler = torch.load(target_scaler_path) # Basic check if it looks like a scaler if not isinstance(target_scaler, (BaseEstimator, TransformerMixin)): logger.warning(f"Loaded object from {target_scaler_path} might not be a valid scaler ({type(target_scaler)}).") # Decide if this should be a hard failure or just a warning else: logger.debug(f"Loaded target scaler from {target_scaler_path}") except Exception as e: logger.error(f"Error loading target scaler from {target_scaler_path}: {e}", exc_info=True) # Decide if this should be a hard failure return None # Fail hard if scaler loading fails loaded_artifacts['target_scaler'] = target_scaler # 4. Initialize Model Architecture # Ensure model config forecast horizon matches feature config (should be guaranteed by MainConfig validation) if set(main_config.model.forecast_horizon) != set(main_config.features.forecast_horizon): logger.warning(f"Mismatch between model ({main_config.model.forecast_horizon}) and feature ({main_config.features.forecast_horizon}) forecast horizons in config {config_path}. Using feature config.") # This might indicate an issue with the saved config, but we proceed using the feature config horizon # main_config.model.forecast_horizon = main_config.features.forecast_horizon # Correct it for model init? Risky. model_instance = LSTMForecastLightningModule( model_config=main_config.model, train_config=main_config.training, # Pass train config if needed input_size=input_size, target_scaler=target_scaler # Pass scaler to model if it uses it internally during inference ) logger.debug("Initialized model architecture.") # 5. Load Model State Dictionary if not model_path.is_file(): logger.error(f"Model checkpoint file not found at {model_path}") return None # Load onto CPU first to avoid GPU memory issues if the loading machine is different state_dict = torch.load(model_path, map_location=torch.device('cpu')) # Adjust state dict keys if saved with 'model.' prefix from Lightning wrapper common during saving ckpt if any(key.startswith('model.') for key in state_dict.get('state_dict', state_dict).keys()): state_dict = {k.partition('model.')[2]: v for k, v in state_dict.get('state_dict', state_dict).items()} logger.debug("Adjusted state dict keys (removed 'model.' prefix).") # Load the state dict # Use strict=False initially if unsure about exact key matching, but strict=True is safer try: load_result = model_instance.load_state_dict(state_dict, strict=True) logger.debug(f"Model state loaded. Result: {load_result}") except RuntimeError as e: logger.error(f"Error loading state dict into model (strict=True): {e}. Trying strict=False.") try: load_result = model_instance.load_state_dict(state_dict, strict=False) logger.warning(f"Model state loaded with strict=False. Result: {load_result}. Check for missing/unexpected keys.") except Exception as e_false: logger.error(f"Failed to load state dict even with strict=False: {e_false}", exc_info=True) return None model_instance.eval() # Set model to evaluation mode loaded_artifacts['model_instance'] = model_instance logger.info(f"Successfully loaded single model artifact: {model_path.name}") return loaded_artifacts except FileNotFoundError: logger.error(f"A required file was not found during artifact loading for {model_path.parent}.", exc_info=True) return None except yaml.YAMLError as e: logger.error(f"Error parsing YAML config file {config_path}: {e}", exc_info=True) return None except Exception as e: logger.error(f"Failed to load single model artifact from {model_path.parent}: {e}", exc_info=True) return None def load_ensemble_artifact( ensemble_definition_path: Path, hpo_base_output_dir: Path # Base directory where HPO study results (including ensemble JSON) are saved ) -> Optional[Dict[str, Any]]: """ Loads artifacts for an ensemble based on its definition JSON file. Args: ensemble_definition_path: Path to the _best_ensemble.json file. hpo_base_output_dir: The base directory where the HPO study ran and where relative paths within the JSON are anchored. Returns: A dictionary containing 'ensemble_method', 'fold_artifacts' (a list of dictionaries, each like the output of load_single_model_artifact), 'ensemble_feature_config', and 'ensemble_target_col', or None if loading fails. """ logger.info(f"Loading ensemble artifact definition from: {ensemble_definition_path}") try: if not ensemble_definition_path.is_file(): logger.error(f"Ensemble definition file not found at: {ensemble_definition_path}") return None with open(ensemble_definition_path, 'r') as f: ensemble_definition = json.load(f) except json.JSONDecodeError as e: logger.error(f"Error decoding ensemble definition JSON file: {e}", exc_info=True) return None except Exception as e: logger.error(f"Error loading ensemble definition: {e}", exc_info=True) return None # Extract information from the definition ensemble_method = ensemble_definition.get("ensemble_method") fold_models_definitions = ensemble_definition.get("fold_models") # Base directory for artifacts *relative to* hpo_base_output_dir relative_artifacts_base_dir = ensemble_definition.get("ensemble_artifacts_base_dir") if not ensemble_method or not fold_models_definitions: logger.error("Ensemble definition file is missing 'ensemble_method' or 'fold_models' list.") return None if not relative_artifacts_base_dir: logger.error("Ensemble definition file is missing 'ensemble_artifacts_base_dir'. Cannot locate fold artifacts.") return None # --- Determine Absolute Path to Fold Artifacts --- # The paths inside fold_models are relative to ensemble_artifacts_base_dir, # which itself is relative to hpo_base_output_dir. absolute_artifacts_base_dir = hpo_base_output_dir / Path(relative_artifacts_base_dir) logger.debug(f"Absolute base directory for fold artifacts: {absolute_artifacts_base_dir}") if not absolute_artifacts_base_dir.is_dir(): logger.error(f"Calculated absolute artifact base directory does not exist or is not a directory: {absolute_artifacts_base_dir}") return None loaded_fold_artifacts: List[Dict[str, Any]] = [] common_feature_config: Optional[FeatureConfig] = None common_target_col: Optional[str] = None logger.info(f"Loading artifacts for {len(fold_models_definitions)} folds defined in the ensemble...") # --- Load Artifacts for Each Fold --- for i, fold_def in enumerate(fold_models_definitions): fold_id = fold_def.get("fold_id", i + 1) logger.debug(f"--- Loading Fold {fold_id} ---") model_path_rel = fold_def.get("model_path") scaler_path_rel = fold_def.get("target_scaler_path") input_size_path_rel = fold_def.get("input_size_path") config_path_rel = fold_def.get("config_path") if not model_path_rel or not input_size_path_rel or not config_path_rel: logger.error(f"Fold {fold_id}: Definition is missing required path(s) (model, input_size, or config). Skipping fold.") continue # Construct absolute paths for this fold's artifacts try: abs_model_path = (absolute_artifacts_base_dir / Path(model_path_rel)).resolve() abs_input_size_path = (absolute_artifacts_base_dir / Path(input_size_path_rel)).resolve() abs_config_path = (absolute_artifacts_base_dir / Path(config_path_rel)).resolve() abs_scaler_path = (absolute_artifacts_base_dir / Path(scaler_path_rel)).resolve() if scaler_path_rel else None logger.debug(f"Fold {fold_id} - Model Path: {abs_model_path}") logger.debug(f"Fold {fold_id} - Config Path: {abs_config_path}") logger.debug(f"Fold {fold_id} - Input Size Path: {abs_input_size_path}") logger.debug(f"Fold {fold_id} - Scaler Path: {abs_scaler_path}") # Load the artifacts for this single fold using the other function single_fold_loaded_artifacts = load_single_model_artifact( model_path=abs_model_path, config_path=abs_config_path, input_size_path=abs_input_size_path, target_scaler_path=abs_scaler_path ) if single_fold_loaded_artifacts: # Add fold_id for reference single_fold_loaded_artifacts['fold_id'] = fold_id loaded_fold_artifacts.append(single_fold_loaded_artifacts) logger.info(f"Successfully loaded artifacts for fold {fold_id}.") # --- Consistency Check (Optional but Recommended) --- # Store the feature config and target col from the first successful fold # Then compare subsequent folds against these current_feature_config = single_fold_loaded_artifacts['feature_config'] current_target_col = single_fold_loaded_artifacts['main_forecasting_config'].data.target_col if common_feature_config is None: common_feature_config = current_feature_config common_target_col = current_target_col logger.debug(f"Set common feature config and target column based on fold {fold_id}.") else: # Compare crucial feature engineering aspects if common_feature_config.sequence_length != current_feature_config.sequence_length or \ set(common_feature_config.forecast_horizon) != set(current_feature_config.forecast_horizon) or \ common_feature_config.scaling_method != current_feature_config.scaling_method: # Add more checks if needed logger.error(f"Fold {fold_id}: Feature configuration mismatch with previous folds. Cannot proceed with this ensemble definition.") # You might want to compare more fields like lags, rolling_windows etc. return None # Fail hard if configs are inconsistent if common_target_col != current_target_col: logger.error(f"Fold {fold_id}: Target column '{current_target_col}' differs from previous folds ('{common_target_col}'). Cannot proceed.") return None # Fail hard else: logger.error(f"Failed to load artifacts for fold {fold_id}. Skipping fold.") # Decide if ensemble loading should fail if *any* fold fails # For now, we continue and will check if enough folds loaded later except TypeError as e: # Catch potential errors if paths are None or invalid types logger.error(f"Fold {fold_id}: Error constructing artifact paths - check definition file content: {e}", exc_info=True) continue except Exception as e: logger.error(f"Fold {fold_id}: Unexpected error during loading: {e}", exc_info=True) continue # Skip this fold # --- Final Checks and Return --- if not loaded_fold_artifacts: logger.error("Failed to load artifacts for *any* fold in the ensemble.") return None # Add a check if a minimum number of folds is required (e.g., > 1) if len(loaded_fold_artifacts) < 1: # Or maybe check against len(fold_models_definitions)? logger.error(f"Only successfully loaded {len(loaded_fold_artifacts)} folds, which might be insufficient for the ensemble.") # Decide if this is an error or just a warning return None if common_feature_config is None or common_target_col is None: # This should not happen if loaded_fold_artifacts is not empty, but check anyway logger.error("Internal error: Could not determine common feature config or target column for the ensemble.") return None logger.info(f"Successfully loaded artifacts for {len(loaded_fold_artifacts)} ensemble folds.") return { 'ensemble_method': ensemble_method, 'fold_artifacts': loaded_fold_artifacts, # List of dicts 'ensemble_feature_config': common_feature_config, # The common config 'ensemble_target_col': common_target_col # The common target column name }