Files
entrix_case_challange/main.py
2025-05-02 14:36:19 +02:00

123 lines
3.5 KiB
Python

import logging
import torch
import numpy as np
from pathlib import Path
from typing import Dict, List, Any
from forecasting_model.utils.config_model import MainConfig
from forecasting_model.data_processing import (
load_raw_data,
TimeSeriesCrossValidationSplitter,
prepare_fold_data_and_loaders
)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def load_config(config_path: Path) -> MainConfig:
"""
Load and validate configuration from YAML file.
"""
# TODO: Implement config loading
pass
def set_seeds(seed: int = 42) -> None:
"""
Set random seeds for reproducibility.
"""
# TODO: Implement seed setting
pass
def determine_device(config: MainConfig) -> torch.device:
"""
Determine the device to use for training.
"""
# TODO: Implement device determination
pass
def aggregate_cv_metrics(all_fold_metrics: List[Dict[str, float]]) -> Dict[str, Dict[str, float]]:
"""
Calculate mean and standard deviation of metrics across folds.
"""
# TODO: Implement metric aggregation
pass
def main():
# Load configuration
config = load_config(Path("config.yaml"))
# Set random seeds
set_seeds()
# Determine device
device = determine_device(config)
# Load raw data
df = load_raw_data(config.data)
# Initialize CV splitter
cv_splitter = TimeSeriesCrossValidationSplitter(config.cross_validation, len(df))
# Initialize list to store fold metrics
all_fold_metrics = []
# Cross-validation loop
for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split(), 1):
logger.info(f"Starting fold {fold_num}")
# Prepare data loaders
train_loader, val_loader, test_loader, target_scaler, input_size = prepare_fold_data_and_loaders(
df, train_idx, val_idx, test_idx,
config.features, config.training, config.evaluation
)
# Update model config with input size
config.model.input_size = input_size
# Initialize model
model = LSTMForecastModel(config.model).to(device)
# Initialize loss function
loss_fn = torch.nn.MSELoss() if config.training.loss_function == "MSE" else torch.nn.L1Loss()
# Initialize scheduler if configured
scheduler = None
if config.training.scheduler_step_size is not None:
# TODO: Initialize scheduler
pass
# Initialize trainer
trainer = Trainer(
model, train_loader, val_loader, loss_fn, device,
config.training, scheduler, target_scaler
)
# Train model
trainer.train()
# Evaluate on test set
fold_metrics = evaluate_fold(
model, test_loader, loss_fn, device,
target_scaler, config.evaluation, fold_num
)
all_fold_metrics.append(fold_metrics)
# Optional: Clear GPU memory
if device.type == "cuda":
torch.cuda.empty_cache()
# Aggregate metrics
aggregated_metrics = aggregate_cv_metrics(all_fold_metrics)
# Log final results
logger.info("Cross-validation results:")
for metric, stats in aggregated_metrics.items():
logger.info(f"{metric}: {stats['mean']:.4f} ± {stats['std']:.4f}")
if __name__ == "__main__":
main()