intermediate backup
This commit is contained in:
@ -2,4 +2,32 @@
|
||||
Utility functions and classes for the forecasting model.
|
||||
|
||||
This package contains configuration models, helper functions, and other utilities.
|
||||
"""
|
||||
"""
|
||||
|
||||
# Expose configuration models
|
||||
from .config_model import (
|
||||
MainConfig,
|
||||
DataConfig,
|
||||
FeatureConfig,
|
||||
ModelConfig,
|
||||
TrainingConfig,
|
||||
CrossValidationConfig,
|
||||
EvaluationConfig,
|
||||
OptunaConfig,
|
||||
WaveletTransformConfig, # Expose nested configs if they might be used directly
|
||||
ClippingConfig
|
||||
)
|
||||
|
||||
# Define __all__ for explicit public API
|
||||
__all__ = [
|
||||
"MainConfig",
|
||||
"DataConfig",
|
||||
"FeatureConfig",
|
||||
"ModelConfig",
|
||||
"TrainingConfig",
|
||||
"CrossValidationConfig",
|
||||
"EvaluationConfig",
|
||||
"OptunaConfig",
|
||||
"WaveletTransformConfig",
|
||||
"ClippingConfig",
|
||||
]
|
@ -1,62 +1,151 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, List, Union
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from typing import Optional, List, Union, Literal
|
||||
from enum import Enum
|
||||
|
||||
# --- Nested Configs ---
|
||||
|
||||
class WaveletTransformConfig(BaseModel):
|
||||
"""Configuration for optional wavelet transform features."""
|
||||
apply: bool = False
|
||||
target_or_feature: str = "target"
|
||||
target_or_feature: Literal['target', 'feature'] = "target"
|
||||
wavelet_type: str = "db4"
|
||||
level: int = 3
|
||||
level: int = Field(3, gt=0) # Level must be positive
|
||||
use_coeffs: List[str] = ["approx", "detail_1"]
|
||||
|
||||
class ClippingConfig(BaseModel):
|
||||
"""Configuration for optional feature clipping."""
|
||||
apply: bool = False
|
||||
clip_min: float = -5.0
|
||||
clip_max: float = 5.0
|
||||
|
||||
@model_validator(mode='after')
|
||||
def check_clip_range(self) -> 'ClippingConfig':
|
||||
if self.apply and self.clip_max <= self.clip_min:
|
||||
raise ValueError(f'clip_max ({self.clip_max}) must be greater than clip_min ({self.clip_min}) when clipping is applied.')
|
||||
return self
|
||||
|
||||
# --- Main Config Sections ---
|
||||
|
||||
class DataConfig(BaseModel):
|
||||
data_path: str
|
||||
datetime_col: str
|
||||
target_col: str
|
||||
"""Configuration related to data loading and initial preparation."""
|
||||
data_path: str = Field(..., description="Path to the input CSV data file.")
|
||||
# --- Raw Data Specifics ---
|
||||
raw_datetime_col: str = Field(..., description="Name of the raw datetime column in the CSV (e.g., 'MTU (CET/CEST)')")
|
||||
raw_target_col: str = Field(..., description="Name of the raw target/price column in the CSV (e.g., 'Day-ahead Price [EUR/MWh]')")
|
||||
|
||||
raw_datetime_format: str = '%d.%m.%Y %H:%M' # Example, make it configurable if needed
|
||||
|
||||
# --- Standardized Names & Processing ---
|
||||
datetime_col: str = Field(..., description="Standardized name for the datetime index after processing (e.g., 'Timestamp')")
|
||||
target_col: str = Field(..., description="Standardized name for the target column after processing (e.g., 'Price')")
|
||||
expected_frequency: Optional[str] = Field('h', description="Expected pandas frequency string (e.g., 'h', 'D', '15min'). If null, no frequency check/setting is performed.")
|
||||
fill_initial_target_nans: bool = Field(True, description="Forward/backward fill NaNs in the target column immediately after loading?")
|
||||
|
||||
class FeatureConfig(BaseModel):
|
||||
sequence_length: int
|
||||
forecast_horizon: int
|
||||
lags: List[int]
|
||||
rolling_window_sizes: List[int]
|
||||
use_time_features: bool
|
||||
scaling_method: Optional[str] = None
|
||||
"""Configuration for feature engineering and preprocessing."""
|
||||
sequence_length: int = Field(..., gt=0)
|
||||
forecast_horizon: int = Field(..., gt=0)
|
||||
lags: List[int] = []
|
||||
rolling_window_sizes: List[int] = []
|
||||
use_time_features: bool = True
|
||||
sinus_curve: bool = False # Added
|
||||
cosin_curve: bool = False # Added
|
||||
wavelet_transform: Optional[WaveletTransformConfig] = None
|
||||
fill_nan: Optional[Union[str, float, int]] = 'ffill' # Added (e.g., 'ffill', 0)
|
||||
clipping: ClippingConfig = ClippingConfig() # Default instance
|
||||
scaling_method: Optional[Literal['standard', 'minmax']] = 'standard' # Added literal validation
|
||||
|
||||
@field_validator('lags', 'rolling_window_sizes')
|
||||
@classmethod
|
||||
def check_positive_list_values(cls, v: List[int]) -> List[int]:
|
||||
if any(val <= 0 for val in v):
|
||||
raise ValueError('Lists lags/rolling_window_sizes must contain only positive values')
|
||||
return v
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
input_size: Optional[int] = None # Will be calculated
|
||||
hidden_size: int
|
||||
num_layers: int
|
||||
dropout: float
|
||||
"""Configuration for the forecasting model architecture."""
|
||||
# input_size: Optional[int] = Field(None, gt=0) # Removed: Determined dynamically
|
||||
hidden_size: int = Field(..., gt=0)
|
||||
num_layers: int = Field(..., gt=0)
|
||||
dropout: float = Field(..., ge=0.0, le=1.0)
|
||||
use_residual_skips: bool = False
|
||||
output_size: Optional[int] = None # Will be calculated
|
||||
# Add forecast_horizon here to ensure LightningModule gets it directly
|
||||
forecast_horizon: Optional[int] = Field(None, gt=0) # Will be set from FeatureConfig
|
||||
|
||||
class TrainingConfig(BaseModel):
|
||||
batch_size: int
|
||||
epochs: int
|
||||
learning_rate: float
|
||||
loss_function: str
|
||||
device: str
|
||||
early_stopping_patience: Optional[int] = None
|
||||
scheduler_step_size: Optional[int] = None
|
||||
scheduler_gamma: Optional[float] = None
|
||||
"""Configuration for the training process (PyTorch Lightning)."""
|
||||
batch_size: int = Field(..., gt=0)
|
||||
epochs: int = Field(..., gt=0) # Max epochs
|
||||
learning_rate: float = Field(..., gt=0.0)
|
||||
loss_function: Literal['MSE', 'MAE'] = 'MSE'
|
||||
# device: str = 'auto' # Handled by PL Trainer accelerator/devices args
|
||||
early_stopping_patience: Optional[int] = Field(None, ge=1) # Patience must be >= 1 if set
|
||||
scheduler_step_size: Optional[int] = Field(None, gt=0)
|
||||
scheduler_gamma: Optional[float] = Field(None, gt=0.0, lt=1.0)
|
||||
gradient_clip_val: Optional[float] = Field(None, ge=0.0) # Added
|
||||
num_workers: int = Field(0, ge=0) # Added
|
||||
precision: Literal[16, 32, 64, 'bf16'] = 32 # Added
|
||||
|
||||
class CrossValidationConfig(BaseModel):
|
||||
n_splits: int
|
||||
test_size_fraction: float
|
||||
val_size_fraction: float
|
||||
initial_train_size: Optional[Union[int, float]] = None
|
||||
"""Configuration for time series cross-validation."""
|
||||
n_splits: int = Field(..., gt=0)
|
||||
test_size_fraction: float = Field(..., gt=0.0, lt=1.0, description="Fraction of the fixed training window size for the test set.")
|
||||
val_size_fraction: float = Field(..., gt=0.0, lt=1.0, description="Fraction of the fixed training window size for the validation set.")
|
||||
initial_train_size: Optional[Union[int, float]] = Field(None, gt=0.0, description="Size of the fixed training window (absolute number or fraction of total data > 0). If null, estimated automatically.")
|
||||
|
||||
class EvaluationConfig(BaseModel):
|
||||
metrics: List[str]
|
||||
eval_batch_size: int
|
||||
save_plots: bool
|
||||
plot_sample_size: int
|
||||
"""Configuration for the final evaluation process."""
|
||||
# metrics: List[str] = ['MAE', 'RMSE'] # Defined internally now
|
||||
eval_batch_size: int = Field(..., gt=0)
|
||||
save_plots: bool = True
|
||||
plot_sample_size: Optional[int] = Field(1000, gt=0) # Max points for plots
|
||||
|
||||
class OptunaConfig(BaseModel):
|
||||
"""Optional configuration for Optuna hyperparameter optimization."""
|
||||
enabled: bool = False
|
||||
n_trials: int = Field(20, gt=0)
|
||||
storage: Optional[str] = None # e.g., "sqlite:///output/hpo_results/study.db"
|
||||
direction: Literal['minimize', 'maximize'] = 'minimize'
|
||||
metric_to_optimize: str = 'val_mae_orig_scale'
|
||||
pruning: bool = True
|
||||
|
||||
# --- Top-Level Configuration Model ---
|
||||
|
||||
class MainConfig(BaseModel):
|
||||
"""Main configuration model nesting all sections."""
|
||||
project_name: str = "TimeSeriesForecasting"
|
||||
random_seed: Optional[int] = 42 # Added top-level seed
|
||||
|
||||
data: DataConfig
|
||||
features: FeatureConfig
|
||||
model: ModelConfig
|
||||
model: ModelConfig # ModelConfig no longer contains input_size
|
||||
training: TrainingConfig
|
||||
cross_validation: CrossValidationConfig
|
||||
evaluation: EvaluationConfig
|
||||
evaluation: EvaluationConfig
|
||||
optuna: Optional[OptunaConfig] = OptunaConfig() # Added optional Optuna config
|
||||
|
||||
@model_validator(mode='after')
|
||||
def check_forecast_horizon_consistency(self) -> 'MainConfig':
|
||||
# Ensure model config gets forecast horizon from features config if not set
|
||||
if self.features and self.model:
|
||||
if self.model.forecast_horizon is None:
|
||||
# If model config doesn't have it, set it from features config
|
||||
self.model.forecast_horizon = self.features.forecast_horizon
|
||||
elif self.model.forecast_horizon != self.features.forecast_horizon:
|
||||
# If both are set but differ, raise error
|
||||
raise ValueError(
|
||||
f"ModelConfig forecast_horizon ({self.model.forecast_horizon}) must match "
|
||||
f"FeatureConfig forecast_horizon ({self.features.forecast_horizon})."
|
||||
)
|
||||
# After potential setting, ensure model.forecast_horizon is actually set
|
||||
if self.model and (self.model.forecast_horizon is None or self.model.forecast_horizon <= 0):
|
||||
raise ValueError("ModelConfig requires a positive forecast_horizon (must be set in features config if not set explicitly in model config).")
|
||||
|
||||
# Input size check is removed as it's not part of static config anymore
|
||||
|
||||
return self
|
||||
|
||||
class Config:
|
||||
# Example configuration for Pydantic itself
|
||||
validate_assignment = True # Re-validate on assignment
|
||||
# extra = 'forbid' # Forbid extra fields not defined in schema
|
Reference in New Issue
Block a user