Files
entrix_case_challange/forecasting_model/utils/helper.py
2025-05-03 20:46:14 +02:00

174 lines
6.2 KiB
Python

import argparse
import json
import logging
import random
from pathlib import Path
from typing import Optional, List, Dict
import numpy as np
import pandas as pd
import torch
import yaml
from forecasting_model import MainConfig
# Get the root logger
logger = logging.getLogger(__name__)
def parse_arguments():
"""Parses command-line arguments."""
parser = argparse.ArgumentParser(
description="Run the Time Series Forecasting training pipeline using a configuration file.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
'-c', '--config',
type=str,
default='config.yaml',
help="Path to the YAML configuration file."
)
# Removed seed, debug, and output-dir arguments
args = parser.parse_args()
return args
def load_config(config_path: Path) -> MainConfig:
"""
Load and validate configuration from YAML file using Pydantic.
Args:
config_path: Path to the YAML configuration file.
Returns:
Validated MainConfig object.
Raises:
FileNotFoundError: If the config file doesn't exist.
yaml.YAMLError: If the file is not valid YAML.
pydantic.ValidationError: If the config doesn't match the schema.
"""
if not config_path.is_file():
logger.error(f"Configuration file not found at: {config_path}")
raise FileNotFoundError(f"Config file not found: {config_path}")
logger.info(f"Loading configuration from: {config_path}")
try:
with open(config_path, 'r') as f:
config_dict = yaml.safe_load(f)
# Validate configuration using Pydantic model
config = MainConfig(**config_dict)
logger.info("Configuration loaded and validated successfully.")
return config
except yaml.YAMLError as e:
logger.error(f"Error parsing YAML file {config_path}: {e}", exc_info=True)
raise
except Exception as e: # Catches Pydantic validation errors too
logger.error(f"Error validating configuration {config_path}: {e}", exc_info=True)
raise
def set_seeds(seed: Optional[int] = 42) -> None:
"""
Set random seeds for reproducibility across libraries.
Args:
seed: The seed value to use. If None, uses default 42.
"""
actual_seed = seed if seed is not None else 42
if seed is None:
logger.warning(f"No random_seed specified in config, using default seed: {actual_seed}")
else:
logger.info(f"Setting random seed from config: {actual_seed}")
random.seed(actual_seed)
np.random.seed(actual_seed)
torch.manual_seed(actual_seed)
# Ensure reproducibility for CUDA operations where possible
if torch.cuda.is_available():
torch.cuda.manual_seed(actual_seed)
torch.cuda.manual_seed_all(actual_seed) # For multi-GPU
# These settings can slow down training but improve reproducibility
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# PyTorch Lightning seeding (optional, as we seed torch directly)
# pl.seed_everything(seed, workers=True) # workers=True ensures dataloader reproducibility
def aggregate_cv_metrics(all_fold_metrics: List[Dict[str, float]]) -> Dict[str, Dict[str, float]]:
"""
Calculate mean and standard deviation of metrics across folds.
Handles potential NaN values by ignoring them.
Args:
all_fold_metrics: A list where each element is a dictionary of
metrics for one fold (e.g., {'MAE': v1, 'RMSE': v2}).
Returns:
A dictionary where keys are metric names and values are dicts
containing 'mean' and 'std' for that metric across folds.
Example: {'MAE': {'mean': m, 'std': s}, 'RMSE': {'mean': m2, 'std': s2}}
"""
if not all_fold_metrics:
logger.warning("Received empty list for metric aggregation.")
return {}
aggregated: Dict[str, Dict[str, float]] = {}
# Get metric names from the first valid fold's results
first_valid_metrics = next((m for m in all_fold_metrics if m), None)
if not first_valid_metrics:
logger.warning("No valid fold metrics found for aggregation.")
return {}
metric_names = list(first_valid_metrics.keys())
for metric in metric_names:
# Collect values for this metric across all folds, ignoring NaNs
values = [fold_metrics.get(metric) for fold_metrics in all_fold_metrics if fold_metrics and metric in fold_metrics]
valid_values = [v for v in values if v is not None and not np.isnan(v)]
if not valid_values:
logger.warning(f"No valid values found for metric '{metric}' across folds.")
mean_val = np.nan
std_val = np.nan
else:
mean_val = float(np.mean(valid_values))
std_val = float(np.std(valid_values))
logger.debug(f"Aggregated '{metric}': Mean={mean_val:.4f}, Std={std_val:.4f} from {len(valid_values)} folds.")
aggregated[metric] = {'mean': mean_val, 'std': std_val}
return aggregated
def save_results(results: Dict, filename: Path):
"""Save dictionary results to a JSON file."""
try:
filename.parent.mkdir(parents=True, exist_ok=True)
# Convert numpy types to native Python types for JSON serialization
results_serializable = json.loads(json.dumps(results, cls=NumpyEncoder))
with open(filename, 'w') as f:
json.dump(results_serializable, f, indent=4)
logger.info(f"Saved results to {filename}")
except TypeError as e:
logger.error(f"Serialization error saving results to {filename}. Check for non-serializable types (e.g., numpy types): {e}", exc_info=True)
except Exception as e:
logger.error(f"Failed to save results to {filename}: {e}", exc_info=True)
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, (np.bool_, bool)):
return bool(obj)
elif pd.isna(obj): # Handle pandas NaT or numpy NaN gracefully
return None
return super(NumpyEncoder, self).default(obj)