init
This commit is contained in:
126
main.py
Normal file
126
main.py
Normal file
@ -0,0 +1,126 @@
|
||||
import logging
|
||||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Any
|
||||
|
||||
from forecasting_model.utils.config_model import MainConfig
|
||||
from forecasting_model.data_processing import (
|
||||
load_raw_data,
|
||||
TimeSeriesCrossValidationSplitter,
|
||||
prepare_fold_data_and_loaders
|
||||
)
|
||||
from forecasting_model.model import LSTMForecastModel
|
||||
from forecasting_model.trainer import Trainer
|
||||
from forecasting_model.evaluation import evaluate_fold
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def load_config(config_path: Path) -> MainConfig:
|
||||
"""
|
||||
Load and validate configuration from YAML file.
|
||||
"""
|
||||
# TODO: Implement config loading
|
||||
pass
|
||||
|
||||
def set_seeds(seed: int = 42) -> None:
|
||||
"""
|
||||
Set random seeds for reproducibility.
|
||||
"""
|
||||
# TODO: Implement seed setting
|
||||
pass
|
||||
|
||||
def determine_device(config: MainConfig) -> torch.device:
|
||||
"""
|
||||
Determine the device to use for training.
|
||||
"""
|
||||
# TODO: Implement device determination
|
||||
pass
|
||||
|
||||
def aggregate_cv_metrics(all_fold_metrics: List[Dict[str, float]]) -> Dict[str, Dict[str, float]]:
|
||||
"""
|
||||
Calculate mean and standard deviation of metrics across folds.
|
||||
"""
|
||||
# TODO: Implement metric aggregation
|
||||
pass
|
||||
|
||||
def main():
|
||||
# Load configuration
|
||||
config = load_config(Path("config.yaml"))
|
||||
|
||||
# Set random seeds
|
||||
set_seeds()
|
||||
|
||||
# Determine device
|
||||
device = determine_device(config)
|
||||
|
||||
# Load raw data
|
||||
df = load_raw_data(config.data)
|
||||
|
||||
# Initialize CV splitter
|
||||
cv_splitter = TimeSeriesCrossValidationSplitter(config.cross_validation, len(df))
|
||||
|
||||
# Initialize list to store fold metrics
|
||||
all_fold_metrics = []
|
||||
|
||||
# Cross-validation loop
|
||||
for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split(), 1):
|
||||
logger.info(f"Starting fold {fold_num}")
|
||||
|
||||
# Prepare data loaders
|
||||
train_loader, val_loader, test_loader, target_scaler, input_size = prepare_fold_data_and_loaders(
|
||||
df, train_idx, val_idx, test_idx,
|
||||
config.features, config.training, config.evaluation
|
||||
)
|
||||
|
||||
# Update model config with input size
|
||||
config.model.input_size = input_size
|
||||
|
||||
# Initialize model
|
||||
model = LSTMForecastModel(config.model).to(device)
|
||||
|
||||
# Initialize loss function
|
||||
loss_fn = torch.nn.MSELoss() if config.training.loss_function == "MSE" else torch.nn.L1Loss()
|
||||
|
||||
# Initialize scheduler if configured
|
||||
scheduler = None
|
||||
if config.training.scheduler_step_size is not None:
|
||||
# TODO: Initialize scheduler
|
||||
pass
|
||||
|
||||
# Initialize trainer
|
||||
trainer = Trainer(
|
||||
model, train_loader, val_loader, loss_fn, device,
|
||||
config.training, scheduler, target_scaler
|
||||
)
|
||||
|
||||
# Train model
|
||||
trainer.train()
|
||||
|
||||
# Evaluate on test set
|
||||
fold_metrics = evaluate_fold(
|
||||
model, test_loader, loss_fn, device,
|
||||
target_scaler, config.evaluation, fold_num
|
||||
)
|
||||
|
||||
all_fold_metrics.append(fold_metrics)
|
||||
|
||||
# Optional: Clear GPU memory
|
||||
if device.type == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Aggregate metrics
|
||||
aggregated_metrics = aggregate_cv_metrics(all_fold_metrics)
|
||||
|
||||
# Log final results
|
||||
logger.info("Cross-validation results:")
|
||||
for metric, stats in aggregated_metrics.items():
|
||||
logger.info(f"{metric}: {stats['mean']:.4f} ± {stats['std']:.4f}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user