This commit is contained in:
2025-05-02 10:45:06 +02:00
commit 7c9d809a82
29 changed files with 2931 additions and 0 deletions

View File

@ -0,0 +1,5 @@
"""
Utility functions and classes for the forecasting model.
This package contains configuration models, helper functions, and other utilities.
"""

View File

@ -0,0 +1,62 @@
from pydantic import BaseModel, Field
from typing import Optional, List, Union
from enum import Enum
class WaveletTransformConfig(BaseModel):
apply: bool = False
target_or_feature: str = "target"
wavelet_type: str = "db4"
level: int = 3
use_coeffs: List[str] = ["approx", "detail_1"]
class DataConfig(BaseModel):
data_path: str
datetime_col: str
target_col: str
class FeatureConfig(BaseModel):
sequence_length: int
forecast_horizon: int
lags: List[int]
rolling_window_sizes: List[int]
use_time_features: bool
scaling_method: Optional[str] = None
wavelet_transform: Optional[WaveletTransformConfig] = None
class ModelConfig(BaseModel):
input_size: Optional[int] = None # Will be calculated
hidden_size: int
num_layers: int
dropout: float
use_residual_skips: bool = False
output_size: Optional[int] = None # Will be calculated
class TrainingConfig(BaseModel):
batch_size: int
epochs: int
learning_rate: float
loss_function: str
device: str
early_stopping_patience: Optional[int] = None
scheduler_step_size: Optional[int] = None
scheduler_gamma: Optional[float] = None
class CrossValidationConfig(BaseModel):
n_splits: int
test_size_fraction: float
val_size_fraction: float
initial_train_size: Optional[Union[int, float]] = None
class EvaluationConfig(BaseModel):
metrics: List[str]
eval_batch_size: int
save_plots: bool
plot_sample_size: int
class MainConfig(BaseModel):
data: DataConfig
features: FeatureConfig
model: ModelConfig
training: TrainingConfig
cross_validation: CrossValidationConfig
evaluation: EvaluationConfig