init
This commit is contained in:
62
forecasting_model/utils/config_model.py
Normal file
62
forecasting_model/utils/config_model.py
Normal file
@ -0,0 +1,62 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, List, Union
|
||||
from enum import Enum
|
||||
|
||||
class WaveletTransformConfig(BaseModel):
|
||||
apply: bool = False
|
||||
target_or_feature: str = "target"
|
||||
wavelet_type: str = "db4"
|
||||
level: int = 3
|
||||
use_coeffs: List[str] = ["approx", "detail_1"]
|
||||
|
||||
class DataConfig(BaseModel):
|
||||
data_path: str
|
||||
datetime_col: str
|
||||
target_col: str
|
||||
|
||||
class FeatureConfig(BaseModel):
|
||||
sequence_length: int
|
||||
forecast_horizon: int
|
||||
lags: List[int]
|
||||
rolling_window_sizes: List[int]
|
||||
use_time_features: bool
|
||||
scaling_method: Optional[str] = None
|
||||
wavelet_transform: Optional[WaveletTransformConfig] = None
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
input_size: Optional[int] = None # Will be calculated
|
||||
hidden_size: int
|
||||
num_layers: int
|
||||
dropout: float
|
||||
use_residual_skips: bool = False
|
||||
output_size: Optional[int] = None # Will be calculated
|
||||
|
||||
class TrainingConfig(BaseModel):
|
||||
batch_size: int
|
||||
epochs: int
|
||||
learning_rate: float
|
||||
loss_function: str
|
||||
device: str
|
||||
early_stopping_patience: Optional[int] = None
|
||||
scheduler_step_size: Optional[int] = None
|
||||
scheduler_gamma: Optional[float] = None
|
||||
|
||||
class CrossValidationConfig(BaseModel):
|
||||
n_splits: int
|
||||
test_size_fraction: float
|
||||
val_size_fraction: float
|
||||
initial_train_size: Optional[Union[int, float]] = None
|
||||
|
||||
class EvaluationConfig(BaseModel):
|
||||
metrics: List[str]
|
||||
eval_batch_size: int
|
||||
save_plots: bool
|
||||
plot_sample_size: int
|
||||
|
||||
class MainConfig(BaseModel):
|
||||
data: DataConfig
|
||||
features: FeatureConfig
|
||||
model: ModelConfig
|
||||
training: TrainingConfig
|
||||
cross_validation: CrossValidationConfig
|
||||
evaluation: EvaluationConfig
|
Reference in New Issue
Block a user