167 lines
6.6 KiB
Python
167 lines
6.6 KiB
Python
import logging
|
|
import yaml
|
|
from pathlib import Path
|
|
from pydantic import BaseModel, Field, ValidationError, field_validator # Use BaseModel for direct dict init
|
|
from typing import Optional # Use Optional for type hints
|
|
|
|
# --- Logger Setup ---
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# --- Configuration File Path ---
|
|
# Define the default path for the configuration file
|
|
CONFIG_YAML_PATH = Path("config.yaml")
|
|
|
|
# --- Settings Model ---
|
|
class Settings(BaseModel):
|
|
"""
|
|
Application settings loaded from YAML configuration.
|
|
|
|
This class defines the configuration structure for the forecasting model,
|
|
including data paths, logging settings, and analysis parameters.
|
|
"""
|
|
# -- General Settings --
|
|
debug: bool = Field(
|
|
default=False,
|
|
description="Enable debug mode for detailed logging and latex stderr output",
|
|
examples=[True, False]
|
|
)
|
|
log_level: str = Field(
|
|
default="INFO",
|
|
description="Logging level for the application",
|
|
examples=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
|
|
)
|
|
# -- IO Settings --
|
|
data_file: Path = Field(
|
|
default=Path("data/energy_prices.csv"),
|
|
description="Path to the input data CSV file relative to project root",
|
|
examples=["data/energy_prices.csv", "data/Day-ahead_Prices_60min.csv"]
|
|
)
|
|
latex_template_file: Optional[Path] = Field(
|
|
default=Path("data_analysis/utils/_latex_report_template.tex"),
|
|
description="Path to the LTX template file relative to project root",
|
|
examples=["data_analysis/utils/_latex_report_template.tex", "data/byo_template.tex"]
|
|
)
|
|
output_dir: Path = Field(
|
|
default=Path("output/reports"),
|
|
description="Directory to save generated plots and report artifacts",
|
|
examples=["output/reports", "analysis/results"]
|
|
)
|
|
# -- Zoom Settings (Plotting and Analysis) --
|
|
zoom_start_date: Optional[str] = Field(
|
|
default=None,
|
|
description="Start date for zoomed-in analysis plots (YYYY-MM-DD format)",
|
|
examples=["2023-01-01"]
|
|
)
|
|
zoom_end_date: Optional[str] = Field(
|
|
default=None,
|
|
description="End date for zoomed-in analysis plots (YYYY-MM-DD format)",
|
|
examples=["2023-12-31"]
|
|
)
|
|
|
|
# -- Data Settings --
|
|
expected_data_frequency: str = Field(
|
|
default="h",
|
|
description="Expected frequency of the time series data",
|
|
examples=["h", "D", "M", "Y"]
|
|
)
|
|
|
|
@field_validator('log_level')
|
|
def validate_log_level(cls, v):
|
|
"""Validate that log_level is one of the standard logging levels."""
|
|
valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
|
|
if v.upper() not in valid_levels:
|
|
raise ValueError(f"log_level must be one of {valid_levels}")
|
|
return v.upper()
|
|
|
|
@field_validator('expected_data_frequency')
|
|
def validate_frequency(cls, v):
|
|
"""Validate that frequency is a valid pandas frequency string."""
|
|
valid_freqs = ["h", "D", "M", "Y"]
|
|
v_lower = v.lower() # Convert input to lowercase for comparison
|
|
if v_lower not in [f.lower() for f in valid_freqs]:
|
|
raise ValueError(f"expected_data_frequency must be one of {valid_freqs}")
|
|
return v_lower # Return normalized lowercase value
|
|
|
|
@field_validator('zoom_start_date', 'zoom_end_date')
|
|
def validate_date_format(cls, v):
|
|
"""Validate date format if provided."""
|
|
if v is None:
|
|
return v
|
|
try:
|
|
from datetime import datetime
|
|
datetime.strptime(v, "%Y-%m-%d")
|
|
return v
|
|
except ValueError:
|
|
raise ValueError("Date must be in YYYY-MM-DD format")
|
|
|
|
@field_validator('latex_template_file')
|
|
def validate_latex_template_file(cls, latex_template_file):
|
|
return latex_template_file or cls.model_fields['latex_template_file'].default
|
|
|
|
@classmethod
|
|
def from_yaml(cls, yaml_path: Path) -> 'Settings':
|
|
"""
|
|
Load settings from a YAML file.
|
|
|
|
Args:
|
|
yaml_path: Path to the YAML configuration file
|
|
|
|
Returns:
|
|
Settings instance with values from the YAML file
|
|
|
|
Raises:
|
|
FileNotFoundError: If the YAML file doesn't exist
|
|
yaml.YAMLError: If the YAML file is invalid
|
|
ValidationError: If the YAML values don't match the schema
|
|
"""
|
|
if not yaml_path.exists():
|
|
raise FileNotFoundError(f"Configuration file not found: {yaml_path}")
|
|
|
|
try:
|
|
with open(yaml_path, 'r') as f:
|
|
config = yaml.safe_load(f)
|
|
return cls(**config)
|
|
except yaml.YAMLError as e:
|
|
logger.error(f"Error parsing YAML file {yaml_path}: {e}")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error loading settings from {yaml_path}: {e}")
|
|
raise
|
|
|
|
# --- Loading Function ---
|
|
def load_settings(config_path: Path = CONFIG_YAML_PATH) -> Settings:
|
|
"""Loads settings from a YAML file."""
|
|
logger.info(f"Attempting to load configuration from: {config_path.resolve()}")
|
|
try:
|
|
with open(config_path, 'r') as f:
|
|
config_data = yaml.safe_load(f)
|
|
if not config_data:
|
|
logger.warning(f"Configuration file {config_path} is empty. Using default settings.")
|
|
return Settings() # Return default settings if file is empty
|
|
|
|
settings = Settings(**config_data)
|
|
logger.info("Configuration loaded successfully.")
|
|
|
|
# Update logger level based on loaded settings
|
|
logging.getLogger().setLevel(settings.log_level.upper())
|
|
logger.info(f"Log level set to: {settings.log_level.upper()}")
|
|
logger.debug(settings.model_dump_json(indent=2)) # Log loaded settings at debug level
|
|
return settings
|
|
|
|
except FileNotFoundError:
|
|
logger.warning(f"Configuration file {config_path} not found. Using default settings.")
|
|
return Settings() # Return default settings if file not found
|
|
except yaml.YAMLError as e:
|
|
logger.error(f"Error parsing YAML file {config_path}: {e}. Using default settings.")
|
|
return Settings() # Return default settings on parse error
|
|
except ValidationError as e:
|
|
logger.error(f"Configuration validation error: {e}. Using default settings.")
|
|
return Settings() # Return default settings on validation error
|
|
except Exception as e:
|
|
logger.error(f"An unexpected error occurred while loading settings: {e}. Using default settings.")
|
|
return Settings() # Catch other potential errors
|
|
|
|
# --- Global Settings Instance ---
|
|
# Load settings when the module is imported
|
|
settings = load_settings()
|