62 lines
1.7 KiB
Python
62 lines
1.7 KiB
Python
from pydantic import BaseModel, Field
|
|
from typing import Optional, List, Union
|
|
from enum import Enum
|
|
|
|
class WaveletTransformConfig(BaseModel):
|
|
apply: bool = False
|
|
target_or_feature: str = "target"
|
|
wavelet_type: str = "db4"
|
|
level: int = 3
|
|
use_coeffs: List[str] = ["approx", "detail_1"]
|
|
|
|
class DataConfig(BaseModel):
|
|
data_path: str
|
|
datetime_col: str
|
|
target_col: str
|
|
|
|
class FeatureConfig(BaseModel):
|
|
sequence_length: int
|
|
forecast_horizon: int
|
|
lags: List[int]
|
|
rolling_window_sizes: List[int]
|
|
use_time_features: bool
|
|
scaling_method: Optional[str] = None
|
|
wavelet_transform: Optional[WaveletTransformConfig] = None
|
|
|
|
class ModelConfig(BaseModel):
|
|
input_size: Optional[int] = None # Will be calculated
|
|
hidden_size: int
|
|
num_layers: int
|
|
dropout: float
|
|
use_residual_skips: bool = False
|
|
output_size: Optional[int] = None # Will be calculated
|
|
|
|
class TrainingConfig(BaseModel):
|
|
batch_size: int
|
|
epochs: int
|
|
learning_rate: float
|
|
loss_function: str
|
|
device: str
|
|
early_stopping_patience: Optional[int] = None
|
|
scheduler_step_size: Optional[int] = None
|
|
scheduler_gamma: Optional[float] = None
|
|
|
|
class CrossValidationConfig(BaseModel):
|
|
n_splits: int
|
|
test_size_fraction: float
|
|
val_size_fraction: float
|
|
initial_train_size: Optional[Union[int, float]] = None
|
|
|
|
class EvaluationConfig(BaseModel):
|
|
metrics: List[str]
|
|
eval_batch_size: int
|
|
save_plots: bool
|
|
plot_sample_size: int
|
|
|
|
class MainConfig(BaseModel):
|
|
data: DataConfig
|
|
features: FeatureConfig
|
|
model: ModelConfig
|
|
training: TrainingConfig
|
|
cross_validation: CrossValidationConfig
|
|
evaluation: EvaluationConfig |