intermediate backup
This commit is contained in:
@ -5,7 +5,7 @@ This package contains configuration models, helper functions, and other utilitie
|
||||
"""
|
||||
|
||||
# Expose configuration models
|
||||
from .config_model import (
|
||||
from .forecast_config_model import (
|
||||
MainConfig,
|
||||
DataConfig,
|
||||
FeatureConfig,
|
||||
|
@ -44,7 +44,7 @@ class DataConfig(BaseModel):
|
||||
class FeatureConfig(BaseModel):
|
||||
"""Configuration for feature engineering and preprocessing."""
|
||||
sequence_length: int = Field(..., gt=0)
|
||||
forecast_horizon: int = Field(..., gt=0)
|
||||
forecast_horizon: List[int] = Field(..., min_length=1, description="List of specific forecast horizons to predict (e.g., [1, 6, 12]).")
|
||||
lags: List[int] = []
|
||||
rolling_window_sizes: List[int] = []
|
||||
use_time_features: bool = True
|
||||
@ -55,11 +55,11 @@ class FeatureConfig(BaseModel):
|
||||
clipping: ClippingConfig = ClippingConfig() # Default instance
|
||||
scaling_method: Optional[Literal['standard', 'minmax']] = 'standard' # Added literal validation
|
||||
|
||||
@field_validator('lags', 'rolling_window_sizes')
|
||||
@field_validator('lags', 'rolling_window_sizes', 'forecast_horizon')
|
||||
@classmethod
|
||||
def check_positive_list_values(cls, v: List[int]) -> List[int]:
|
||||
if any(val <= 0 for val in v):
|
||||
raise ValueError('Lists lags/rolling_window_sizes must contain only positive values')
|
||||
raise ValueError('Lists lags, rolling_window_sizes, and forecast_horizon must contain only positive values')
|
||||
return v
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
@ -69,8 +69,8 @@ class ModelConfig(BaseModel):
|
||||
num_layers: int = Field(..., gt=0)
|
||||
dropout: float = Field(..., ge=0.0, le=1.0)
|
||||
use_residual_skips: bool = False
|
||||
# Add forecast_horizon here to ensure LightningModule gets it directly
|
||||
forecast_horizon: Optional[int] = Field(None, gt=0) # Will be set from FeatureConfig
|
||||
# forecast_horizon: Optional[int] = Field(None, gt=0) # OLD
|
||||
forecast_horizon: Optional[List[int]] = Field(None, min_length=1) # Will be set from FeatureConfig
|
||||
|
||||
class TrainingConfig(BaseModel):
|
||||
"""Configuration for the training process (PyTorch Lightning)."""
|
||||
@ -103,26 +103,35 @@ class EvaluationConfig(BaseModel):
|
||||
class OptunaConfig(BaseModel):
|
||||
"""Optional configuration for Optuna hyperparameter optimization."""
|
||||
enabled: bool = False
|
||||
study_name: str = "default_study" # Added study_name
|
||||
n_trials: int = Field(20, gt=0)
|
||||
storage: Optional[str] = None # e.g., "sqlite:///output/hpo_results/study.db"
|
||||
direction: Literal['minimize', 'maximize'] = 'minimize'
|
||||
metric_to_optimize: str = 'val_mae_orig_scale'
|
||||
pruning: bool = True
|
||||
metric_to_optimize: str = 'val_MeanAbsoluteError_Original_Scale' # Updated default metric
|
||||
pruning: bool = True
|
||||
|
||||
# --- Top-Level Configuration Model ---
|
||||
|
||||
class MainConfig(BaseModel):
|
||||
"""Main configuration model nesting all sections."""
|
||||
project_name: str = "TimeSeriesForecasting"
|
||||
random_seed: Optional[int] = 42 # Added top-level seed
|
||||
random_seed: Optional[int] = 42
|
||||
log_level: Literal['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] = 'INFO'
|
||||
output_dir: str = Field("output/cv_results", description="Base directory for saving all outputs (results, logs, models, plots).")
|
||||
|
||||
# --- Execution Control ---
|
||||
run_cross_validation: bool = Field(True, description="Run the main cross-validation training loop?")
|
||||
run_classic_training: bool = Field(True, description="Run a single classic train/val/test split training?")
|
||||
run_ensemble_evaluation: bool = Field(True, description="Run ensemble evaluation using CV fold models?")
|
||||
# --- End Execution Control ---
|
||||
|
||||
data: DataConfig
|
||||
features: FeatureConfig
|
||||
model: ModelConfig # ModelConfig no longer contains input_size
|
||||
model: ModelConfig
|
||||
training: TrainingConfig
|
||||
cross_validation: CrossValidationConfig
|
||||
evaluation: EvaluationConfig
|
||||
optuna: Optional[OptunaConfig] = OptunaConfig() # Added optional Optuna config
|
||||
optuna: Optional[OptunaConfig] = OptunaConfig()
|
||||
|
||||
@model_validator(mode='after')
|
||||
def check_forecast_horizon_consistency(self) -> 'MainConfig':
|
||||
@ -131,20 +140,33 @@ class MainConfig(BaseModel):
|
||||
if self.model.forecast_horizon is None:
|
||||
# If model config doesn't have it, set it from features config
|
||||
self.model.forecast_horizon = self.features.forecast_horizon
|
||||
elif self.model.forecast_horizon != self.features.forecast_horizon:
|
||||
elif set(self.model.forecast_horizon) != set(self.features.forecast_horizon): # Compare sets for content equality
|
||||
# If both are set but differ, raise error
|
||||
raise ValueError(
|
||||
f"ModelConfig forecast_horizon ({self.model.forecast_horizon}) must match "
|
||||
f"FeatureConfig forecast_horizon ({self.features.forecast_horizon})."
|
||||
)
|
||||
# After potential setting, ensure model.forecast_horizon is actually set
|
||||
if self.model and (self.model.forecast_horizon is None or self.model.forecast_horizon <= 0):
|
||||
raise ValueError("ModelConfig requires a positive forecast_horizon (must be set in features config if not set explicitly in model config).")
|
||||
# After potential setting, ensure model.forecast_horizon is actually set and valid
|
||||
if self.model and (
|
||||
self.model.forecast_horizon is None or
|
||||
not isinstance(self.model.forecast_horizon, list) or # Check type
|
||||
len(self.model.forecast_horizon) == 0 or # Check not empty
|
||||
any(h <= 0 for h in self.model.forecast_horizon) # Check positive values
|
||||
):
|
||||
raise ValueError("ModelConfig requires a non-empty list of positive forecast_horizon values (must be set in features config if not set explicitly in model config).")
|
||||
|
||||
# Input size check is removed as it's not part of static config anymore
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode='after')
|
||||
def check_execution_flags(self) -> 'MainConfig':
|
||||
if not self.run_cross_validation and not self.run_classic_training:
|
||||
raise ValueError("At least one of 'run_cross_validation' or 'run_classic_training' must be True.")
|
||||
if self.run_ensemble_evaluation and not self.run_cross_validation:
|
||||
raise ValueError("'run_ensemble_evaluation' requires 'run_cross_validation' to be True (needs CV fold models).")
|
||||
return self
|
||||
|
||||
class Config:
|
||||
# Example configuration for Pydantic itself
|
||||
validate_assignment = True # Re-validate on assignment
|
173
forecasting_model/utils/helper.py
Normal file
173
forecasting_model/utils/helper.py
Normal file
@ -0,0 +1,173 @@
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
import yaml
|
||||
|
||||
from forecasting_model import MainConfig
|
||||
|
||||
# Get the root logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
"""Parses command-line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run the Time Series Forecasting training pipeline using a configuration file.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument(
|
||||
'-c', '--config',
|
||||
type=str,
|
||||
default='config.yaml',
|
||||
help="Path to the YAML configuration file."
|
||||
)
|
||||
# Removed seed, debug, and output-dir arguments
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def load_config(config_path: Path) -> MainConfig:
|
||||
"""
|
||||
Load and validate configuration from YAML file using Pydantic.
|
||||
|
||||
Args:
|
||||
config_path: Path to the YAML configuration file.
|
||||
|
||||
Returns:
|
||||
Validated MainConfig object.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the config file doesn't exist.
|
||||
yaml.YAMLError: If the file is not valid YAML.
|
||||
pydantic.ValidationError: If the config doesn't match the schema.
|
||||
"""
|
||||
if not config_path.is_file():
|
||||
logger.error(f"Configuration file not found at: {config_path}")
|
||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||
|
||||
logger.info(f"Loading configuration from: {config_path}")
|
||||
try:
|
||||
with open(config_path, 'r') as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
|
||||
# Validate configuration using Pydantic model
|
||||
config = MainConfig(**config_dict)
|
||||
logger.info("Configuration loaded and validated successfully.")
|
||||
return config
|
||||
except yaml.YAMLError as e:
|
||||
logger.error(f"Error parsing YAML file {config_path}: {e}", exc_info=True)
|
||||
raise
|
||||
except Exception as e: # Catches Pydantic validation errors too
|
||||
logger.error(f"Error validating configuration {config_path}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def set_seeds(seed: Optional[int] = 42) -> None:
|
||||
"""
|
||||
Set random seeds for reproducibility across libraries.
|
||||
|
||||
Args:
|
||||
seed: The seed value to use. If None, uses default 42.
|
||||
"""
|
||||
actual_seed = seed if seed is not None else 42
|
||||
if seed is None:
|
||||
logger.warning(f"No random_seed specified in config, using default seed: {actual_seed}")
|
||||
else:
|
||||
logger.info(f"Setting random seed from config: {actual_seed}")
|
||||
|
||||
random.seed(actual_seed)
|
||||
np.random.seed(actual_seed)
|
||||
torch.manual_seed(actual_seed)
|
||||
# Ensure reproducibility for CUDA operations where possible
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(actual_seed)
|
||||
torch.cuda.manual_seed_all(actual_seed) # For multi-GPU
|
||||
# These settings can slow down training but improve reproducibility
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.backends.cudnn.benchmark = False
|
||||
# PyTorch Lightning seeding (optional, as we seed torch directly)
|
||||
# pl.seed_everything(seed, workers=True) # workers=True ensures dataloader reproducibility
|
||||
|
||||
|
||||
def aggregate_cv_metrics(all_fold_metrics: List[Dict[str, float]]) -> Dict[str, Dict[str, float]]:
|
||||
"""
|
||||
Calculate mean and standard deviation of metrics across folds.
|
||||
Handles potential NaN values by ignoring them.
|
||||
|
||||
Args:
|
||||
all_fold_metrics: A list where each element is a dictionary of
|
||||
metrics for one fold (e.g., {'MAE': v1, 'RMSE': v2}).
|
||||
|
||||
Returns:
|
||||
A dictionary where keys are metric names and values are dicts
|
||||
containing 'mean' and 'std' for that metric across folds.
|
||||
Example: {'MAE': {'mean': m, 'std': s}, 'RMSE': {'mean': m2, 'std': s2}}
|
||||
"""
|
||||
if not all_fold_metrics:
|
||||
logger.warning("Received empty list for metric aggregation.")
|
||||
return {}
|
||||
|
||||
aggregated: Dict[str, Dict[str, float]] = {}
|
||||
# Get metric names from the first valid fold's results
|
||||
first_valid_metrics = next((m for m in all_fold_metrics if m), None)
|
||||
if not first_valid_metrics:
|
||||
logger.warning("No valid fold metrics found for aggregation.")
|
||||
return {}
|
||||
metric_names = list(first_valid_metrics.keys())
|
||||
|
||||
for metric in metric_names:
|
||||
# Collect values for this metric across all folds, ignoring NaNs
|
||||
values = [fold_metrics.get(metric) for fold_metrics in all_fold_metrics if fold_metrics and metric in fold_metrics]
|
||||
valid_values = [v for v in values if v is not None and not np.isnan(v)]
|
||||
|
||||
if not valid_values:
|
||||
logger.warning(f"No valid values found for metric '{metric}' across folds.")
|
||||
mean_val = np.nan
|
||||
std_val = np.nan
|
||||
else:
|
||||
mean_val = float(np.mean(valid_values))
|
||||
std_val = float(np.std(valid_values))
|
||||
logger.debug(f"Aggregated '{metric}': Mean={mean_val:.4f}, Std={std_val:.4f} from {len(valid_values)} folds.")
|
||||
|
||||
aggregated[metric] = {'mean': mean_val, 'std': std_val}
|
||||
|
||||
return aggregated
|
||||
|
||||
|
||||
def save_results(results: Dict, filename: Path):
|
||||
"""Save dictionary results to a JSON file."""
|
||||
try:
|
||||
filename.parent.mkdir(parents=True, exist_ok=True)
|
||||
# Convert numpy types to native Python types for JSON serialization
|
||||
results_serializable = json.loads(json.dumps(results, cls=NumpyEncoder))
|
||||
with open(filename, 'w') as f:
|
||||
json.dump(results_serializable, f, indent=4)
|
||||
logger.info(f"Saved results to {filename}")
|
||||
except TypeError as e:
|
||||
logger.error(f"Serialization error saving results to {filename}. Check for non-serializable types (e.g., numpy types): {e}", exc_info=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save results to {filename}: {e}", exc_info=True)
|
||||
|
||||
|
||||
class NumpyEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, np.integer):
|
||||
return int(obj)
|
||||
elif isinstance(obj, np.floating):
|
||||
return float(obj)
|
||||
elif isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
elif isinstance(obj, (np.bool_, bool)):
|
||||
return bool(obj)
|
||||
elif pd.isna(obj): # Handle pandas NaT or numpy NaN gracefully
|
||||
return None
|
||||
return super(NumpyEncoder, self).default(obj)
|
Reference in New Issue
Block a user