174 lines
6.2 KiB
Python
174 lines
6.2 KiB
Python
import argparse
|
|
import json
|
|
import logging
|
|
import random
|
|
from pathlib import Path
|
|
from typing import Optional, List, Dict
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import torch
|
|
|
|
import yaml
|
|
|
|
from forecasting_model import MainConfig
|
|
|
|
# Get the root logger
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def parse_arguments():
|
|
"""Parses command-line arguments."""
|
|
parser = argparse.ArgumentParser(
|
|
description="Run the Time Series Forecasting training pipeline using a configuration file.",
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
)
|
|
parser.add_argument(
|
|
'-c', '--config',
|
|
type=str,
|
|
default='config.yaml',
|
|
help="Path to the YAML configuration file."
|
|
)
|
|
# Removed seed, debug, and output-dir arguments
|
|
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def load_config(config_path: Path) -> MainConfig:
|
|
"""
|
|
Load and validate configuration from YAML file using Pydantic.
|
|
|
|
Args:
|
|
config_path: Path to the YAML configuration file.
|
|
|
|
Returns:
|
|
Validated MainConfig object.
|
|
|
|
Raises:
|
|
FileNotFoundError: If the config file doesn't exist.
|
|
yaml.YAMLError: If the file is not valid YAML.
|
|
pydantic.ValidationError: If the config doesn't match the schema.
|
|
"""
|
|
if not config_path.is_file():
|
|
logger.error(f"Configuration file not found at: {config_path}")
|
|
raise FileNotFoundError(f"Config file not found: {config_path}")
|
|
|
|
logger.info(f"Loading configuration from: {config_path}")
|
|
try:
|
|
with open(config_path, 'r') as f:
|
|
config_dict = yaml.safe_load(f)
|
|
|
|
# Validate configuration using Pydantic model
|
|
config = MainConfig(**config_dict)
|
|
logger.info("Configuration loaded and validated successfully.")
|
|
return config
|
|
except yaml.YAMLError as e:
|
|
logger.error(f"Error parsing YAML file {config_path}: {e}", exc_info=True)
|
|
raise
|
|
except Exception as e: # Catches Pydantic validation errors too
|
|
logger.error(f"Error validating configuration {config_path}: {e}", exc_info=True)
|
|
raise
|
|
|
|
|
|
def set_seeds(seed: Optional[int] = 42) -> None:
|
|
"""
|
|
Set random seeds for reproducibility across libraries.
|
|
|
|
Args:
|
|
seed: The seed value to use. If None, uses default 42.
|
|
"""
|
|
actual_seed = seed if seed is not None else 42
|
|
if seed is None:
|
|
logger.warning(f"No random_seed specified in config, using default seed: {actual_seed}")
|
|
else:
|
|
logger.info(f"Setting random seed from config: {actual_seed}")
|
|
|
|
random.seed(actual_seed)
|
|
np.random.seed(actual_seed)
|
|
torch.manual_seed(actual_seed)
|
|
# Ensure reproducibility for CUDA operations where possible
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed(actual_seed)
|
|
torch.cuda.manual_seed_all(actual_seed) # For multi-GPU
|
|
# These settings can slow down training but improve reproducibility
|
|
# torch.backends.cudnn.deterministic = True
|
|
# torch.backends.cudnn.benchmark = False
|
|
# PyTorch Lightning seeding (optional, as we seed torch directly)
|
|
# pl.seed_everything(seed, workers=True) # workers=True ensures dataloader reproducibility
|
|
|
|
|
|
def aggregate_cv_metrics(all_fold_metrics: List[Dict[str, float]]) -> Dict[str, Dict[str, float]]:
|
|
"""
|
|
Calculate mean and standard deviation of metrics across folds.
|
|
Handles potential NaN values by ignoring them.
|
|
|
|
Args:
|
|
all_fold_metrics: A list where each element is a dictionary of
|
|
metrics for one fold (e.g., {'MAE': v1, 'RMSE': v2}).
|
|
|
|
Returns:
|
|
A dictionary where keys are metric names and values are dicts
|
|
containing 'mean' and 'std' for that metric across folds.
|
|
Example: {'MAE': {'mean': m, 'std': s}, 'RMSE': {'mean': m2, 'std': s2}}
|
|
"""
|
|
if not all_fold_metrics:
|
|
logger.warning("Received empty list for metric aggregation.")
|
|
return {}
|
|
|
|
aggregated: Dict[str, Dict[str, float]] = {}
|
|
# Get metric names from the first valid fold's results
|
|
first_valid_metrics = next((m for m in all_fold_metrics if m), None)
|
|
if not first_valid_metrics:
|
|
logger.warning("No valid fold metrics found for aggregation.")
|
|
return {}
|
|
metric_names = list(first_valid_metrics.keys())
|
|
|
|
for metric in metric_names:
|
|
# Collect values for this metric across all folds, ignoring NaNs
|
|
values = [fold_metrics.get(metric) for fold_metrics in all_fold_metrics if fold_metrics and metric in fold_metrics]
|
|
valid_values = [v for v in values if v is not None and not np.isnan(v)]
|
|
|
|
if not valid_values:
|
|
logger.warning(f"No valid values found for metric '{metric}' across folds.")
|
|
mean_val = np.nan
|
|
std_val = np.nan
|
|
else:
|
|
mean_val = float(np.mean(valid_values))
|
|
std_val = float(np.std(valid_values))
|
|
logger.debug(f"Aggregated '{metric}': Mean={mean_val:.4f}, Std={std_val:.4f} from {len(valid_values)} folds.")
|
|
|
|
aggregated[metric] = {'mean': mean_val, 'std': std_val}
|
|
|
|
return aggregated
|
|
|
|
|
|
def save_results(results: Dict, filename: Path):
|
|
"""Save dictionary results to a JSON file."""
|
|
try:
|
|
filename.parent.mkdir(parents=True, exist_ok=True)
|
|
# Convert numpy types to native Python types for JSON serialization
|
|
results_serializable = json.loads(json.dumps(results, cls=NumpyEncoder))
|
|
with open(filename, 'w') as f:
|
|
json.dump(results_serializable, f, indent=4)
|
|
logger.info(f"Saved results to {filename}")
|
|
except TypeError as e:
|
|
logger.error(f"Serialization error saving results to {filename}. Check for non-serializable types (e.g., numpy types): {e}", exc_info=True)
|
|
except Exception as e:
|
|
logger.error(f"Failed to save results to {filename}: {e}", exc_info=True)
|
|
|
|
|
|
class NumpyEncoder(json.JSONEncoder):
|
|
def default(self, obj):
|
|
if isinstance(obj, np.integer):
|
|
return int(obj)
|
|
elif isinstance(obj, np.floating):
|
|
return float(obj)
|
|
elif isinstance(obj, np.ndarray):
|
|
return obj.tolist()
|
|
elif isinstance(obj, (np.bool_, bool)):
|
|
return bool(obj)
|
|
elif pd.isna(obj): # Handle pandas NaT or numpy NaN gracefully
|
|
return None
|
|
return super(NumpyEncoder, self).default(obj)
|