from pydantic import BaseModel, Field, field_validator, model_validator from typing import Optional, List, Union, Literal from enum import Enum # --- Nested Configs --- class WaveletTransformConfig(BaseModel): """Configuration for optional wavelet transform features.""" apply: bool = False target_or_feature: Literal['target', 'feature'] = "target" wavelet_type: str = "db4" level: int = Field(3, gt=0) # Level must be positive use_coeffs: List[str] = ["approx", "detail_1"] class ClippingConfig(BaseModel): """Configuration for optional feature clipping.""" apply: bool = False clip_min: float = -5.0 clip_max: float = 5.0 @model_validator(mode='after') def check_clip_range(self) -> 'ClippingConfig': if self.apply and self.clip_max <= self.clip_min: raise ValueError(f'clip_max ({self.clip_max}) must be greater than clip_min ({self.clip_min}) when clipping is applied.') return self # --- Main Config Sections --- class DataConfig(BaseModel): """Configuration related to data loading and initial preparation.""" data_path: str = Field(..., description="Path to the input CSV data file.") # --- Raw Data Specifics --- raw_datetime_col: str = Field(..., description="Name of the raw datetime column in the CSV (e.g., 'MTU (CET/CEST)')") raw_target_col: str = Field(..., description="Name of the raw target/price column in the CSV (e.g., 'Day-ahead Price [EUR/MWh]')") raw_datetime_format: str = '%d.%m.%Y %H:%M' # Example, make it configurable if needed # --- Standardized Names & Processing --- datetime_col: str = Field(..., description="Standardized name for the datetime index after processing (e.g., 'Timestamp')") target_col: str = Field(..., description="Standardized name for the target column after processing (e.g., 'Price')") expected_frequency: Optional[str] = Field('h', description="Expected pandas frequency string (e.g., 'h', 'D', '15min'). If null, no frequency check/setting is performed.") fill_initial_target_nans: bool = Field(True, description="Forward/backward fill NaNs in the target column immediately after loading?") class FeatureConfig(BaseModel): """Configuration for feature engineering and preprocessing.""" sequence_length: int = Field(..., gt=0) forecast_horizon: int = Field(..., gt=0) lags: List[int] = [] rolling_window_sizes: List[int] = [] use_time_features: bool = True sinus_curve: bool = False # Added cosin_curve: bool = False # Added wavelet_transform: Optional[WaveletTransformConfig] = None fill_nan: Optional[Union[str, float, int]] = 'ffill' # Added (e.g., 'ffill', 0) clipping: ClippingConfig = ClippingConfig() # Default instance scaling_method: Optional[Literal['standard', 'minmax']] = 'standard' # Added literal validation @field_validator('lags', 'rolling_window_sizes') @classmethod def check_positive_list_values(cls, v: List[int]) -> List[int]: if any(val <= 0 for val in v): raise ValueError('Lists lags/rolling_window_sizes must contain only positive values') return v class ModelConfig(BaseModel): """Configuration for the forecasting model architecture.""" # input_size: Optional[int] = Field(None, gt=0) # Removed: Determined dynamically hidden_size: int = Field(..., gt=0) num_layers: int = Field(..., gt=0) dropout: float = Field(..., ge=0.0, le=1.0) use_residual_skips: bool = False # Add forecast_horizon here to ensure LightningModule gets it directly forecast_horizon: Optional[int] = Field(None, gt=0) # Will be set from FeatureConfig class TrainingConfig(BaseModel): """Configuration for the training process (PyTorch Lightning).""" batch_size: int = Field(..., gt=0) epochs: int = Field(..., gt=0) # Max epochs learning_rate: float = Field(..., gt=0.0) loss_function: Literal['MSE', 'MAE'] = 'MSE' # device: str = 'auto' # Handled by PL Trainer accelerator/devices args early_stopping_patience: Optional[int] = Field(None, ge=1) # Patience must be >= 1 if set scheduler_step_size: Optional[int] = Field(None, gt=0) scheduler_gamma: Optional[float] = Field(None, gt=0.0, lt=1.0) gradient_clip_val: Optional[float] = Field(None, ge=0.0) # Added num_workers: int = Field(0, ge=0) # Added precision: Literal[16, 32, 64, 'bf16'] = 32 # Added class CrossValidationConfig(BaseModel): """Configuration for time series cross-validation.""" n_splits: int = Field(..., gt=0) test_size_fraction: float = Field(..., gt=0.0, lt=1.0, description="Fraction of the fixed training window size for the test set.") val_size_fraction: float = Field(..., gt=0.0, lt=1.0, description="Fraction of the fixed training window size for the validation set.") initial_train_size: Optional[Union[int, float]] = Field(None, gt=0.0, description="Size of the fixed training window (absolute number or fraction of total data > 0). If null, estimated automatically.") class EvaluationConfig(BaseModel): """Configuration for the final evaluation process.""" # metrics: List[str] = ['MAE', 'RMSE'] # Defined internally now eval_batch_size: int = Field(..., gt=0) save_plots: bool = True plot_sample_size: Optional[int] = Field(1000, gt=0) # Max points for plots class OptunaConfig(BaseModel): """Optional configuration for Optuna hyperparameter optimization.""" enabled: bool = False n_trials: int = Field(20, gt=0) storage: Optional[str] = None # e.g., "sqlite:///output/hpo_results/study.db" direction: Literal['minimize', 'maximize'] = 'minimize' metric_to_optimize: str = 'val_mae_orig_scale' pruning: bool = True # --- Top-Level Configuration Model --- class MainConfig(BaseModel): """Main configuration model nesting all sections.""" project_name: str = "TimeSeriesForecasting" random_seed: Optional[int] = 42 # Added top-level seed data: DataConfig features: FeatureConfig model: ModelConfig # ModelConfig no longer contains input_size training: TrainingConfig cross_validation: CrossValidationConfig evaluation: EvaluationConfig optuna: Optional[OptunaConfig] = OptunaConfig() # Added optional Optuna config @model_validator(mode='after') def check_forecast_horizon_consistency(self) -> 'MainConfig': # Ensure model config gets forecast horizon from features config if not set if self.features and self.model: if self.model.forecast_horizon is None: # If model config doesn't have it, set it from features config self.model.forecast_horizon = self.features.forecast_horizon elif self.model.forecast_horizon != self.features.forecast_horizon: # If both are set but differ, raise error raise ValueError( f"ModelConfig forecast_horizon ({self.model.forecast_horizon}) must match " f"FeatureConfig forecast_horizon ({self.features.forecast_horizon})." ) # After potential setting, ensure model.forecast_horizon is actually set if self.model and (self.model.forecast_horizon is None or self.model.forecast_horizon <= 0): raise ValueError("ModelConfig requires a positive forecast_horizon (must be set in features config if not set explicitly in model config).") # Input size check is removed as it's not part of static config anymore return self class Config: # Example configuration for Pydantic itself validate_assignment = True # Re-validate on assignment # extra = 'forbid' # Forbid extra fields not defined in schema