123 lines
3.5 KiB
Python
123 lines
3.5 KiB
Python
import logging
|
|
import torch
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from typing import Dict, List, Any
|
|
|
|
from forecasting_model.utils.config_model import MainConfig
|
|
from forecasting_model.data_processing import (
|
|
load_raw_data,
|
|
TimeSeriesCrossValidationSplitter,
|
|
prepare_fold_data_and_loaders
|
|
)
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def load_config(config_path: Path) -> MainConfig:
|
|
"""
|
|
Load and validate configuration from YAML file.
|
|
"""
|
|
# TODO: Implement config loading
|
|
pass
|
|
|
|
def set_seeds(seed: int = 42) -> None:
|
|
"""
|
|
Set random seeds for reproducibility.
|
|
"""
|
|
# TODO: Implement seed setting
|
|
pass
|
|
|
|
def determine_device(config: MainConfig) -> torch.device:
|
|
"""
|
|
Determine the device to use for training.
|
|
"""
|
|
# TODO: Implement device determination
|
|
pass
|
|
|
|
def aggregate_cv_metrics(all_fold_metrics: List[Dict[str, float]]) -> Dict[str, Dict[str, float]]:
|
|
"""
|
|
Calculate mean and standard deviation of metrics across folds.
|
|
"""
|
|
# TODO: Implement metric aggregation
|
|
pass
|
|
|
|
def main():
|
|
# Load configuration
|
|
config = load_config(Path("config.yaml"))
|
|
|
|
# Set random seeds
|
|
set_seeds()
|
|
|
|
# Determine device
|
|
device = determine_device(config)
|
|
|
|
# Load raw data
|
|
df = load_raw_data(config.data)
|
|
|
|
# Initialize CV splitter
|
|
cv_splitter = TimeSeriesCrossValidationSplitter(config.cross_validation, len(df))
|
|
|
|
# Initialize list to store fold metrics
|
|
all_fold_metrics = []
|
|
|
|
# Cross-validation loop
|
|
for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split(), 1):
|
|
logger.info(f"Starting fold {fold_num}")
|
|
|
|
# Prepare data loaders
|
|
train_loader, val_loader, test_loader, target_scaler, input_size = prepare_fold_data_and_loaders(
|
|
df, train_idx, val_idx, test_idx,
|
|
config.features, config.training, config.evaluation
|
|
)
|
|
|
|
# Update model config with input size
|
|
config.model.input_size = input_size
|
|
|
|
# Initialize model
|
|
model = LSTMForecastModel(config.model).to(device)
|
|
|
|
# Initialize loss function
|
|
loss_fn = torch.nn.MSELoss() if config.training.loss_function == "MSE" else torch.nn.L1Loss()
|
|
|
|
# Initialize scheduler if configured
|
|
scheduler = None
|
|
if config.training.scheduler_step_size is not None:
|
|
# TODO: Initialize scheduler
|
|
pass
|
|
|
|
# Initialize trainer
|
|
trainer = Trainer(
|
|
model, train_loader, val_loader, loss_fn, device,
|
|
config.training, scheduler, target_scaler
|
|
)
|
|
|
|
# Train model
|
|
trainer.train()
|
|
|
|
# Evaluate on test set
|
|
fold_metrics = evaluate_fold(
|
|
model, test_loader, loss_fn, device,
|
|
target_scaler, config.evaluation, fold_num
|
|
)
|
|
|
|
all_fold_metrics.append(fold_metrics)
|
|
|
|
# Optional: Clear GPU memory
|
|
if device.type == "cuda":
|
|
torch.cuda.empty_cache()
|
|
|
|
# Aggregate metrics
|
|
aggregated_metrics = aggregate_cv_metrics(all_fold_metrics)
|
|
|
|
# Log final results
|
|
logger.info("Cross-validation results:")
|
|
for metric, stats in aggregated_metrics.items():
|
|
logger.info(f"{metric}: {stats['mean']:.4f} ± {stats['std']:.4f}")
|
|
|
|
if __name__ == "__main__":
|
|
main() |