intermediate backup
This commit is contained in:
@ -1 +0,0 @@
|
||||
|
BIN
data/Tau Class - Lab Report Template.zip
Normal file
BIN
data/Tau Class - Lab Report Template.zip
Normal file
Binary file not shown.
@ -1,11 +1,8 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import json
|
||||
from typing import Optional, Dict, List, Any
|
||||
# Use utils for config if that's the structure
|
||||
from data_analysis.utils.data_config_model import settings
|
||||
import datetime
|
||||
from typing import Optional, Dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -28,11 +25,9 @@ from data_analysis.io.plotting import (
|
||||
plot_zoomed_time_series,
|
||||
plot_boxplot_by_period,
|
||||
plot_histogram,
|
||||
plot_decomposition as plot_decomposition_results, # Rename to avoid clash
|
||||
plot_decomposition,
|
||||
plot_residuals,
|
||||
plot_acf_pacf,
|
||||
plot_seasonal_subseries,
|
||||
plot_cross_correlation,
|
||||
plot_weekly_autocorrelation
|
||||
)
|
||||
# --- Import report generator ---
|
||||
@ -41,10 +36,13 @@ from data_analysis.utils.report_model import ReportData
|
||||
|
||||
|
||||
# --- Modified Pipeline Function ---
|
||||
def run_eda_pipeline():
|
||||
def run_eda_pipeline(settings):
|
||||
"""
|
||||
Orchestrates the Exploratory Data Analysis process using loaded settings
|
||||
and generates a LaTeX report.
|
||||
|
||||
args:
|
||||
- settings: Pydantic config class
|
||||
"""
|
||||
logger.info("Starting Exploratory Data Analysis Pipeline (LaTeX Report)...")
|
||||
output_dir = settings.output_dir
|
||||
@ -88,7 +86,7 @@ def run_eda_pipeline():
|
||||
else:
|
||||
logger.warning(f"Raw price column '{PRICE_COL_RAW}' not found for initial missing value check.")
|
||||
|
||||
df, err = load_and_prepare_data(settings.data_file)
|
||||
df, err = load_and_prepare_data(settings)
|
||||
if err or df is None:
|
||||
logger.error(f"Data loading failed: {err or 'Unknown error'}. Stopping pipeline.")
|
||||
raise SystemExit(1)
|
||||
@ -204,7 +202,7 @@ def run_eda_pipeline():
|
||||
if err: logger.error(f"Daily decomposition failed: {err}")
|
||||
elif decomp_daily:
|
||||
plot_name = "05_decomposition_daily.png"
|
||||
err = plot_decomposition_results(decomp_daily, "Daily (Period=24)", plots_dir / plot_name)
|
||||
err = plot_decomposition(decomp_daily, "Daily (Period=24)", plots_dir / plot_name)
|
||||
if not err: decomposition_plot_paths['daily'] = plot_name
|
||||
else: logger.warning(f"Plotting error (daily decomp): {err}")
|
||||
|
||||
@ -222,7 +220,7 @@ def run_eda_pipeline():
|
||||
if err: logger.error(f"Weekly decomposition failed: {err}")
|
||||
elif decomp_weekly:
|
||||
plot_name = "07_decomposition_weekly.png"
|
||||
err = plot_decomposition_results(decomp_weekly, "Weekly (Period=168)", plots_dir / plot_name)
|
||||
err = plot_decomposition(decomp_weekly, "Weekly (Period=168)", plots_dir / plot_name)
|
||||
if not err: decomposition_plot_paths['weekly'] = plot_name
|
||||
else: logger.warning(f"Plotting error (weekly decomp): {err}")
|
||||
|
||||
@ -366,7 +364,7 @@ def run_eda_pipeline():
|
||||
)
|
||||
try:
|
||||
generate_latex_report(
|
||||
output_dir=output_dir,
|
||||
settings=settings,
|
||||
df=df,
|
||||
report_data=report_data,
|
||||
series_name_stat=series_name_stat_tested,
|
||||
|
@ -1,9 +1,8 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
from typing import Tuple, Optional, Dict, Any
|
||||
|
||||
from data_analysis.utils.data_config_model import settings
|
||||
from data_analysis.utils.data_config_model import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -12,13 +11,13 @@ TIME_COL_RAW = "MTU (CET/CEST)"
|
||||
PRICE_COL_RAW = "Day-ahead Price [EUR/MWh]"
|
||||
PRICE_COL = "Price" # Standardized column name after processing
|
||||
|
||||
def load_and_prepare_data(file_path: Path) -> Tuple[Optional[pd.DataFrame], Optional[str]]:
|
||||
def load_and_prepare_data(settings: Settings) -> Tuple[Optional[pd.DataFrame], Optional[str]]:
|
||||
"""
|
||||
Loads the energy price CSV data, parses the time column, sets a
|
||||
DatetimeIndex, renames columns, checks frequency, and handles missing values.
|
||||
|
||||
Args:
|
||||
file_path: Path to the input CSV file.
|
||||
settings: Path to the input CSV file.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
@ -26,16 +25,16 @@ def load_and_prepare_data(file_path: Path) -> Tuple[Optional[pd.DataFrame], Opti
|
||||
May include other columns if they exist in the source.
|
||||
- str | None: Error message if loading fails, otherwise None.
|
||||
"""
|
||||
logger.info(f"Attempting to load data from: {file_path.resolve()}")
|
||||
logger.info(f"Attempting to load data from: {settings.data_file.resolve()}")
|
||||
err = None
|
||||
df = None
|
||||
try:
|
||||
# Load data, assuming header is on the first row
|
||||
df = pd.read_csv(file_path, header=0)
|
||||
df = pd.read_csv(settings.data_file, header=0)
|
||||
|
||||
# Basic check for expected columns
|
||||
if TIME_COL_RAW not in df.columns or PRICE_COL_RAW not in df.columns:
|
||||
err = f"Missing expected columns '{TIME_COL_RAW}' or '{PRICE_COL_RAW}' in {file_path}"
|
||||
err = f"Missing expected columns '{TIME_COL_RAW}' or '{PRICE_COL_RAW}' in {settings.data_file}"
|
||||
logger.error(err)
|
||||
return None, err
|
||||
|
||||
@ -89,7 +88,7 @@ def load_and_prepare_data(file_path: Path) -> Tuple[Optional[pd.DataFrame], Opti
|
||||
logger.info(f"Data loaded and prepared. Final shape: {df.shape}")
|
||||
|
||||
except FileNotFoundError:
|
||||
err = f"Data file not found: {file_path}"
|
||||
err = f"Data file not found: {settings.data_file}"
|
||||
logger.error(err)
|
||||
except Exception as e:
|
||||
err = f"An unexpected error occurred during data loading/preparation: {e}"
|
||||
|
@ -1,15 +1,14 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
import numpy as np # Import numpy for CI calculation
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from typing import Optional, List
|
||||
from typing import Optional
|
||||
|
||||
# Import analysis tools for plotting results
|
||||
from statsmodels.tsa.seasonal import DecomposeResult
|
||||
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf, seasonal_plot
|
||||
from statsmodels.tsa.stattools import ccf # Import ccf
|
||||
from statsmodels.tsa.stattools import ccf
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -9,7 +9,7 @@ import shutil
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from data_analysis.utils.data_config_model import settings # Assuming settings are configured
|
||||
from data_analysis.utils.data_config_model import Settings
|
||||
from data_analysis.utils.report_model import ReportData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -117,13 +117,13 @@ def series_to_latex(series: Optional[pd.Series], title: Optional[str] = None, ca
|
||||
|
||||
|
||||
# --- Report Generation Function (LaTeX) ---
|
||||
def compile_latex_report(report_tex_path: Path, output_dir: Path) -> bool:
|
||||
def compile_latex_report(report_tex_path: Path, settings: Settings) -> bool:
|
||||
"""
|
||||
Attempts to compile the LaTeX report using the local LaTeX installation.
|
||||
|
||||
Args:
|
||||
report_tex_path: Path to the .tex file
|
||||
output_dir: Directory where the PDF should be saved
|
||||
settings: Pydantic settings clas
|
||||
|
||||
Returns:
|
||||
bool: True if compilation was successful, False otherwise
|
||||
@ -131,8 +131,8 @@ def compile_latex_report(report_tex_path: Path, output_dir: Path) -> bool:
|
||||
logger.info(f"Attempting to compile LaTeX report: {report_tex_path}")
|
||||
|
||||
# Create necessary directories
|
||||
reports_dir = output_dir / "reports"
|
||||
tmp_dir = output_dir / "_tmp"
|
||||
reports_dir = settings.output_dir / "reports"
|
||||
tmp_dir = settings.output_dir / "_tmp"
|
||||
reports_dir.mkdir(parents=True, exist_ok=True)
|
||||
tmp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@ -235,7 +235,7 @@ def _format_stationarity_results(results: Optional[Dict[str, Any]], test_name: s
|
||||
return series_to_latex(series, title=escaped_title, label=f"{test_name.lower()}_results", escape=False)
|
||||
|
||||
def generate_latex_report(
|
||||
output_dir: Path,
|
||||
settings: Settings,
|
||||
df: Optional[pd.DataFrame],
|
||||
report_data: ReportData,
|
||||
series_name_stat: Optional[str],
|
||||
@ -244,14 +244,15 @@ def generate_latex_report(
|
||||
other_plot_paths: Optional[Dict[str, str]] = None,
|
||||
decomposition_model: str = 'additive',
|
||||
acf_pacf_lags: Optional[int] = 48,
|
||||
template_path: Path = Path("data_analysis/utils/_latex_report_template.tex")
|
||||
template_path: Path = Path("data_analysis/utils/_latex_report_template.tex"),
|
||||
tau_path: Path = Path("data_analysis/utils/tau-class")
|
||||
):
|
||||
"""Generates the LaTeX report (.tex file) by filling the template using macros."""
|
||||
logger.info(f"Generating LaTeX EDA report using template: {template_path.resolve()}")
|
||||
|
||||
reports_dir = output_dir / "reports"
|
||||
reports_dir = settings.output_dir / "reports"
|
||||
source_plots_dir = reports_dir / "plots" # Define source plot dir
|
||||
tmp_dir = output_dir / "_tmp"
|
||||
tmp_dir = settings.output_dir / "_tmp"
|
||||
tmp_plots_dir = tmp_dir / "plots" # Define target plot dir within tmp
|
||||
reports_dir.mkdir(parents=True, exist_ok=True)
|
||||
tmp_dir.mkdir(parents=True, exist_ok=True)
|
||||
@ -259,7 +260,8 @@ def generate_latex_report(
|
||||
if tmp_plots_dir.exists():
|
||||
shutil.rmtree(tmp_plots_dir)
|
||||
tmp_plots_dir.mkdir()
|
||||
shutil.copytree( output_dir / "plots", tmp_plots_dir, dirs_exist_ok=True)
|
||||
shutil.copytree( settings.output_dir / "plots", tmp_plots_dir, dirs_exist_ok=True)
|
||||
shutil.copytree( tau_path, tmp_dir / "tau-class", dirs_exist_ok=True)
|
||||
|
||||
report_tex_path = tmp_dir / "eda_report.tex"
|
||||
|
||||
@ -366,15 +368,29 @@ def generate_latex_report(
|
||||
|
||||
|
||||
# --- Generate Definitions using the new add_def ---
|
||||
# Basic Info
|
||||
add_def("reportDateGenerated", datetime.date.today(), formatter=lambda d: d.strftime("%Y-%m-%d"))
|
||||
|
||||
# Tau Class Preamble Definitions
|
||||
add_def("tauJournalName", "Data Analysis Report") # Example Value
|
||||
add_def("tauReportTitle", f"Time Series EDA Report: {settings.data_file.stem}") # Use existing logic
|
||||
add_def("tauAuthor", "Generated Automatically") # Example Value
|
||||
add_def("tauAffiliation", "Entrix Case Challenge") # Example Value
|
||||
add_def("tauProfessor", "Not Applicable") # Example Value
|
||||
add_def("tauInstitution", "Automated System") # Example Value
|
||||
add_def("tauFootinfo", "Exploratory Data Analysis") # Example Value
|
||||
add_def("tauLeadAuthor", f"{settings.data_file.stem} EDA") # Example Value
|
||||
add_def("tauCourse", "Time Series Analysis") # Example Value
|
||||
add_def("tauAbstractContent", f"This report presents an exploratory data analysis of the time series data '{settings.data_file.name}'. It covers data overview, descriptive statistics, visual patterns, decomposition, stationarity, and autocorrelation.") # Example Abstract
|
||||
add_def("tauKeywords", f"EDA, Time Series, {settings.data_file.stem}, Autocorrelation, Stationarity") # Example Keywords
|
||||
|
||||
# Existing Basic Info
|
||||
add_def("reportDateGenerated", datetime.date.today(), formatter=lambda d: d.strftime("%Y-%m-%d"), escape_if_plain=False) # Already provides format
|
||||
add_def("dataSourceDescription", f"Hourly prices from {settings.data_file.name}")
|
||||
add_def("priceVariableName", settings.data_file.stem)
|
||||
|
||||
# Info from DataFrame
|
||||
if df is not None and not df.empty:
|
||||
add_def("dateRangeStart", df.index.min().date())
|
||||
add_def("dateRangeEnd", df.index.max().date())
|
||||
add_def("dateRangeStart", df.index.min().date(), formatter=lambda d: d.strftime("%Y-%m-%d"), escape_if_plain=False)
|
||||
add_def("dateRangeEnd", df.index.max().date(), formatter=lambda d: d.strftime("%Y-%m-%d"), escape_if_plain=False)
|
||||
add_def("numDataPoints", len(df))
|
||||
freq_info = "Irregular/Not Inferred"
|
||||
if isinstance(df.index, pd.DatetimeIndex):
|
||||
@ -476,13 +492,13 @@ def generate_latex_report(
|
||||
kpss_p = kpss_res.get('p-value') if kpss_res else None
|
||||
|
||||
if adf_p is not None and kpss_p is not None:
|
||||
if adf_p < 0.05 and kpss_p >= 0.05:
|
||||
if adf_p < 0.05 <= kpss_p:
|
||||
findings_summary = "Tests suggest the series is stationary (ADF rejects H0, KPSS fails to reject H0)."
|
||||
elif adf_p >= 0.05 and kpss_p < 0.05:
|
||||
elif adf_p >= 0.05 > kpss_p:
|
||||
findings_summary = "Tests suggest the series is non-stationary (trend-stationary) and requires differencing (ADF fails to reject H0, KPSS rejects H0)."
|
||||
elif adf_p < 0.05 and kpss_p < 0.05:
|
||||
findings_summary = "Test results conflict: ADF suggests stationarity, KPSS suggests non-stationarity. May indicate difference-stationarity."
|
||||
else:
|
||||
else: # adf_p >= 0.05 and kpss_p >= 0.05
|
||||
findings_summary = "Tests suggest the series is non-stationary (unit root present) and requires differencing (Both fail to reject H0)."
|
||||
elif adf_p is not None:
|
||||
findings_summary = f"ADF test p-value: {adf_p:.4f}. Stationarity conclusion requires KPSS test."
|
||||
@ -493,7 +509,7 @@ def generate_latex_report(
|
||||
logger.warning(f"Could not generate stationarity summary: {e}")
|
||||
findings_summary = r"\textit{Error generating summary.}"
|
||||
|
||||
add_def("stationarityFindingsSummary", findings_summary)
|
||||
add_def("stationarityFindingsSummary", findings_summary, escape_if_plain=False) # Summary contains LaTeX commands
|
||||
|
||||
# Section 6 Autocorrelation
|
||||
add_def("autocorrSeriesAnalyzed", series_name_stat)
|
||||
@ -502,22 +518,6 @@ def generate_latex_report(
|
||||
add_path_def("plotPacf", acf_pacf_plot_paths, 'pacf')
|
||||
add_def("autocorrObservations", None, default=default_text, escape_if_plain=False)
|
||||
|
||||
# Section 7 Summary & Implications
|
||||
add_def("summaryTrendCycles", None, default=default_text, escape_if_plain=False)
|
||||
add_def("summarySeasonality", None, default=default_text, escape_if_plain=False)
|
||||
add_def("summaryStationarity", None, default=default_text, escape_if_plain=False)
|
||||
add_def("summaryAutocorrelations", None, default=default_text, escape_if_plain=False)
|
||||
add_def("summaryOutliersVolatility", None, default=default_text, escape_if_plain=False)
|
||||
add_def("implicationsModelChoice", None, default=default_text, escape_if_plain=False)
|
||||
add_def("implicationsFeatureEngineering", None, default=default_text, escape_if_plain=False)
|
||||
add_def("implicationsPreprocessing", None, default=default_text, escape_if_plain=False)
|
||||
add_def("implicationsEvaluation", None, default=default_text, escape_if_plain=False)
|
||||
add_def("implicationsProbabilistic", None, default=default_text, escape_if_plain=False)
|
||||
|
||||
# Section 8 Conclusion
|
||||
add_def("conclusionStatement", None, default=default_text, escape_if_plain=False)
|
||||
|
||||
|
||||
# --- Apply Definitions to Template ---
|
||||
definitions_block = "\n".join(latex_definitions)
|
||||
if "{{LATEX_DEFINITIONS}}" not in template:
|
||||
@ -531,20 +531,8 @@ def generate_latex_report(
|
||||
f.write(report_content)
|
||||
logger.info(f"Successfully generated LaTeX report source: {report_tex_path}")
|
||||
|
||||
# --- Copy Plots ---
|
||||
# This is now handled within add_path_def to copy files individually
|
||||
# logger.info(f"Copying plots from {source_plots_dir} to {tmp_plots_dir}")
|
||||
# try:
|
||||
# shutil.copytree(source_plots_dir, tmp_plots_dir, dirs_exist_ok=True) # dirs_exist_ok=True allows overwriting
|
||||
# except FileNotFoundError:
|
||||
# logger.error(f"Source plots directory not found: {source_plots_dir}")
|
||||
# raise # Re-raise error if plots dir is essential
|
||||
# except Exception as e:
|
||||
# logger.error(f"Failed to copy plots directory: {e}", exc_info=True)
|
||||
# raise # Re-raise error
|
||||
|
||||
# Attempt to compile the report
|
||||
if compile_latex_report(report_tex_path, output_dir):
|
||||
if compile_latex_report(report_tex_path, settings):
|
||||
logger.info("LaTeX report successfully compiled to PDF")
|
||||
else:
|
||||
logger.warning("LaTeX compilation failed. Check logs above. The .tex file is available for manual compilation.")
|
||||
|
@ -1,48 +1,60 @@
|
||||
% LaTeX EDA Report Template
|
||||
\documentclass[11pt, a4paper]{article}
|
||||
% LaTeX EDA Report Template - Adapted for Tau Class
|
||||
\documentclass[9pt,a4paper,twocolumn,twoside]{tau-class/tau}
|
||||
\usepackage[english]{babel}
|
||||
|
||||
% --- Packages ---
|
||||
\usepackage[utf8]{inputenc}
|
||||
\usepackage[T1]{fontenc}
|
||||
\usepackage{lmodern} % Use Latin Modern fonts
|
||||
\usepackage[margin=1in]{geometry} % Set page margins
|
||||
\usepackage{graphicx} % Required for including images
|
||||
% \graphicspath{{../reports/plots/}} % REMOVE OR COMMENT OUT THIS LINE
|
||||
\usepackage{booktabs} % For professional quality tables (\toprule, \midrule, \bottomrule)
|
||||
\usepackage{amsmath} % For math symbols and environments
|
||||
\usepackage{datetime2} % For date formatting (optional, can use simple text)
|
||||
% --- Packages (Keep necessary ones, tau includes some) ---
|
||||
% \usepackage[utf8]{inputenc} % Often not needed with modern LaTeX/tau
|
||||
% \usepackage[T1]{fontenc} % Often not needed with modern LaTeX/tau
|
||||
% \usepackage{lmodern} % Tau likely defines fonts
|
||||
% \usepackage[margin=1in]{geometry} % Tau likely handles margins, remove if needed
|
||||
\usepackage{graphicx} % Required for including images (tau might include it, but safe to keep)
|
||||
% \graphicspath{{../reports/plots/}} % Path handled by Python/compilation process
|
||||
\usepackage{booktabs} % For professional quality tables (tau might include it)
|
||||
\usepackage{amsmath} % For math symbols and environments (tau might include it)
|
||||
\usepackage{datetime2} % For date formatting (keep if \reportDateGenerated uses it)
|
||||
\usepackage{float} % For finer control over figure placement (e.g., [H] option)
|
||||
\usepackage{caption} % For customizing captions
|
||||
\usepackage{hyperref} % For clickable links (optional)
|
||||
\usepackage{sectsty} % To potentially adjust section font sizes/styles (optional)
|
||||
\usepackage{parskip} % Use vertical space between paragraphs instead of indentation
|
||||
\usepackage{ifthen} % ADD THIS PACKAGE for conditional logic
|
||||
\usepackage{caption} % For customizing captions (tau might handle this)
|
||||
\usepackage{hyperref} % For clickable links (tau likely includes this)
|
||||
\usepackage{ifthen} % Keep for conditional logic
|
||||
|
||||
% --- Hyperref Setup (Optional) ---
|
||||
\hypersetup{
|
||||
colorlinks=true,
|
||||
linkcolor=blue,
|
||||
filecolor=magenta,
|
||||
urlcolor=cyan,
|
||||
pdftitle={Time Series EDA Report},
|
||||
pdfpagemode=FullScreen,
|
||||
}
|
||||
% --- Hyperref Setup (Optional, tau might configure this) ---
|
||||
% \hypersetup{
|
||||
% colorlinks=true,
|
||||
% linkcolor=blue,
|
||||
% filecolor=magenta,
|
||||
% urlcolor=cyan,
|
||||
% pdftitle={Time Series EDA Report},
|
||||
% pdfpagemode=FullScreen,
|
||||
% }
|
||||
|
||||
% --- Custom LaTeX Definitions Placeholder ---
|
||||
{{LATEX_DEFINITIONS}} % Python script will insert \newcommand definitions here
|
||||
% Define boolean flags if they don't exist (e.g., for manual compilation)
|
||||
\ifdefined\ifShowZoomedTimeseries\else\newcommand{\ifShowZoomedTimeseries}{false}\fi
|
||||
\ifdefined\ifShowYearlyDecomp\else\newcommand{\ifShowYearlyDecomp}{false}\fi
|
||||
% Populated by Python script
|
||||
{{LATEX_DEFINITIONS}}
|
||||
|
||||
% --- Document Information ---
|
||||
\title{Time Series Exploratory Data Analysis Report: Hourly Prices}
|
||||
\author{Generated Automatically}
|
||||
\date{\reportDateGenerated} % Use the macro defined in Python
|
||||
% --- Tau Document Information (Define macros in Python) ---
|
||||
\journalname{\tauJournalName}
|
||||
\title{\tauReportTitle}
|
||||
\author{\tauAuthor} % Can add \author[affil]{Name} structure if needed
|
||||
\affil{\tauAffiliation}
|
||||
\professor{\tauProfessor}
|
||||
\institution{\tauInstitution}
|
||||
\footinfo{\tauFootinfo}
|
||||
\theday{\reportDateGenerated} % Keep using existing date macro
|
||||
\leadauthor{\tauLeadAuthor}
|
||||
\course{\tauCourse}
|
||||
|
||||
% --- Start Document ---
|
||||
\begin{document}
|
||||
|
||||
\maketitle
|
||||
\thispagestyle{firststyle} % Standard for tau first page
|
||||
|
||||
% --- Abstract (Using tau structure) ---
|
||||
\begin{abstract}
|
||||
\tauAbstractContent % Define this macro in Python
|
||||
\end{abstract}
|
||||
\keywords{\tauKeywords} % Define this macro in Python
|
||||
\tauabstract % Command to display the abstract/keywords
|
||||
|
||||
% --- Overview Section ---
|
||||
\section*{Report Overview}
|
||||
@ -67,7 +79,7 @@ Purpose: Understand the basic structure, size, and data types of the dataset. Ch
|
||||
\subsection*{Raw Data Sample}
|
||||
% Placeholder for Table: First 5 Rows
|
||||
\tableHeadData
|
||||
\vspace{\baselineskip} % Add some vertical space
|
||||
% \vspace{\baselineskip} % Remove manual spacing, let LaTeX/tau handle it
|
||||
|
||||
% Placeholder for Table: Last 5 Rows
|
||||
\tableTailData
|
||||
@ -87,13 +99,13 @@ Purpose: Summarize the central tendency, dispersion, and distribution of the pri
|
||||
\subsection*{Missing Values}
|
||||
% Placeholder for Table: Count of Missing Values
|
||||
\tableMissingCounts
|
||||
\vspace{\baselineskip}
|
||||
% \vspace{\baselineskip}
|
||||
|
||||
% Placeholder for Table: Percentage of Missing Values
|
||||
\tableMissingPercentages
|
||||
\vspace{\baselineskip}
|
||||
% \vspace{\baselineskip}
|
||||
|
||||
Observations on missing values: \missingValuesObservations % Add a text placeholder
|
||||
Observations on missing values: \missingValuesObservations
|
||||
|
||||
% --- Section 3: Visual Exploration ---
|
||||
\section{Visual Exploration of Time Series Patterns}
|
||||
@ -102,7 +114,7 @@ Purpose: Visually identify overall trends, seasonality (daily, weekly, yearly),
|
||||
\begin{figure}[H] % Use [H] from float package to place figure 'here' if possible
|
||||
\centering
|
||||
% Placeholder for Plot: Full Time Series
|
||||
\includegraphics[width=0.9\textwidth]{\plotFullTimeseries}
|
||||
\includegraphics[width=\columnwidth]{\plotFullTimeseries} % Use \columnwidth
|
||||
\caption{Full Time Series: Price vs. Time.}
|
||||
\label{fig:full_ts}
|
||||
\end{figure}
|
||||
@ -112,7 +124,7 @@ Purpose: Visually identify overall trends, seasonality (daily, weekly, yearly),
|
||||
\begin{figure}[H]
|
||||
\centering
|
||||
% Placeholder for Plot: Zoomed Time Series
|
||||
\includegraphics[width=0.9\textwidth]{\plotZoomedTimeseries}
|
||||
\includegraphics[width=\columnwidth]{\plotZoomedTimeseries} % Use \columnwidth
|
||||
\caption{Zoomed Time Series (Specific Period).}
|
||||
\label{fig:zoomed_ts}
|
||||
\end{figure}
|
||||
@ -121,7 +133,7 @@ Purpose: Visually identify overall trends, seasonality (daily, weekly, yearly),
|
||||
\begin{figure}[H]
|
||||
\centering
|
||||
% Placeholder for Plot: Histogram
|
||||
\includegraphics[width=0.7\textwidth]{\plotHistogram}
|
||||
\includegraphics[width=0.9\columnwidth]{\plotHistogram} % Use \columnwidth (maybe slightly less)
|
||||
\caption{Distribution of Price Values.}
|
||||
\label{fig:histogram}
|
||||
\end{figure}
|
||||
@ -131,7 +143,7 @@ Purpose: Visually identify overall trends, seasonality (daily, weekly, yearly),
|
||||
\begin{figure}[H]
|
||||
\centering
|
||||
% Placeholder for Plot: Box Plots by Hour
|
||||
\includegraphics[width=0.9\textwidth]{\plotBoxplotHour}
|
||||
\includegraphics[width=\columnwidth]{\plotBoxplotHour} % Use \columnwidth
|
||||
\caption{Price Distribution by Hour of Day.}
|
||||
\label{fig:boxplot_hour}
|
||||
\end{figure}
|
||||
@ -139,7 +151,7 @@ Purpose: Visually identify overall trends, seasonality (daily, weekly, yearly),
|
||||
\begin{figure}[H]
|
||||
\centering
|
||||
% Placeholder for Plot: Box Plots by Day of Week
|
||||
\includegraphics[width=0.9\textwidth]{\plotBoxplotDayofweek}
|
||||
\includegraphics[width=\columnwidth]{\plotBoxplotDayofweek} % Use \columnwidth
|
||||
\caption{Price Distribution by Day of Week.}
|
||||
\label{fig:boxplot_dayofweek}
|
||||
\end{figure}
|
||||
@ -147,7 +159,7 @@ Purpose: Visually identify overall trends, seasonality (daily, weekly, yearly),
|
||||
\begin{figure}[H]
|
||||
\centering
|
||||
% Placeholder for Plot: Box Plots by Month
|
||||
\includegraphics[width=0.9\textwidth]{\plotBoxplotMonth}
|
||||
\includegraphics[width=\columnwidth]{\plotBoxplotMonth} % Use \columnwidth
|
||||
\caption{Price Distribution by Month.}
|
||||
\label{fig:boxplot_month}
|
||||
\end{figure}
|
||||
@ -155,7 +167,7 @@ Purpose: Visually identify overall trends, seasonality (daily, weekly, yearly),
|
||||
\begin{figure}[H]
|
||||
\centering
|
||||
% Placeholder for Plot: Box Plots by Year
|
||||
\includegraphics[width=0.9\textwidth]{\plotBoxplotYear}
|
||||
\includegraphics[width=\columnwidth]{\plotBoxplotYear} % Use \columnwidth
|
||||
\caption{Price Distribution by Year.}
|
||||
\label{fig:boxplot_year}
|
||||
\end{figure}
|
||||
@ -166,7 +178,7 @@ Purpose: Visually identify overall trends, seasonality (daily, weekly, yearly),
|
||||
\begin{figure}[H]
|
||||
\centering
|
||||
% Placeholder for Optional Plot: Seasonal Sub-series Daily
|
||||
\includegraphics[width=0.9\textwidth]{\plotSeasonalSubseriesDaily}
|
||||
\includegraphics[width=\columnwidth]{\plotSeasonalSubseriesDaily} % Use \columnwidth
|
||||
\caption{Seasonal Sub-series Plot (Daily Pattern).}
|
||||
\label{fig:subseries_daily}
|
||||
\end{figure}
|
||||
@ -174,13 +186,11 @@ Purpose: Visually identify overall trends, seasonality (daily, weekly, yearly),
|
||||
\begin{figure}[H]
|
||||
\centering
|
||||
% Placeholder for Optional Plot: Seasonal Sub-series Weekly
|
||||
\includegraphics[width=0.9\textwidth]{\plotSeasonalSubseriesWeekly}
|
||||
\includegraphics[width=\columnwidth]{\plotSeasonalSubseriesWeekly} % Use \columnwidth
|
||||
\caption{Seasonal Sub-series Plot (Weekly Pattern).}
|
||||
\label{fig:subseries_weekly}
|
||||
\end{figure}
|
||||
|
||||
Observations on seasonal interactions: \seasonalInteractionsObservations % Placeholder
|
||||
|
||||
% --- Section 4: Time Series Decomposition ---
|
||||
\section{Time Series Decomposition}
|
||||
Purpose: Separate the time series into its underlying components: Trend, Seasonality, and Residuals. Assess how well the decomposition captures the main patterns.
|
||||
@ -190,7 +200,7 @@ Method Used: \decompositionMethodDetails
|
||||
\begin{figure}[H]
|
||||
\centering
|
||||
% Placeholder for Plot: Decomposition (Daily Period)
|
||||
\includegraphics[width=0.9\textwidth]{\plotDecompositionDaily}
|
||||
\includegraphics[width=\columnwidth]{\plotDecompositionDaily} % Use \columnwidth
|
||||
\caption{Time Series Decomposition (Daily Seasonality, Period=24).}
|
||||
\label{fig:decomp_daily}
|
||||
\end{figure}
|
||||
@ -198,7 +208,7 @@ Method Used: \decompositionMethodDetails
|
||||
\begin{figure}[H]
|
||||
\centering
|
||||
% Placeholder for Plot: Decomposition (Weekly Period)
|
||||
\includegraphics[width=0.9\textwidth]{\plotDecompositionWeekly}
|
||||
\includegraphics[width=\columnwidth]{\plotDecompositionWeekly} % Use \columnwidth
|
||||
\caption{Time Series Decomposition (Weekly Seasonality, Period=168).}
|
||||
\label{fig:decomp_weekly}
|
||||
\end{figure}
|
||||
@ -210,14 +220,12 @@ Method Used: \decompositionMethodDetails
|
||||
\begin{figure}[H]
|
||||
\centering
|
||||
% Placeholder for Plot: Decomposition (Yearly Period) - Optional
|
||||
\includegraphics[width=0.9\textwidth]{\plotDecompositionYearly}
|
||||
\includegraphics[width=\columnwidth]{\plotDecompositionYearly} % Use \columnwidth
|
||||
\caption{Time Series Decomposition (Yearly Seasonality, Period=8760).}
|
||||
\label{fig:decomp_yearly}
|
||||
\end{figure}
|
||||
}{} % Empty 'else' part - include nothing if false
|
||||
|
||||
Observations on decomposition: \decompositionObservations % Placeholder
|
||||
|
||||
% --- Section 5: Stationarity Analysis ---
|
||||
\section{Stationarity Analysis}
|
||||
Purpose: Determine if the statistical properties (mean, variance, autocorrelation) are constant over time.
|
||||
@ -232,7 +240,7 @@ Refer to the trend component in the decomposition plots (Figures \ref{fig:decomp
|
||||
\begin{figure}[H]
|
||||
\centering
|
||||
% Placeholder for Plot: Residuals
|
||||
\includegraphics[width=0.9\textwidth]{\plotResiduals}
|
||||
\includegraphics[width=\columnwidth]{\plotResiduals} % Use \columnwidth
|
||||
\caption{Residuals from Decomposition (used for stationarity tests).}
|
||||
\label{fig:residuals}
|
||||
\end{figure}
|
||||
@ -240,7 +248,7 @@ Refer to the trend component in the decomposition plots (Figures \ref{fig:decomp
|
||||
\subsection*{Statistical Test Results}
|
||||
% Placeholder for Table: ADF Test Results
|
||||
\tableAdfResults
|
||||
\vspace{\baselineskip}
|
||||
% \vspace{\baselineskip}
|
||||
|
||||
% Placeholder for Table: KPSS Test Results
|
||||
\tableKpssResults
|
||||
@ -259,7 +267,7 @@ Lags Shown: \autocorrLagsShown
|
||||
\begin{figure}[H]
|
||||
\centering
|
||||
% Placeholder for Plot: ACF
|
||||
\includegraphics[width=0.9\textwidth]{\plotAcf}
|
||||
\includegraphics[width=\columnwidth]{\plotAcf} % Use \columnwidth
|
||||
\caption{Autocorrelation Function (ACF).}
|
||||
\label{fig:acf}
|
||||
\end{figure}
|
||||
@ -267,40 +275,13 @@ Lags Shown: \autocorrLagsShown
|
||||
\begin{figure}[H]
|
||||
\centering
|
||||
% Placeholder for Plot: PACF
|
||||
\includegraphics[width=0.9\textwidth]{\plotPacf}
|
||||
\includegraphics[width=\columnwidth]{\plotPacf} % Use \columnwidth
|
||||
\caption{Partial Autocorrelation Function (PACF).}
|
||||
\label{fig:pacf}
|
||||
\end{figure}
|
||||
|
||||
Observations: \autocorrObservations % Placeholder
|
||||
Observations: \autocorrObservations
|
||||
|
||||
% --- Section 7: Summary and Implications ---
|
||||
\section{Analysis Summary and Implications for Forecasting}
|
||||
Purpose: Synthesize the findings and discuss their relevance for modeling.
|
||||
|
||||
\subsection*{Key Findings Summary}
|
||||
\begin{itemize}
|
||||
\item \textbf{Trend \& Cycles:} \summaryTrendCycles
|
||||
\item \textbf{Seasonality:} \summarySeasonality
|
||||
\item \textbf{Stationarity:} \summaryStationarity
|
||||
\item \textbf{Autocorrelations:} \summaryAutocorrelations
|
||||
\item \textbf{Outliers/Volatility:} \summaryOutliersVolatility
|
||||
\end{itemize}
|
||||
|
||||
\subsection*{Implications for Day-Ahead Model}
|
||||
\begin{itemize}
|
||||
\item \textbf{Model Choice:} \implicationsModelChoice
|
||||
\item \textbf{Feature Engineering:} \implicationsFeatureEngineering
|
||||
\item \textbf{Preprocessing:} \implicationsPreprocessing
|
||||
\item \textbf{Evaluation:} \implicationsEvaluation
|
||||
\item \textbf{Probabilistic Forecasting:} \implicationsProbabilistic
|
||||
\end{itemize}
|
||||
|
||||
% --- Section 8: Conclusion ---
|
||||
\section{Conclusion}
|
||||
Purpose: Briefly summarize the EDA process.
|
||||
|
||||
\conclusionStatement % Placeholder
|
||||
|
||||
% --- End Document ---
|
||||
\end{document}
|
@ -160,7 +160,3 @@ def load_settings(config_path: Path = CONFIG_YAML_PATH) -> Settings:
|
||||
except Exception as e:
|
||||
logger.error(f"An unexpected error occurred while loading settings: {e}. Using default settings.")
|
||||
return Settings() # Catch other potential errors
|
||||
|
||||
# --- Global Settings Instance ---
|
||||
# Load settings when the module is imported
|
||||
settings = load_settings()
|
||||
|
144
data_analysis/utils/tau-class/README.md
Normal file
144
data_analysis/utils/tau-class/README.md
Normal file
@ -0,0 +1,144 @@
|
||||
# Tau ~ Version 2.4.4
|
||||
|
||||
## Description
|
||||
|
||||
The main features of the class are as follows.
|
||||
|
||||
* The class document and custom packages are available in a *single* folder.
|
||||
* Stix2 font for clear text.
|
||||
* Custom environments for notes and information.
|
||||
* Custom colours when code is inserted for programming languages (Matlab, C, C++, and LaTeX).
|
||||
* Appropriate conjunction ('y' for Spanish, 'and' for English) when two authors are included.
|
||||
* This LaTeX template is compatible with external editors.
|
||||
|
||||
## Updates Log
|
||||
|
||||
### Version 1.0.0 (01/03/2024)
|
||||
|
||||
[1] Launch of the first edition of tau-book class, made especially for academic articles and laboratory reports.
|
||||
|
||||
### Version 2.0.0 (03/03/2024)
|
||||
|
||||
[1] The table of contents has a new design.
|
||||
[2] Figure, table and code captions have a new style.
|
||||
|
||||
### Version 2.1.0 (04/03/2024)
|
||||
|
||||
[1] All URLs have the same font format.
|
||||
[2] Corrections to the author "and" were made.
|
||||
[3] Package name changed from kappa.sty to tau.sty.
|
||||
|
||||
### Version 2.2.0 (15/03/2024)
|
||||
|
||||
[1] Tau-book is dressed in midnight blue for title, sections, URLs, and more.
|
||||
[2] The \abscontent{} command was removed and the abstract was modified directly.
|
||||
[3] The title is now centered and lines have been removed for cleaner formatting.
|
||||
[4] New colors and formatting when inserting code for better appearance.
|
||||
|
||||
### Version 2.3.0 (08/04/2024)
|
||||
|
||||
[1] Class name changed from tau-book to just tau.
|
||||
[2] A new code for the abstract was created.
|
||||
[3] The abstract font was changed from italics to normal text keeping the keywords in italics.
|
||||
[4] Taublue color was changed.
|
||||
[5] San Serif font was added for title, abstract, sections, captions and environments.
|
||||
[6] The table of contents was redesigned for better visualization.
|
||||
[7] The new environment (tauenv) with customized title was included.
|
||||
[8] The appearance of the header was modified showing the title on the right or left depending on the page.
|
||||
[9] New packages were added to help Tikz perform better.
|
||||
[10] The pbalance package was added to balace the last two columns of the document (optional).
|
||||
[11] The style of the fancyfoot was changed by eliminating the lines that separated the information.
|
||||
[12] New code was added to define the space above and below in equations.
|
||||
|
||||
### Version 2.3.1 (10/04/2024)
|
||||
|
||||
[1] We say goodbye to tau.sty.
|
||||
[2] Introducing tauenvs package which includes the defined environments.
|
||||
[3] The packages that were in tau.sty were moved to the class document (tau.cls).
|
||||
|
||||
### Version 2.4.0 (14/05/2024)
|
||||
|
||||
[1] The code of the title and abstract was modified for a better adjustment.
|
||||
[2] The title is now placed on the left by default, however, it can be changed in title preferences (see appendix for more information).
|
||||
[3] Titlepos, titlefont, authorfont, professorfont now define the title format for easy modification.
|
||||
[4] When the 'professor' is not declared, the title space is automatically adjusted.
|
||||
[5] Bug fixed when 'keywords' command is not declared.
|
||||
[6] The word “keywords” now begins with a capital letter.
|
||||
[7] The color command taublue was changed to 'taucolor'.
|
||||
[8] When a code is inserted and the package 'spanish' is declared, the caption code will say “Código”.
|
||||
[9] Footer information is automatically adjusted when a command is not declared.
|
||||
[10] The 'ftitle' command is now 'footinfo'.
|
||||
[11] The footer style of the first page is not shown on odd pages.
|
||||
[12] Pbalance package is disable by default, however, uncomment if is required in 'tau.cls'.
|
||||
|
||||
### Version 2.4.1 (22/05/2024)
|
||||
|
||||
[1] Now all class files are placed in one folder (tau-class).
|
||||
[2] New command ‘journalname’ to provide more information.
|
||||
[3] The environments now have a slightly new style.
|
||||
[4] New package (taubabel) added to make the translation of the document easier.
|
||||
[5] A frame was added when placing a code.
|
||||
|
||||
### Version 2.4.2 (26/07/2024)
|
||||
|
||||
[1] The language boolean function has been removed from taubabel.sty and the language is now manually set in the main.tex to avoid confusion.
|
||||
[2] The graphics path option was added in tau.cls/packages for figures.
|
||||
|
||||
### Version 2.4.3 (01/09/2024)
|
||||
|
||||
[1] Journalname has modified its font size to improve the visual appearance of the document.
|
||||
|
||||
### Version 2.4.4 (28/02/2025)
|
||||
|
||||
[1] Added an arrow when there is a line break when a code is inserted.
|
||||
[2] Numbers in codes are now shown in blue to differentiate them.
|
||||
[3] Keywords are now shown in bold for codes.
|
||||
[4] The lstset for matlab language was removed for better integration.
|
||||
[5] The tabletext command will now display the text in italics.
|
||||
[6] Line numbers and ToC are disabled by default.
|
||||
|
||||
## Supporting
|
||||
|
||||
I appreciate that you are using tau class as your preferred template. If you would like to acknowledge this class, adding a sentence somewhere in your document such as 'this report/article was typeset using the tau class a LaTeX template' would be great!
|
||||
|
||||
**More of our work**
|
||||
|
||||
Did you like this class document? Check out our new project, made for complex articles and reports.
|
||||
|
||||
https://es.overleaf.com/latex/templates/rho-class-academic-article-template/bpgjxjjqhtfy
|
||||
|
||||
**Any contributions are welcome!**
|
||||
|
||||
Coffee keeps me awake and helps me create better LaTeX templates. If you wish to support my work, you can do so through PayPal:
|
||||
|
||||
https://www.paypal.me/GuillermoJimeenez
|
||||
|
||||
## License
|
||||
|
||||
This work is licensed under Creative Commons CC BY 4.0.
|
||||
To view a copy of CC BY 4.0 DEED, visit:
|
||||
|
||||
https://creativecommons.org/licenses/by/4.0/
|
||||
|
||||
This work consists of all files listed below as well as the products of their compilation.
|
||||
|
||||
```
|
||||
tau/
|
||||
`-- tau-class/
|
||||
|-- tau.cls
|
||||
|-- taubabel.sty
|
||||
|-- tauenvs.sty
|
||||
`-- main.tex
|
||||
`-- tau.bib
|
||||
```
|
||||
|
||||
## Contact me
|
||||
|
||||
Do you like the design, but you found a bug? Is there something you would have done differently? All comments are welcome!
|
||||
|
||||
*Instagram: memo.notess*
|
||||
*Email: memo.notess1@gmail.com*
|
||||
*Website: https://memonotess1.wixsite.com/memonotess*
|
||||
|
||||
---
|
||||
Enjoy writing with tau class :D
|
476
data_analysis/utils/tau-class/tau.cls
Normal file
476
data_analysis/utils/tau-class/tau.cls
Normal file
@ -0,0 +1,476 @@
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
% --------------------------------------------------------
|
||||
% Tau
|
||||
% LaTeX Class
|
||||
% Version 2.4.4 (28/02/2025)
|
||||
%
|
||||
% Author:
|
||||
% Guillermo Jimenez (memo.notess1@gmail.com)
|
||||
%
|
||||
% License:
|
||||
% Creative Commons CC BY 4.0
|
||||
% --------------------------------------------------------
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
% --------------------------------------------------------
|
||||
% You may modify 'tau.cls' file according to your
|
||||
% needs and preferences. This LaTeX class file defines
|
||||
% the document layout and behavior.
|
||||
% --------------------------------------------------------
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
% --------------------------------------------------------
|
||||
% WARNING!
|
||||
% --------------------------------------------------------
|
||||
% It is important to proceed with caution and ensure that
|
||||
% any modifications made are compatible with LaTeX
|
||||
% syntax and conventions to avoid errors or unexpected
|
||||
% behavior in the document compilation process.
|
||||
% --------------------------------------------------------
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
|
||||
%----------------------------------------------------------
|
||||
% CLASS CONFIGURATION
|
||||
%----------------------------------------------------------
|
||||
|
||||
\NeedsTeXFormat{LaTeX2e}
|
||||
\ProvidesClass{tau-class/tau}[2025/02/28 Tau LaTeX class]
|
||||
\DeclareOption*{\PassOptionsToClass{\CurrentOption}{extarticle}}
|
||||
\ProcessOptions\relax
|
||||
\LoadClass{extarticle}
|
||||
\AtEndOfClass{\RequirePackage{microtype}}
|
||||
|
||||
%----------------------------------------------------------
|
||||
% REQUIRED PACKAGES
|
||||
%----------------------------------------------------------
|
||||
|
||||
\RequirePackage[utf8]{inputenc}
|
||||
\RequirePackage{etoolbox}
|
||||
\RequirePackage[framemethod=tikz]{mdframed}
|
||||
\RequirePackage{titling}
|
||||
\RequirePackage{lettrine}
|
||||
\RequirePackage[switch]{lineno}
|
||||
\RequirePackage[bottom,hang,flushmargin,ragged]{footmisc}
|
||||
\RequirePackage{fancyhdr}
|
||||
\RequirePackage{xifthen}
|
||||
\RequirePackage{adjustbox}
|
||||
%\RequirePackage{adforn}
|
||||
\RequirePackage{lastpage}
|
||||
\RequirePackage[explicit]{titlesec}
|
||||
\RequirePackage{booktabs}
|
||||
\RequirePackage{array}
|
||||
\RequirePackage{caption}
|
||||
\RequirePackage{setspace}
|
||||
\RequirePackage{iflang}
|
||||
\RequirePackage{listings}
|
||||
\RequirePackage{lipsum}
|
||||
\RequirePackage{fontawesome5} % For icons
|
||||
\RequirePackage{chemfig} % Chemical structures
|
||||
\RequirePackage{circuitikz} % Circuits schematics
|
||||
\RequirePackage{supertabular}
|
||||
%\RequirePackage{matlab-prettifier}
|
||||
\RequirePackage{listings}
|
||||
\RequirePackage{csquotes}
|
||||
\RequirePackage{ragged2e}
|
||||
\RequirePackage{subcaption}
|
||||
\RequirePackage{stfloats}
|
||||
% \RequirePackage{pbalance}
|
||||
|
||||
%----------------------------------------------------------
|
||||
% TAU CUSTOM PACKAGES (location/name)
|
||||
%----------------------------------------------------------
|
||||
|
||||
\RequirePackage{tau-class/tauenvs}
|
||||
\RequirePackage{tau-class/taubabel}
|
||||
|
||||
%----------------------------------------------------------
|
||||
% PACKAGES FOR FIGURES
|
||||
%----------------------------------------------------------
|
||||
|
||||
\usepackage{graphicx}
|
||||
\RequirePackage{here}
|
||||
\graphicspath{{figures/}{./}}
|
||||
|
||||
%----------------------------------------------------------
|
||||
% PACKAGES FOR TABLES
|
||||
%----------------------------------------------------------
|
||||
|
||||
\usepackage{adjustbox}
|
||||
\RequirePackage{colortbl}
|
||||
\RequirePackage{tcolorbox}
|
||||
|
||||
%----------------------------------------------------------
|
||||
% GEOMETRY PACKAGE SETUP
|
||||
%----------------------------------------------------------
|
||||
|
||||
\RequirePackage[
|
||||
left=1.25cm,
|
||||
right=1.25cm,
|
||||
top=2cm,
|
||||
bottom=2cm,
|
||||
headsep=0.75cm
|
||||
]{geometry}
|
||||
|
||||
%----------------------------------------------------------
|
||||
% MATH PACKAGES
|
||||
%----------------------------------------------------------
|
||||
|
||||
%!TEX In case of using another font that is not stix2 uncomment 'amssymb'
|
||||
|
||||
\RequirePackage{amsmath}
|
||||
\RequirePackage{amsfonts}
|
||||
\RequirePackage{mathtools}
|
||||
% \RequirePackage{amssymb}
|
||||
|
||||
% Equation skip value
|
||||
\newlength{\eqskip}\setlength{\eqskip}{8pt}
|
||||
\expandafter\def\expandafter\normalsize\expandafter{%
|
||||
\normalsize%
|
||||
\setlength\abovedisplayskip{\eqskip}%
|
||||
\setlength\belowdisplayskip{\eqskip}%
|
||||
\setlength\abovedisplayshortskip{\eqskip-\baselineskip}%
|
||||
\setlength\belowdisplayshortskip{\eqskip}%
|
||||
}
|
||||
|
||||
%----------------------------------------------------------
|
||||
% FONTS
|
||||
%----------------------------------------------------------
|
||||
|
||||
\usepackage[notextcomp]{stix2}
|
||||
\RequirePackage[scaled]{helvet}
|
||||
\renewcommand{\ttdefault}{lmtt}
|
||||
% \renewcommand{\ttdefault}{zi4}
|
||||
% \renewcommand{\ttdefault}{txtt}
|
||||
% \renewcommand{\ttdefault}{cmtl}
|
||||
|
||||
%----------------------------------------------------------
|
||||
% URLs STYLE
|
||||
%----------------------------------------------------------
|
||||
|
||||
\RequirePackage{url}
|
||||
\RequirePackage{xurl}
|
||||
\renewcommand\UrlFont{\selectfont}
|
||||
|
||||
%----------------------------------------------------------
|
||||
|
||||
\RequirePackage[colorlinks=true,allcolors=taucolor]{hyperref}
|
||||
\RequirePackage{cleveref}
|
||||
\RequirePackage{bookmark}
|
||||
|
||||
%----------------------------------------------------------
|
||||
% ITEMS
|
||||
%----------------------------------------------------------
|
||||
|
||||
\RequirePackage{enumitem}
|
||||
\setlist{noitemsep}
|
||||
|
||||
%----------------------------------------------------------
|
||||
% PREDEFINED LENGTHS
|
||||
%----------------------------------------------------------
|
||||
|
||||
\setlength{\columnsep}{15pt}
|
||||
|
||||
%----------------------------------------------------------
|
||||
% FIRST PAGE, HEADER AND FOOTER
|
||||
%----------------------------------------------------------
|
||||
|
||||
% New commands
|
||||
\newcommand{\footerfont}{\normalfont\sffamily\fontsize{7}{9}\selectfont}
|
||||
\newcommand{\institution}[1]{\def\@institution{#1}}
|
||||
\newcommand{\footinfo}[1]{\def\@footinfo{#1}}
|
||||
\newcommand{\leadauthor}[1]{\def\@leadauthor{#1}}
|
||||
\newcommand{\course}[1]{\def\@course{#1}}
|
||||
\newcommand{\theday}[1]{\def\@theday{#1}}
|
||||
\pagestyle{fancy}
|
||||
|
||||
% Number style
|
||||
\pagenumbering{arabic} % Roman
|
||||
|
||||
% First page style
|
||||
\fancypagestyle{firststyle}{
|
||||
\fancyfoot[R]{
|
||||
{\ifx\@institution\undefined\else\footerfont\@institution\hspace{10pt}\fi}
|
||||
{\ifx\@theday\undefined\else\footerfont\bfseries\@theday\hspace{10pt}\fi}
|
||||
{\ifx\@footinfo\undefined\else\footerfont\@footinfo\hspace{10pt}\fi}
|
||||
\footerfont\textbf{\thepage\textendash\pageref{LastPage}}
|
||||
}
|
||||
\fancyfoot[L]{\ifx\@course\undefined\else\footerfont\@course\fi}
|
||||
\fancyhead[C]{}
|
||||
\fancyhead[R]{}
|
||||
\fancyhead[L]{}
|
||||
}
|
||||
|
||||
% Fancy head
|
||||
\fancyhead[C]{}
|
||||
\fancyhead[RE]{\footerfont\itshape\@title}
|
||||
\fancyhead[RO]{}
|
||||
\fancyhead[LO]{\footerfont\itshape\@title}
|
||||
\fancyhead[LE]{}
|
||||
|
||||
% Fancy foot
|
||||
\fancyfoot[C]{}
|
||||
\fancyfoot[LE]{\footerfont\textbf{\thepage}\hspace{10pt}\ifx\@course\undefined\else\@course\fi}
|
||||
\fancyfoot[RO]{\footerfont\ifx\@course\undefined\else\@course\hspace{10pt}\fi\textbf{\thepage}}
|
||||
\fancyfoot[RE,LO]{\footerfont\ifx\@leadauthor\undefined\else\itshape\@leadauthor\fi}
|
||||
|
||||
% Head and foot rule
|
||||
\renewcommand{\headrulewidth}{0pt} % No header rule
|
||||
\renewcommand{\footrulewidth}{0pt} % No footer rule
|
||||
|
||||
%----------------------------------------------------------
|
||||
% TAU START ~ LETTRINE
|
||||
%----------------------------------------------------------
|
||||
|
||||
\RequirePackage{lettrine}
|
||||
\newcommand{\dropcapfont}{\color{taucolor}\bfseries\fontsize{26pt}{28pt}\selectfont}
|
||||
\newcommand{\taustart}[1]{\lettrine[lines=2,lraise=0,findent=2pt, nindent=0em]{{\dropcapfont{#1}}}{}}
|
||||
|
||||
%----------------------------------------------------------
|
||||
% ABSTRACT STYLE
|
||||
%----------------------------------------------------------
|
||||
|
||||
% Abstract font
|
||||
\newcommand{\absfont}{\selectfont\small\color{taucolor}\sffamily\bfseries}
|
||||
% Abstract text font
|
||||
\newcommand{\abstextfont}{\selectfont\sffamily\small}
|
||||
|
||||
% Keywords new command
|
||||
\newcommand{\keywords}[1]{\def\@keywords{#1}}
|
||||
% Keyword font
|
||||
\newcommand{\keywordsfont}{\selectfont\small\color{taucolor}\sffamily\bfseries}
|
||||
% Keyword text font
|
||||
\newcommand{\keywordsfonttext}{\selectfont\itshape\sffamily\small}
|
||||
|
||||
\def\xabstract{abstract}
|
||||
\long\def\abstract#1\end#2{\def\two{#2}\ifx\two\xabstract
|
||||
\long\gdef\theabstract{\ignorespaces#1}
|
||||
\def\go{\end{abstract}}
|
||||
\else
|
||||
#1\end{#2}
|
||||
\gdef\theabstract{\vskip12pt
|
||||
\vskip12pt}
|
||||
\let\go\relax\fi
|
||||
\go}
|
||||
|
||||
\newcommand{\tauabstract}{
|
||||
\noindent
|
||||
\parbox{\dimexpr\linewidth}{
|
||||
\vskip3pt
|
||||
\hspace*{1em}{\absfont\abstractname---}\abstextfont\theabstract
|
||||
}
|
||||
\@ifundefined{@keywords}{\vskip15pt}{
|
||||
\vskip6pt
|
||||
\noindent
|
||||
\parbox{\dimexpr\linewidth}{
|
||||
{
|
||||
\hspace*{1em}{\keywordsfont\keywordname\ignorespaces}{\keywordsfonttext\@keywords}
|
||||
}
|
||||
}
|
||||
\vskip12pt
|
||||
}
|
||||
}
|
||||
|
||||
%----------------------------------------------------------
|
||||
% TITLE STYLE
|
||||
%----------------------------------------------------------
|
||||
|
||||
% New commands
|
||||
\newcommand{\journalname}[1]{\def\@journalname{#1}}
|
||||
\newcommand{\professor}[1]{\def\@professor{#1}}
|
||||
\newcommand{\titlepos}{\RaggedRight}
|
||||
\newcommand{\titlefont}{\bfseries\color{taucolor}\fontsize{18}{22}\sffamily\selectfont}
|
||||
\newcommand{\authorfont}{\normalsize\sffamily}
|
||||
\newcommand{\professorfont}{\fontsize{7pt}{8pt}\selectfont}
|
||||
|
||||
\renewcommand{\@maketitle}{
|
||||
{\ifx\@journalname\undefined\vskip-18pt\else\vskip-25pt\RaggedRight\sffamily\bfseries\fontsize{8}{14}\@journalname\par\fi}
|
||||
{\titlepos\titlefont\@title\par\vskip8pt}
|
||||
{\authorfont\titlepos\@author\par\vskip8pt}
|
||||
{\ifx\@professor\undefined\vskip14pt\else\professorfont\titlepos\@professor\par\vskip22pt\fi}
|
||||
}
|
||||
|
||||
%----------------------------------------------------------
|
||||
% AFFILIATION SETUP
|
||||
%----------------------------------------------------------
|
||||
|
||||
\RequirePackage{authblk} % For custom environment affiliation
|
||||
|
||||
\setlength{\affilsep}{9pt}
|
||||
\renewcommand\Authfont{\normalfont\sffamily\bfseries\fontsize{9}{11}}
|
||||
\renewcommand\Affilfont{\normalfont\itshape\fontsize{7.5}{9}}
|
||||
% \renewcommand\AB@affilsepx{; \protect\Affilfont}
|
||||
\renewcommand\Authands{\ignorespaces\andlanguage}
|
||||
\renewcommand\Authand{\ignorespaces\andlanguage}
|
||||
|
||||
%----------------------------------------------------------
|
||||
% SECTION STYLE
|
||||
%----------------------------------------------------------
|
||||
|
||||
\setcounter{secnumdepth}{5}
|
||||
|
||||
\titleformat{\section}
|
||||
{\color{taucolor}\sffamily\large\bfseries}
|
||||
{\thesection.}
|
||||
{0.5em}
|
||||
{#1}
|
||||
[]
|
||||
|
||||
\titleformat{name=\section,numberless}
|
||||
{\color{taucolor}\sffamily\large\bfseries}
|
||||
{}
|
||||
{0em}
|
||||
{#1}
|
||||
[]
|
||||
|
||||
\titleformat{\subsection}[block] % You may change it to "runin"
|
||||
{\bfseries\sffamily}
|
||||
{\thesubsection.}
|
||||
{0.5em}
|
||||
{#1} % If using runin, change #1 to '#1. ' (space at the end)
|
||||
[]
|
||||
|
||||
\titleformat{\subsubsection}[block] % You may change it to "runin"
|
||||
{\small\bfseries\sffamily\itshape}
|
||||
{\thesubsubsection.}
|
||||
{0.5em}
|
||||
{#1} % If using runin, change #1 to '#1. ' (space at the end)
|
||||
[]
|
||||
|
||||
\titleformat{\paragraph}[runin]
|
||||
{\small\bfseries}
|
||||
{}
|
||||
{0em}
|
||||
{#1}
|
||||
|
||||
\titlespacing*{\section}{0pc}{3ex \@plus4pt \@minus3pt}{5pt}
|
||||
\titlespacing*{\subsection}{0pc}{2.5ex \@plus3pt \@minus2pt}{2pt}
|
||||
\titlespacing*{\subsubsection}{0pc}{2ex \@plus2.5pt \@minus1.5pt}{2pt}
|
||||
\titlespacing*{\paragraph}{0pc}{1.5ex \@plus2pt \@minus1pt}{12pt}
|
||||
|
||||
%----------------------------------------------------------
|
||||
% TABLE OF CONTENTS
|
||||
%----------------------------------------------------------
|
||||
|
||||
\newlength{\tocsep}
|
||||
\setlength\tocsep{1.5pc} % Sets the indentation of the sections in the table of contents
|
||||
\setcounter{tocdepth}{5} % Three levels in the table of contents section: sections, subsections and subsubsections
|
||||
|
||||
\RequirePackage{titletoc}
|
||||
\contentsmargin{0cm}
|
||||
|
||||
\titlecontents{section}[\tocsep]
|
||||
{\addvspace{4pt}\sffamily\selectfont\bfseries}
|
||||
{\contentslabel[\thecontentslabel]{\tocsep}}
|
||||
{}
|
||||
{\hfill\thecontentspage}
|
||||
[]
|
||||
|
||||
\titlecontents{subsection}[3pc]
|
||||
{\addvspace{4pt}\small\sffamily\selectfont}
|
||||
{\contentslabel[\thecontentslabel]{\tocsep}}
|
||||
{}
|
||||
{\ \titlerule*[.5pc]{.}\ \thecontentspage}
|
||||
[]
|
||||
|
||||
\titlecontents*{subsubsection}[3pc]
|
||||
{\footnotesize\sffamily\itshape\selectfont}
|
||||
{}
|
||||
{}
|
||||
{}
|
||||
[\ \textbullet\ ]
|
||||
|
||||
%----------------------------------------------------------
|
||||
% FOOTNOTE STYLE
|
||||
%----------------------------------------------------------
|
||||
|
||||
\definecolor{black50}{gray}{0.5}
|
||||
\renewcommand*{\footnotelayout}{\normalfont\fontsize{6}{8}}
|
||||
\renewcommand{\footnoterule}{
|
||||
\kern -3pt
|
||||
{\color{black50} \hrule width 72pt height 0.25pt}
|
||||
\kern 2.5pt
|
||||
}
|
||||
|
||||
%----------------------------------------------------------
|
||||
% FIGURE-, TABLE-, LISTINGS- CAPTION STYLE
|
||||
%----------------------------------------------------------
|
||||
|
||||
% Figure
|
||||
\captionsetup[figure]{format=plain, labelsep=period, textfont={small}, justification=centering, labelfont={small,bf,sf}}
|
||||
|
||||
% Table
|
||||
\captionsetup*[table]{labelfont={small,bf,sf},textfont={small},skip=10pt,position=below,labelsep=period} % Change it to "above" if you prefer the caption above the table.
|
||||
\newcommand{\tabletext}[1]{{\setlength{\leftskip}{9pt}\fontsize{7}{9}\vskip2pt\itshape\selectfont#1}}
|
||||
|
||||
\captionsetup[lstlisting]{font=small, labelfont={bf,sf}, belowskip=3pt, position=below, textfont=small, labelsep=period}
|
||||
\renewcommand\lstlistingname{\captioncodelanguage}
|
||||
\renewcommand\lstlistlistingname{\captioncodelanguage}
|
||||
|
||||
%----------------------------------------------------------
|
||||
% LISTINGS STYLE
|
||||
%----------------------------------------------------------
|
||||
|
||||
% Defined colors for listings environment
|
||||
\definecolor{taucodeback}{RGB}{248, 248, 248}
|
||||
\definecolor{taucodecomment}{RGB}{1, 136, 0}
|
||||
\definecolor{taucodekey}{RGB}{53, 53, 128}
|
||||
\definecolor{taucodestring}{RGB}{122, 36, 47}
|
||||
\definecolor{taugray}{RGB}{0.5,0.5,0.5}
|
||||
\definecolor{tauredmatlab}{RGB}{199, 78, 0}
|
||||
\definecolor{taublue}{RGB}{43, 43, 255}
|
||||
|
||||
% General
|
||||
\lstdefinestyle{taucoding}{
|
||||
backgroundcolor=\color{taucodeback},
|
||||
commentstyle=\color{taucodecomment},
|
||||
keywordstyle=\bfseries\color{taucodekey},
|
||||
numberstyle=\tiny\color{taugray},
|
||||
stringstyle=\color{taucodestring},
|
||||
basicstyle=\footnotesize\ttfamily,
|
||||
breakatwhitespace=false,
|
||||
basicstyle=\small\ttfamily,
|
||||
breaklines=true,
|
||||
captionpos=b,
|
||||
keepspaces=true,
|
||||
numbers=left, % if "none" change the values
|
||||
numbersep=8pt, % 0pt
|
||||
showspaces=false,
|
||||
showstringspaces=false,
|
||||
showtabs=false,
|
||||
tabsize=2,
|
||||
aboveskip=12pt,
|
||||
belowskip=8pt,
|
||||
xleftmargin=12pt, % 0pt
|
||||
xrightmargin=3pt,
|
||||
rulecolor=\color{taugray},
|
||||
frame=single,
|
||||
columns=fullflexible,
|
||||
postbreak=\mbox{\small\textcolor{taugray}{$\hookrightarrow$}\space},
|
||||
literate=
|
||||
{0}{{\textcolor{taublue}{0}}}{1}
|
||||
{1}{{\textcolor{taublue}{1}}}{1}
|
||||
{2}{{\textcolor{taublue}{2}}}{1}
|
||||
{3}{{\textcolor{taublue}{3}}}{1}
|
||||
{4}{{\textcolor{taublue}{4}}}{1}
|
||||
{5}{{\textcolor{taublue}{5}}}{1}
|
||||
{6}{{\textcolor{taublue}{6}}}{1}
|
||||
{7}{{\textcolor{taublue}{7}}}{1}
|
||||
{8}{{\textcolor{taublue}{8}}}{1}
|
||||
{9}{{\textcolor{taublue}{9}}}{1}
|
||||
}
|
||||
|
||||
\lstset{style=taucoding}
|
||||
|
||||
%----------------------------------------------------------
|
||||
% BIBLATEX
|
||||
%----------------------------------------------------------
|
||||
|
||||
\RequirePackage[
|
||||
backend=biber,
|
||||
style=ieee,
|
||||
sorting=ynt
|
||||
]{biblatex}
|
||||
|
||||
\addbibresource{tau.bib}
|
||||
|
||||
%----------------------------------------------------------
|
||||
|
||||
\endinput
|
84
data_analysis/utils/tau-class/taubabel.sty
Normal file
84
data_analysis/utils/tau-class/taubabel.sty
Normal file
@ -0,0 +1,84 @@
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
% --------------------------------------------------------
|
||||
% Taubabel
|
||||
% LaTeX Package
|
||||
% Version 1.0.1 (26/07/2024)
|
||||
%
|
||||
% Author:
|
||||
% Guillermo Jimenez (memo.notess1@gmail.com)
|
||||
%
|
||||
% License:
|
||||
% Creative Commons CC BY 4.0
|
||||
% --------------------------------------------------------
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
% --------------------------------------------------------
|
||||
% WARNING!
|
||||
% --------------------------------------------------------
|
||||
% Do not remove this package from 'tau.cls' to avoid
|
||||
% compilation problems. This package defines the custom
|
||||
% words for english and spanish.
|
||||
% --------------------------------------------------------
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
|
||||
%----------------------------------------------------------
|
||||
% PACKAGE CONFIGURATION
|
||||
%----------------------------------------------------------
|
||||
|
||||
\NeedsTeXFormat{LaTeX2e}
|
||||
\ProvidesPackage{tau-class/taubabel}[2024/05/22]
|
||||
|
||||
%----------------------------------------------------------
|
||||
% KEYWORDS WORD
|
||||
%----------------------------------------------------------
|
||||
|
||||
\newcommand{\keywordname}{{Keywords---}}
|
||||
|
||||
%----------------------------------------------------------
|
||||
|
||||
%!TEX Change manually the second {} if using another language babel
|
||||
|
||||
%----------------------------------------------------------
|
||||
% AUTHOR "AND" LANGUAGES
|
||||
%----------------------------------------------------------
|
||||
|
||||
\newcommand{\andlanguage}{
|
||||
\iflanguage{spanish}{
|
||||
{ y }%
|
||||
}{%else
|
||||
{ and }%
|
||||
}%
|
||||
}
|
||||
|
||||
%----------------------------------------------------------
|
||||
% CAPTION CODE LANGUAGE
|
||||
%----------------------------------------------------------
|
||||
|
||||
\newcommand{\captioncodelanguage}{
|
||||
\iflanguage{spanish}{
|
||||
{C\'odigo}%
|
||||
}{%else
|
||||
{Code}%
|
||||
}%
|
||||
}
|
||||
|
||||
%----------------------------------------------------------
|
||||
% ENVIRONMENTS LANGUAGES
|
||||
%----------------------------------------------------------
|
||||
|
||||
% Information/Informaci\'on language
|
||||
\newcommand{\infolanguage}{
|
||||
\iflanguage{spanish}{
|
||||
{\bfseries\noindent Informaci\'on}%
|
||||
}{%else
|
||||
{\bfseries\noindent Information}%
|
||||
}%
|
||||
}
|
||||
|
||||
% Note/Nota language
|
||||
\newcommand{\notelanguage}{
|
||||
\iflanguage{spanish}{
|
||||
{\bfseries\noindent Nota}%
|
||||
}{%else
|
||||
{\bfseries\noindent Note}%
|
||||
}%
|
||||
}
|
106
data_analysis/utils/tau-class/tauenvs.sty
Normal file
106
data_analysis/utils/tau-class/tauenvs.sty
Normal file
@ -0,0 +1,106 @@
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
% --------------------------------------------------------
|
||||
% Tauenvs
|
||||
% LaTeX Package
|
||||
% Version 1.1.1 (22/05/2024)
|
||||
%
|
||||
% Author:
|
||||
% Guillermo Jimenez (memo.notess1@gmail.com)
|
||||
%
|
||||
% License:
|
||||
% Creative Commons CC BY 4.0
|
||||
% --------------------------------------------------------
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
% --------------------------------------------------------
|
||||
% WARNING!
|
||||
% --------------------------------------------------------
|
||||
% Do not remove this package from 'tau.cls' to avoid
|
||||
% compilation problems. This package defines the included
|
||||
% custom environments and colors.
|
||||
% --------------------------------------------------------
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
|
||||
%----------------------------------------------------------
|
||||
% PACKAGE CONFIGURATION
|
||||
%----------------------------------------------------------
|
||||
|
||||
\NeedsTeXFormat{LaTeX2e}
|
||||
\ProvidesPackage{tau-class/tauenvs}[2024/05/22]
|
||||
|
||||
%----------------------------------------------------------
|
||||
% COLORS
|
||||
%----------------------------------------------------------
|
||||
|
||||
\RequirePackage{xcolor}
|
||||
\definecolor{taucolor}{rgb}{0.0,0.2,0.4} % Blue
|
||||
% \definecolor{taucolor}{rgb}{0.12, 0.3, 0.17} % Green
|
||||
% \definecolor{taucolor}{rgb}{0.4, 0.26, 0.13} % Brown
|
||||
|
||||
%----------------------------------------------------------
|
||||
% TAUENV-, INFO-, NOTE- ENVIRONMENTS
|
||||
%----------------------------------------------------------
|
||||
|
||||
% Tauenv
|
||||
\newmdenv[
|
||||
backgroundcolor=taucolor!22,
|
||||
linecolor=taucolor!22,
|
||||
linewidth=0.7pt,
|
||||
frametitle=\vskip0pt\bfseries,
|
||||
frametitlerule=false,
|
||||
frametitlefont=\color{taucolor}\bfseries\sffamily,
|
||||
frametitlealignment=\raggedright,
|
||||
innertopmargin=3pt,
|
||||
innerbottommargin=6pt,
|
||||
innerleftmargin=6pt,
|
||||
innerrightmargin=6pt,
|
||||
font=\selectfont,
|
||||
fontcolor=taucolor,
|
||||
frametitleaboveskip=8pt,
|
||||
skipabove=10pt
|
||||
]{tauenv}
|
||||
|
||||
%----------------------------------------------------------
|
||||
|
||||
% Info
|
||||
\newmdenv[
|
||||
backgroundcolor=taucolor!22,
|
||||
linecolor=taucolor!22,
|
||||
linewidth=0.7pt,
|
||||
frametitle=\vskip0pt\bfseries\infolanguage,
|
||||
frametitlerule=false,
|
||||
frametitlefont=\color{taucolor}\bfseries\sffamily,
|
||||
frametitlealignment=\raggedright,
|
||||
innertopmargin=3pt,
|
||||
innerbottommargin=6pt,
|
||||
innerleftmargin=6pt,
|
||||
innerrightmargin=6pt,
|
||||
font=\normalfont,
|
||||
fontcolor=taucolor,
|
||||
frametitleaboveskip=3pt,
|
||||
skipabove=10pt
|
||||
]{info}
|
||||
|
||||
%----------------------------------------------------------
|
||||
|
||||
% Note
|
||||
\newmdenv[
|
||||
backgroundcolor=taucolor!22,
|
||||
linecolor=taucolor!22,
|
||||
linewidth=0.7pt,
|
||||
frametitle=\vskip0pt\bfseries\notelanguage,
|
||||
frametitlerule=false,
|
||||
frametitlefont=\color{taucolor}\bfseries\sffamily,
|
||||
frametitlealignment=\raggedright,
|
||||
innertopmargin=3pt,
|
||||
innerbottommargin=6pt,
|
||||
innerleftmargin=6pt,
|
||||
innerrightmargin=6pt,
|
||||
font=\normalfont,
|
||||
fontcolor=taucolor,
|
||||
frametitleaboveskip=3pt,
|
||||
skipabove=10pt
|
||||
]{note}
|
||||
|
||||
%----------------------------------------------------------
|
||||
|
||||
\endinput
|
@ -5,8 +5,8 @@ from pathlib import Path
|
||||
import time
|
||||
|
||||
# Import necessary components from your project structure
|
||||
from data_analysis.utils.data_config_model import load_settings, Settings # Import loading function and model
|
||||
from data_analysis.analysis.pipeline import run_eda_pipeline # Import the pipeline entry point
|
||||
from data_analysis.utils.data_config_model import load_settings
|
||||
from data_analysis.analysis.pipeline import run_eda_pipeline
|
||||
|
||||
# Silence overly verbose libraries if needed (e.g., matplotlib)
|
||||
mpl_logger = logging.getLogger('matplotlib')
|
||||
@ -47,13 +47,13 @@ def main():
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# --- Configuration Loading ---
|
||||
_ = load_settings(config_path)
|
||||
settings = load_settings(config_path)
|
||||
logger.info(f"Using configuration from: {config_path.resolve()} (or defaults if loading failed)")
|
||||
|
||||
# --- Pipeline Execution ---
|
||||
try:
|
||||
# Call the main function from your pipeline module
|
||||
run_eda_pipeline()
|
||||
run_eda_pipeline(settings)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
logger.info(f"Main script finished successfully in {end_time - start_time:.2f} seconds.")
|
||||
|
@ -3,6 +3,7 @@
|
||||
project_name: "TimeSeriesForecasting" # Name for the project/run
|
||||
random_seed: 42 # Optional: Global random seed for reproducibility
|
||||
log_level: INFO # Or DEBUG
|
||||
output_dir: "output" # Base directory for all outputs (logs, models, results)
|
||||
|
||||
# --- Execution Control ---
|
||||
run_cross_validation: true # Run the main cross-validation loop?
|
||||
@ -26,13 +27,14 @@ data:
|
||||
|
||||
# --- Feature Engineering & Preprocessing Configuration ---
|
||||
features:
|
||||
sequence_length: 72 # REQUIRED: Lookback window size (e.g., 72 hours = 3 days)
|
||||
forecast_horizon: [ 1, 6, 12, 24] # REQUIRED: List of steps ahead to predict (e.g., 1 hour, 6 hours, 12 hours, 24 hours, 48 hours, 72 hours, 168 hours)
|
||||
lags: [24, 48, 72, 168] # List of lag features to create (e.g., 1 day, 2 days, 3 days, 1 week)
|
||||
rolling_window_sizes: [24, 72, 168] # List of window sizes for rolling stats (mean, std)
|
||||
sequence_length: 72 # REQUIRED: Lookback window size (e.g., 72 hours = 3 days) includes all features and lags!
|
||||
# REQUIRED: List of steps ahead to predict (e.g., 1 hour, 6 hours, 12 hours, 24 hours, 48 hours, 72 hours, 168 hours)
|
||||
forecast_horizon: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]
|
||||
lags: [1,2,3,24,48,168] # List of lag features to create in h; 168 = 1W
|
||||
rolling_window_sizes: [72, 168] # List of window sizes for rolling stats (mean, std)
|
||||
use_time_features: true # Create calendar features (hour, dayofweek, month, etc.)?
|
||||
sinus_curve: true # Create sinusoidal feature for time of day?
|
||||
cosin_curve: true # Create cosinusoidal feature for time of day?
|
||||
cosine_curve: true # Create cosinusoidal feature for time of day?
|
||||
fill_nan: 'ffill' # Method to fill NaNs created by lags/rolling windows ('ffill', 'bfill', 0, etc.)
|
||||
scaling_method: 'standard' # Scaling method ('standard', 'minmax', or null/None for no scaling) Fit per fold.
|
||||
|
||||
@ -62,8 +64,9 @@ model:
|
||||
# --- Training Configuration (PyTorch Lightning) ---
|
||||
training:
|
||||
batch_size: 64 # REQUIRED: Batch size for training
|
||||
epochs: 50 # REQUIRED: Max number of training epochs per fold
|
||||
learning_rate: 0.001 # REQUIRED: Initial learning rate for Adam optimizer
|
||||
epochs: 72 # REQUIRED: Max number of training epochs per fold
|
||||
learning_rate: 0.0001 # REQUIRED: Initial learning rate for Adam optimizer
|
||||
check_val_n_epoch: 3
|
||||
loss_function: "MSE" # Loss function ('MSE' or 'MAE')
|
||||
early_stopping_patience: 10 # Optional: Patience for early stopping (epochs). Set null/None to disable. Must be >= 1 if set.
|
||||
scheduler_step_size: null # Optional: Step size for StepLR scheduler (epochs). Set null/None to disable. Must be > 0 if set.
|
||||
@ -74,7 +77,7 @@ training:
|
||||
|
||||
# --- Cross-Validation Configuration (Rolling Window) ---
|
||||
cross_validation:
|
||||
n_splits: 5 # REQUIRED: Number of CV folds (must be > 0)
|
||||
n_splits: 3 # REQUIRED: Number of CV folds (must be > 0)
|
||||
test_size_fraction: 0.1 # REQUIRED: Fraction of the *fixed training window size* for the test set (0 < frac < 1)
|
||||
val_size_fraction: 0.1 # REQUIRED: Fraction of the *fixed training window size* for the validation set (0 < frac < 1)
|
||||
initial_train_size: null # Optional: Size of the fixed training window (integer samples or float fraction of total data > 0). If null, estimated automatically.
|
||||
@ -88,9 +91,9 @@ evaluation:
|
||||
# --- Optuna Hyperparameter Optimization Configuration ---
|
||||
optuna:
|
||||
enabled: true # Set to true to actually run HPO via optuna_run.py
|
||||
study_name: "lstm_price_forecast_hpo_v1" # Specific name for this study
|
||||
n_trials: 200 # Number of trials to run
|
||||
storage: "sqlite:///output/hpo_results/study_v1.db" # Path to database file
|
||||
study_name: "lstm_price_forecast" # Specific name for this study
|
||||
n_trials: 100 # Number of trials to run
|
||||
storage: "sqlite:///study_v1.db" # Path to database file
|
||||
direction: "minimize" # 'minimize' or 'maximize'
|
||||
metric_to_optimize: "val_MeanAbsoluteError" # Metric logged in validation_step
|
||||
pruning: true # Enable pruning
|
||||
|
@ -8,36 +8,21 @@ with support for feature engineering, cross-validation, and evaluation.
|
||||
__version__ = "0.1.0"
|
||||
|
||||
# Expose core components for easier import
|
||||
from .data_processing import (
|
||||
load_raw_data,
|
||||
engineer_features,
|
||||
TimeSeriesCrossValidationSplitter,
|
||||
prepare_fold_data_and_loaders,
|
||||
TimeSeriesDataset
|
||||
)
|
||||
from forecasting_model.utils.data_processing import engineer_features, prepare_fold_data_and_loaders
|
||||
from forecasting_model.io.data import load_raw_data
|
||||
from forecasting_model.utils.dataset_splitter import TimeSeriesCrossValidationSplitter, TimeSeriesDataset
|
||||
from forecasting_model.train.model import LSTMForecastLightningModule
|
||||
from .evaluation import (
|
||||
evaluate_fold_predictions,
|
||||
# Optionally expose the standalone evaluation utility if needed externally
|
||||
# evaluate_model_on_fold_test_set
|
||||
)
|
||||
|
||||
# Expose main configuration class from utils
|
||||
from .utils import MainConfig
|
||||
|
||||
# Expose the main execution script function if it's intended to be callable as a function
|
||||
# from .forecasting_model import run # Assuming the main script is named forecasting_model.py
|
||||
from forecasting_model.utils.evaluation import evaluate_fold_predictions
|
||||
from forecasting_model.utils import MainConfig
|
||||
|
||||
# Define __all__ for explicit public API (optional but good practice)
|
||||
__all__ = [
|
||||
"load_raw_data",
|
||||
"engineer_features",
|
||||
"TimeSeriesCrossValidationSplitter",
|
||||
"prepare_fold_data_and_loaders",
|
||||
"TimeSeriesDataset",
|
||||
"LSTMForecastLightningModule",
|
||||
"evaluate_fold_predictions",
|
||||
# "evaluate_model_on_fold_test_set", # Uncomment if exposed
|
||||
"MainConfig",
|
||||
# "run", # Uncomment if exposed
|
||||
"TimeSeriesDataset",
|
||||
"TimeSeriesCrossValidationSplitter",
|
||||
"load_raw_data"
|
||||
]
|
@ -1,377 +0,0 @@
|
||||
import logging
|
||||
from pathlib import Path # Added
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchmetrics
|
||||
from torch.utils.data import DataLoader
|
||||
from sklearn.preprocessing import StandardScaler, MinMaxScaler # For type hinting target_scaler
|
||||
from typing import Dict, Optional, Union, List
|
||||
import pandas as pd # For time index type hint
|
||||
|
||||
from forecasting_model.utils.forecast_config_model import EvaluationConfig
|
||||
from forecasting_model.train.model import LSTMForecastLightningModule
|
||||
from forecasting_model.io.plotting import (
|
||||
setup_plot_style,
|
||||
save_plot,
|
||||
create_time_series_plot,
|
||||
create_scatter_plot,
|
||||
create_residuals_plot,
|
||||
create_residuals_distribution_plot,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Metric Calculations (Utilities - Optional) ---
|
||||
# (Keep calculate_mae_np, calculate_rmse_np if needed as standalone utils)
|
||||
# ... (code for calculate_mae_np, calculate_rmse_np unchanged) ...
|
||||
def calculate_mae_np(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
||||
"""
|
||||
[Optional Utility] Calculate Mean Absolute Error using NumPy.
|
||||
Prefer torchmetrics inside training/validation loops.
|
||||
|
||||
Args:
|
||||
y_true: Ground truth values (flattened).
|
||||
y_pred: Predicted values (flattened).
|
||||
|
||||
Returns:
|
||||
Calculated MAE, or NaN if inputs are invalid.
|
||||
"""
|
||||
if y_true.shape != y_pred.shape:
|
||||
logger.error(f"Shape mismatch for MAE: y_true={y_true.shape}, y_pred={y_pred.shape}")
|
||||
return np.nan
|
||||
if len(y_true) == 0:
|
||||
logger.warning("Attempting to calculate MAE on empty arrays.")
|
||||
return np.nan
|
||||
try:
|
||||
# Use scikit-learn for robustness if available, otherwise basic numpy
|
||||
from sklearn.metrics import mean_absolute_error
|
||||
mae = mean_absolute_error(y_true, y_pred)
|
||||
except ImportError:
|
||||
mae = np.mean(np.abs(y_true - y_pred))
|
||||
return float(mae)
|
||||
|
||||
|
||||
def calculate_rmse_np(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
||||
"""
|
||||
[Optional Utility] Calculate Root Mean Squared Error using NumPy.
|
||||
Prefer torchmetrics inside training/validation loops.
|
||||
|
||||
Args:
|
||||
y_true: Ground truth values (flattened).
|
||||
y_pred: Predicted values (flattened).
|
||||
|
||||
Returns:
|
||||
Calculated RMSE, or NaN if inputs are invalid.
|
||||
"""
|
||||
if y_true.shape != y_pred.shape:
|
||||
logger.error(f"Shape mismatch for RMSE: y_true={y_true.shape}, y_pred={y_pred.shape}")
|
||||
return np.nan
|
||||
if len(y_true) == 0:
|
||||
logger.warning("Attempting to calculate RMSE on empty arrays.")
|
||||
return np.nan
|
||||
try:
|
||||
# Use scikit-learn for robustness if available, otherwise basic numpy
|
||||
from sklearn.metrics import mean_squared_error
|
||||
mse = mean_squared_error(y_true, y_pred, squared=True)
|
||||
except ImportError:
|
||||
mse = np.mean((y_true - y_pred)**2)
|
||||
rmse = np.sqrt(mse)
|
||||
return float(rmse)
|
||||
|
||||
|
||||
|
||||
# --- Fold Evaluation Function ---
|
||||
|
||||
def evaluate_fold_predictions(
|
||||
y_true_scaled: np.ndarray, # Shape: (n_samples, len(horizons))
|
||||
y_pred_scaled: np.ndarray, # Shape: (n_samples, len(horizons))
|
||||
target_scaler: Union[StandardScaler, MinMaxScaler, None],
|
||||
eval_config: EvaluationConfig,
|
||||
fold_num: int, # Zero-based fold index
|
||||
output_dir: str, # Base output directory
|
||||
plot_subdir: Optional[str] = "plots",
|
||||
# time_index: Optional[Union[np.ndarray, pd.Index]] = None, # OLD: Index for samples
|
||||
prediction_time_index: Optional[pd.Index] = None, # Index corresponding to the prediction times (n_samples,)
|
||||
forecast_horizons: Optional[List[int]] = None, # The list of horizons predicted (e.g., [1, 6, 12])
|
||||
plot_title_prefix: Optional[str] = None
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Processes prediction results (multiple horizons) for a fold or ensemble.
|
||||
|
||||
Takes scaled predictions and targets (shape: samples, num_horizons),
|
||||
inverse transforms them, calculates overall metrics (MAE, RMSE) across all horizons,
|
||||
and generates evaluation plots *for the first specified horizon only*.
|
||||
|
||||
Args:
|
||||
y_true_scaled: Numpy array of scaled ground truth targets (n_samples, len(horizons)).
|
||||
y_pred_scaled: Numpy array of scaled model predictions (n_samples, len(horizons)).
|
||||
target_scaler: The scaler fitted on the target variable.
|
||||
eval_config: Configuration object for evaluation parameters.
|
||||
fold_num: The current fold number (zero-based or -1 for classic).
|
||||
output_dir: The base directory to save outputs.
|
||||
plot_subdir: Specific subdirectory under output_dir for plots.
|
||||
prediction_time_index: Pandas Index representing the time for each prediction point (n_samples,).
|
||||
Required for meaningful time plots.
|
||||
forecast_horizons: List of horizons predicted (e.g., [1, 6, 12]). Required for plotting.
|
||||
plot_title_prefix: Optional string to prepend to plot titles.
|
||||
|
||||
Returns:
|
||||
Dictionary containing evaluation metrics {'MAE': value, 'RMSE': value} on the
|
||||
original scale, calculated *across all predicted horizons*.
|
||||
"""
|
||||
fold_id_str = f"Fold {fold_num + 1}" if fold_num >= 0 else "Classic Run"
|
||||
eval_context_str = f"{plot_title_prefix} {fold_id_str}" if plot_title_prefix else fold_id_str
|
||||
logger.info(f"Processing evaluation results for: {eval_context_str}")
|
||||
|
||||
if y_true_scaled.shape != y_pred_scaled.shape:
|
||||
raise ValueError(f"Shape mismatch between targets and predictions for {eval_context_str}: "
|
||||
f"{y_true_scaled.shape} vs {y_pred_scaled.shape}")
|
||||
if y_true_scaled.ndim != 2:
|
||||
raise ValueError(f"Expected 2D arrays (samples, num_horizons) for {eval_context_str}, got {y_true_scaled.ndim}D")
|
||||
|
||||
n_samples, n_horizons = y_true_scaled.shape
|
||||
logger.debug(f"Processing {n_samples} samples across {n_horizons} horizons for {eval_context_str}.")
|
||||
|
||||
# --- Inverse Transform (Outputs NumPy) ---
|
||||
# Flatten the multi-horizon arrays for the scaler (which expects (N, 1))
|
||||
y_true_flat_scaled = y_true_scaled.reshape(-1, 1) # Shape: (n_samples * n_horizons, 1)
|
||||
y_pred_flat_scaled = y_pred_scaled.reshape(-1, 1) # Shape: (n_samples * n_horizons, 1)
|
||||
|
||||
y_true_inv_np: np.ndarray
|
||||
y_pred_inv_np: np.ndarray
|
||||
|
||||
if target_scaler is not None:
|
||||
try:
|
||||
logger.debug(f"Inverse transforming predictions and targets for {eval_context_str}.")
|
||||
y_true_inv_flat = target_scaler.inverse_transform(y_true_flat_scaled)
|
||||
y_pred_inv_flat = target_scaler.inverse_transform(y_pred_flat_scaled)
|
||||
# Reshape back to (n_samples, n_horizons) for potential per-horizon analysis later
|
||||
y_true_inv_np = y_true_inv_flat.reshape(n_samples, n_horizons)
|
||||
y_pred_inv_np = y_pred_inv_flat.reshape(n_samples, n_horizons)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during inverse scaling for {eval_context_str}: {e}", exc_info=True)
|
||||
logger.error("Metrics calculation will be skipped due to inverse transform failure.")
|
||||
return {'MAE': np.nan, 'RMSE': np.nan}
|
||||
|
||||
else:
|
||||
logger.info(f"No target scaler provided for {eval_context_str}, assuming inputs are on original scale.")
|
||||
y_true_inv_np = y_true_scaled # Keep original shape (n_samples, n_horizons)
|
||||
y_pred_inv_np = y_pred_scaled # Keep original shape
|
||||
|
||||
# --- Calculate Metrics using torchmetrics.functional (Overall across all horizons) ---
|
||||
metrics: Dict[str, float] = {'MAE': np.nan, 'RMSE': np.nan}
|
||||
try:
|
||||
# Flatten arrays for overall metrics calculation
|
||||
y_true_flat_for_metrics = y_true_inv_np.flatten()
|
||||
y_pred_flat_for_metrics = y_pred_inv_np.flatten()
|
||||
|
||||
valid_mask = ~np.isnan(y_true_flat_for_metrics) & ~np.isnan(y_pred_flat_for_metrics)
|
||||
if np.sum(valid_mask) < len(y_true_flat_for_metrics):
|
||||
nan_count = len(y_true_flat_for_metrics) - np.sum(valid_mask)
|
||||
logger.warning(f"{nan_count} NaN values found in predictions/targets (across all horizons) for {eval_context_str}. These will be excluded from metrics.")
|
||||
|
||||
|
||||
if np.sum(valid_mask) > 0:
|
||||
y_true_tensor = torch.from_numpy(y_true_flat_for_metrics[valid_mask]).float().cpu()
|
||||
y_pred_tensor = torch.from_numpy(y_pred_flat_for_metrics[valid_mask]).float().cpu()
|
||||
|
||||
mae_tensor = torchmetrics.functional.mean_absolute_error(y_pred_tensor, y_true_tensor)
|
||||
mse_tensor = torchmetrics.functional.mean_squared_error(y_pred_tensor, y_true_tensor)
|
||||
rmse_tensor = torch.sqrt(mse_tensor)
|
||||
|
||||
metrics['MAE'] = mae_tensor.item()
|
||||
metrics['RMSE'] = rmse_tensor.item()
|
||||
|
||||
logger.info(f"{eval_context_str} Test Set Overall Metrics (torchmetrics): MAE={metrics['MAE']:.4f}, RMSE={metrics['RMSE']:.4f} (across all horizons)")
|
||||
else:
|
||||
logger.warning(f"Skipping metric calculation for {eval_context_str} due to no valid (non-NaN) data points.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to calculate overall metrics using torchmetrics for {eval_context_str}: {e}", exc_info=True)
|
||||
|
||||
|
||||
# --- Generate Plots (Optional - Focus on FIRST horizon) ---
|
||||
if eval_config.save_plots and np.sum(valid_mask) > 0:
|
||||
if forecast_horizons is None or not forecast_horizons:
|
||||
logger.warning(f"Skipping plot generation for {eval_context_str}: `forecast_horizons` list not provided.")
|
||||
elif prediction_time_index is None or len(prediction_time_index) != n_samples:
|
||||
logger.warning(f"Skipping plot generation for {eval_context_str}: `prediction_time_index` is missing or has incorrect length ({len(prediction_time_index) if prediction_time_index is not None else 'None'} != {n_samples}).")
|
||||
else:
|
||||
logger.info(f"Generating evaluation plots for {eval_context_str} (using first horizon H+{forecast_horizons[0]} only)...")
|
||||
base_plot_dir = Path(output_dir)
|
||||
fold_plot_dir = base_plot_dir / plot_subdir if plot_subdir else base_plot_dir
|
||||
setup_plot_style()
|
||||
|
||||
# --- Plotting for the FIRST horizon ---
|
||||
first_horizon = forecast_horizons[0]
|
||||
y_true_h1 = y_true_inv_np[:, 0] # Data for the first horizon
|
||||
y_pred_h1 = y_pred_inv_np[:, 0] # Data for the first horizon
|
||||
residuals_h1 = y_true_h1 - y_pred_h1
|
||||
|
||||
# Calculate the actual time index for the first horizon's targets
|
||||
# Requires the original dataset's frequency if available, otherwise assumes simple offset
|
||||
target_time_index_h1 = prediction_time_index
|
||||
try:
|
||||
# Assuming prediction_time_index corresponds to the *time* of prediction
|
||||
# The target for H+h occurs `h` steps later.
|
||||
# This requires a DatetimeIndex with a frequency.
|
||||
if isinstance(prediction_time_index, pd.DatetimeIndex) and prediction_time_index.freq:
|
||||
time_offset = pd.Timedelta(first_horizon, unit=prediction_time_index.freq.name)
|
||||
target_time_index_h1 = prediction_time_index + time_offset
|
||||
xlabel_h1 = f"Time (Target H+{first_horizon})"
|
||||
else:
|
||||
logger.warning(f"Prediction time index lacks frequency info. Using original prediction time for H+{first_horizon} plot x-axis.")
|
||||
xlabel_h1 = f"Prediction Time (Plotting H+{first_horizon})"
|
||||
except Exception as time_err:
|
||||
logger.warning(f"Could not calculate target time index for H+{first_horizon}: {time_err}. Using prediction time index for x-axis.")
|
||||
xlabel_h1 = f"Prediction Time (Plotting H+{first_horizon})"
|
||||
|
||||
|
||||
title_suffix = f"- {eval_context_str} (H+{first_horizon})"
|
||||
|
||||
try:
|
||||
fig_ts = create_time_series_plot(
|
||||
target_time_index_h1, y_true_h1, y_pred_h1, # Use H1 data and time
|
||||
f"Predictions vs Actual {title_suffix}",
|
||||
xlabel=xlabel_h1, ylabel="Value (Original Scale)",
|
||||
max_points=eval_config.plot_sample_size
|
||||
)
|
||||
save_plot(fig_ts, fold_plot_dir / f"predictions_vs_actual_h{first_horizon}.png")
|
||||
|
||||
fig_scatter = create_scatter_plot(
|
||||
y_true_h1, y_pred_h1, # Use H1 data
|
||||
f"Scatter Plot {title_suffix}",
|
||||
xlabel="Actual Values (Original Scale)", ylabel="Predicted Values (Original Scale)"
|
||||
)
|
||||
save_plot(fig_scatter, fold_plot_dir / f"scatter_predictions_h{first_horizon}.png")
|
||||
|
||||
fig_res_time = create_residuals_plot(
|
||||
target_time_index_h1, residuals_h1, # Use H1 residuals and time
|
||||
f"Residuals Over Time {title_suffix}",
|
||||
xlabel=xlabel_h1, ylabel="Residual (Original Scale)",
|
||||
max_points=eval_config.plot_sample_size
|
||||
)
|
||||
save_plot(fig_res_time, fold_plot_dir / f"residuals_time_h{first_horizon}.png")
|
||||
|
||||
# Residual distribution can use residuals from ALL horizons
|
||||
residuals_all = y_true_inv_np.flatten() - y_pred_inv_np.flatten()
|
||||
fig_res_dist = create_residuals_distribution_plot(
|
||||
residuals_all, # Use all residuals
|
||||
f"Residuals Distribution {eval_context_str} (All Horizons)", # Adjusted title
|
||||
xlabel="Residual Value (Original Scale)", ylabel="Density"
|
||||
)
|
||||
save_plot(fig_res_dist, fold_plot_dir / "residuals_distribution_all_horizons.png")
|
||||
|
||||
logger.info(f"Evaluation plots saved to: {fold_plot_dir}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate or save one or more plots for {eval_context_str}: {e}", exc_info=True)
|
||||
|
||||
elif eval_config.save_plots and np.sum(valid_mask) == 0:
|
||||
logger.warning(f"Skipping plot generation for {eval_context_str} due to no valid data points.")
|
||||
|
||||
logger.info(f"Evaluation processing finished for {eval_context_str}.")
|
||||
return metrics
|
||||
|
||||
|
||||
# --- (Optional) Wrapper for non-PL usage or direct testing ---
|
||||
# This function still calls evaluate_fold_predictions internally, so it benefits
|
||||
# from the updated plotting logic without needing direct changes here.
|
||||
def evaluate_model_on_fold_test_set(
|
||||
model: LSTMForecastLightningModule, # Use the specific type
|
||||
test_loader: DataLoader,
|
||||
device: torch.device,
|
||||
target_scaler: Union[StandardScaler, MinMaxScaler, None],
|
||||
eval_config: EvaluationConfig,
|
||||
fold_num: int,
|
||||
output_dir: str,
|
||||
# time_index: Optional[Union[np.ndarray, pd.Index]] = None, # OLD
|
||||
prediction_time_index: Optional[pd.Index] = None, # Pass prediction time index
|
||||
forecast_horizons: Optional[List[int]] = None # Pass horizons
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
[Optional Function] Evaluates a given model on a fold's test set.
|
||||
Handles multiple forecast horizons.
|
||||
"""
|
||||
logger.info(f"Starting full evaluation (inference + processing) for Fold {fold_num + 1}...")
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
all_preds_scaled_list: List[torch.Tensor] = []
|
||||
all_targets_scaled_list: List[torch.Tensor] = []
|
||||
with torch.no_grad():
|
||||
for i, batch in enumerate(test_loader):
|
||||
try:
|
||||
if isinstance(batch, (list, tuple)) and len(batch) == 2:
|
||||
X_batch, y_batch = batch # y_batch shape: (batch, len(horizons))
|
||||
targets_present = True
|
||||
else:
|
||||
X_batch = batch
|
||||
y_batch = None
|
||||
targets_present = False
|
||||
|
||||
X_batch = X_batch.to(device)
|
||||
outputs = model(X_batch) # Scaled outputs: (batch, len(horizons))
|
||||
|
||||
all_preds_scaled_list.append(outputs.cpu())
|
||||
|
||||
if targets_present and y_batch is not None:
|
||||
if outputs.shape != y_batch.shape:
|
||||
raise ValueError(f"Shape mismatch: Output {outputs.shape}, Target {y_batch.shape}")
|
||||
all_targets_scaled_list.append(y_batch.cpu())
|
||||
# ... error/warning if targets expected but not found ...
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during inference batch {i} for Fold {fold_num+1}: {e}", exc_info=True)
|
||||
raise ValueError(f"Inference failed on batch {i} for Fold {fold_num+1}")
|
||||
|
||||
# --- Concatenate results ---
|
||||
try:
|
||||
if not all_preds_scaled_list:
|
||||
# ... handle no predictions ...
|
||||
return {'MAE': np.nan, 'RMSE': np.nan}
|
||||
# Resulting shapes: (n_samples, len(horizons))
|
||||
y_pred_scaled = torch.cat(all_preds_scaled_list, dim=0).numpy()
|
||||
|
||||
y_true_scaled = None
|
||||
if all_targets_scaled_list:
|
||||
y_true_scaled = torch.cat(all_targets_scaled_list, dim=0).numpy()
|
||||
elif targets_present:
|
||||
# ... handle missing targets ...
|
||||
return {'MAE': np.nan, 'RMSE': np.nan}
|
||||
else:
|
||||
# ... handle no targets available ...
|
||||
return {'MAE': np.nan, 'RMSE': np.nan}
|
||||
|
||||
except Exception as e:
|
||||
# ... error handling ...
|
||||
raise ValueError("Failed to combine batch results during evaluation inference.")
|
||||
|
||||
if y_true_scaled is None:
|
||||
# ... handle missing targets ...
|
||||
return {'MAE': np.nan, 'RMSE': np.nan}
|
||||
|
||||
# Ensure forecast_horizons are passed if available from the model
|
||||
# Retrieve from model's hparams if not passed explicitly
|
||||
if forecast_horizons is None:
|
||||
try:
|
||||
# Assuming forecast_horizon list is stored in model_config hparam
|
||||
forecast_horizons = model.hparams.model_config.forecast_horizon
|
||||
except AttributeError:
|
||||
logger.warning("Could not retrieve forecast_horizons from model hparams for evaluation.")
|
||||
|
||||
|
||||
# Process the collected predictions
|
||||
return evaluate_fold_predictions(
|
||||
y_true_scaled=y_true_scaled,
|
||||
y_pred_scaled=y_pred_scaled,
|
||||
target_scaler=target_scaler,
|
||||
eval_config=eval_config,
|
||||
fold_num=fold_num,
|
||||
output_dir=output_dir,
|
||||
# time_index=time_index # OLD
|
||||
prediction_time_index=prediction_time_index, # Pass through
|
||||
forecast_horizons=forecast_horizons, # Pass through
|
||||
plot_title_prefix=f"Test Fold {fold_num + 1}" # Example prefix
|
||||
)
|
162
forecasting_model/io/data.py
Normal file
162
forecasting_model/io/data.py
Normal file
@ -0,0 +1,162 @@
|
||||
import logging
|
||||
import pandas as pd
|
||||
|
||||
from forecasting_model.utils import DataConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def load_raw_data(config: DataConfig) -> pd.DataFrame:
|
||||
"""
|
||||
Load raw time series data from a CSV file, handling specific formats,
|
||||
performing initial cleaning, frequency checks, and NaN filling based on config.
|
||||
|
||||
Args:
|
||||
config: DataConfig object containing file path, raw/standard column names,
|
||||
frequency settings, and NaN handling flags.
|
||||
|
||||
Returns:
|
||||
DataFrame with a standardized datetime index (named config.datetime_col)
|
||||
and a standardized, cleaned target column (named config.target_col).
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the data path does not exist.
|
||||
ValueError: If specified raw columns are not found, datetime parsing fails,
|
||||
or frequency checks indicate critical issues.
|
||||
Exception: For other pandas read_csv or processing errors.
|
||||
"""
|
||||
logger.info(f"Loading raw data from: {config.data_path}")
|
||||
try:
|
||||
# --- Initial Load ---
|
||||
df = pd.read_csv(config.data_path, header=0)
|
||||
logger.debug(f"Loaded raw data shape: {df.shape}")
|
||||
|
||||
# --- Validate Raw Columns ---
|
||||
if config.raw_datetime_col not in df.columns:
|
||||
raise ValueError(f"Raw datetime column '{config.raw_datetime_col}' not found in {config.data_path}")
|
||||
if config.raw_target_col not in df.columns:
|
||||
raise ValueError(f"Raw target column '{config.raw_target_col}' not found in {config.data_path}")
|
||||
|
||||
# --- Time Parsing (Specific Format Handling) ---
|
||||
logger.info(f"Parsing raw datetime column: '{config.raw_datetime_col}'")
|
||||
try:
|
||||
# Extract the start time part 'dd.mm.yyyy hh:mm'
|
||||
# Handle potential errors during split if format deviates
|
||||
start_times = df[config.raw_datetime_col].astype(str).str.split(' - ', expand=True)[0]
|
||||
# Define the specific format
|
||||
datetime_format = config.raw_datetime_format or '%d.%m.%Y %H:%M'
|
||||
# Parse to datetime, coercing errors to NaT
|
||||
parsed_timestamps = pd.to_datetime(start_times, format=datetime_format, errors='coerce')
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to split or parse raw datetime column '{config.raw_datetime_col}' using expected format: {e}", exc_info=True)
|
||||
raise ValueError("Datetime parsing failed. Check raw_datetime_col format and data.")
|
||||
|
||||
# Check for parsing errors (NaT values)
|
||||
num_parsing_errors = parsed_timestamps.isnull().sum()
|
||||
if num_parsing_errors > 0:
|
||||
original_len = len(df)
|
||||
df = df.loc[parsed_timestamps.notnull()].copy() # Keep only rows with valid timestamps
|
||||
parsed_timestamps = parsed_timestamps.dropna()
|
||||
logger.warning(f"Dropped {num_parsing_errors} rows ({num_parsing_errors/original_len:.1%}) due to timestamp parsing errors "
|
||||
f"(expected format: '{datetime_format}' on start time).")
|
||||
if df.empty:
|
||||
raise ValueError("No valid timestamps found after parsing. Check data format.")
|
||||
|
||||
# Assign parsed timestamp and set as index with standardized name
|
||||
df[config.datetime_col] = parsed_timestamps
|
||||
df = df.set_index(config.datetime_col)
|
||||
logger.debug(f"Set '{config.datetime_col}' as index.")
|
||||
|
||||
|
||||
# --- Target Column Processing ---
|
||||
logger.info(f"Processing target column: '{config.raw_target_col}' -> '{config.target_col}'")
|
||||
# Convert raw target to numeric, coercing errors
|
||||
df[config.target_col] = pd.to_numeric(df[config.raw_target_col], errors='coerce')
|
||||
|
||||
# Handle NaNs caused by coercion
|
||||
num_coercion_errors = df[config.target_col].isnull().sum()
|
||||
if num_coercion_errors > 0:
|
||||
logger.warning(f"Found {num_coercion_errors} non-numeric values in raw target column '{config.raw_target_col}'. Coerced to NaN.")
|
||||
# Keep rows with NaN for now, handle based on config flag below
|
||||
|
||||
|
||||
# --- Column Selection ---
|
||||
# Keep only the standardized target column for the forecasting pipeline
|
||||
# Discard raw columns and any others loaded initially
|
||||
df = df[[config.target_col]]
|
||||
logger.debug(f"Selected target column '{config.target_col}'. Shape: {df.shape}")
|
||||
|
||||
|
||||
# --- Initial Target NaN Filling (Optional) ---
|
||||
if config.fill_initial_target_nans:
|
||||
missing_prices = df[config.target_col].isnull().sum()
|
||||
if missing_prices > 0:
|
||||
logger.info(f"Found {missing_prices} missing values in target column '{config.target_col}'. Applying ffill then bfill.")
|
||||
df[config.target_col] = df[config.target_col].ffill()
|
||||
df[config.target_col] = df[config.target_col].bfill() # Fill remaining NaNs at the start
|
||||
|
||||
final_missing = df[config.target_col].isnull().sum()
|
||||
if final_missing > 0:
|
||||
logger.error(f"{final_missing} missing values REMAIN in target column after ffill/bfill. Cannot proceed.")
|
||||
raise ValueError("Target column contains unfillable NaN values.")
|
||||
else:
|
||||
logger.debug("No missing values found in target column.")
|
||||
else:
|
||||
logger.info("Skipping initial NaN filling for target column as per config.")
|
||||
# Warning if NaNs exist and aren't being filled here
|
||||
if df[config.target_col].isnull().any():
|
||||
logger.warning(f"NaNs exist in target column '{config.target_col}' and initial filling is disabled.")
|
||||
|
||||
|
||||
# --- Frequency Check & Setting ---
|
||||
logger.info("Checking time index frequency...")
|
||||
df = df.sort_index() # Ensure index is sorted before frequency checks
|
||||
|
||||
# Handle duplicate timestamps before frequency inference
|
||||
duplicates = df.index.duplicated().sum()
|
||||
if duplicates > 0:
|
||||
logger.warning(f"Found {duplicates} duplicate timestamps. Keeping the first occurrence.")
|
||||
df = df[~df.index.duplicated(keep='first')]
|
||||
|
||||
if config.expected_frequency:
|
||||
inferred_freq = pd.infer_freq(df.index)
|
||||
logger.debug(f"Inferred frequency: {inferred_freq}")
|
||||
|
||||
if inferred_freq == config.expected_frequency:
|
||||
logger.info(f"Inferred frequency matches expected ('{config.expected_frequency}'). Setting index frequency.")
|
||||
df = df.asfreq(config.expected_frequency)
|
||||
# Check for NaNs introduced by asfreq (filling gaps)
|
||||
missing_after_asfreq = df[config.target_col].isnull().sum()
|
||||
if missing_after_asfreq > 0:
|
||||
logger.warning(f"{missing_after_asfreq} NaNs appeared after setting frequency to '{config.expected_frequency}'. Applying ffill/bfill.")
|
||||
# Only fill if initial filling was also enabled, otherwise just warn? Be explicit.
|
||||
if config.fill_initial_target_nans:
|
||||
df[config.target_col] = df[config.target_col].ffill().bfill()
|
||||
if df[config.target_col].isnull().any():
|
||||
logger.error("NaNs still present after attempting to fill gaps from asfreq. Check data continuity.")
|
||||
raise ValueError("Unfillable NaNs after setting frequency.")
|
||||
else:
|
||||
logger.warning("Initial NaN filling was disabled, leaving NaNs introduced by asfreq.")
|
||||
|
||||
elif inferred_freq:
|
||||
logger.warning(f"Inferred frequency ('{inferred_freq}') does NOT match expected ('{config.expected_frequency}'). Index frequency will not be explicitly set. This might affect time-based features or models assuming regular intervals.")
|
||||
# Consider raising an error depending on strictness needed
|
||||
# raise ValueError("Inferred frequency does not match expected frequency.")
|
||||
else:
|
||||
logger.error(f"Could not infer frequency, but expected frequency was set to '{config.expected_frequency}'. Check data for gaps or irregularities. Index frequency will not be explicitly set.")
|
||||
# This is often a critical issue for time series models
|
||||
raise ValueError("Could not infer frequency. Ensure data has regular intervals matching expected_frequency.")
|
||||
else:
|
||||
logger.info("No expected frequency specified in config. Skipping frequency check and setting.")
|
||||
|
||||
logger.info(f"Data loading and initial preparation complete. Final shape: {df.shape}")
|
||||
return df
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Data file not found at: {config.data_path}")
|
||||
raise
|
||||
except ValueError as e: # Catch ValueErrors raised internally or by pandas
|
||||
logger.error(f"Data processing error: {e}", exc_info=True)
|
||||
raise
|
||||
except Exception as e: # Catch other unexpected errors
|
||||
logger.error(f"Failed to load or process data from {config.data_path}: {e}", exc_info=True)
|
||||
raise
|
@ -14,16 +14,16 @@ from pytorch_lightning.loggers import CSVLogger
|
||||
from typing import Dict, Optional
|
||||
|
||||
from forecasting_model.utils.forecast_config_model import MainConfig
|
||||
from forecasting_model.data_processing import prepare_fold_data_and_loaders, split_data_classic
|
||||
from forecasting_model.train.model import LSTMForecastLightningModule
|
||||
from forecasting_model.evaluation import evaluate_fold_predictions
|
||||
from forecasting_model.train.ensemble_evaluation import evaluate_fold_predictions
|
||||
from forecasting_model.utils.data_processing import prepare_fold_data_and_loaders, split_data_classic
|
||||
|
||||
from forecasting_model.utils.helper import save_results
|
||||
from forecasting_model.io.plotting import plot_loss_curve_from_csv
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def run_classic_training(
|
||||
def run_model_training(
|
||||
config: MainConfig,
|
||||
full_df: pd.DataFrame,
|
||||
output_base_dir: Path
|
||||
@ -66,7 +66,7 @@ def run_classic_training(
|
||||
|
||||
# --- Data Preparation ---
|
||||
logger.info("Preparing data loaders for the classic split...")
|
||||
train_loader, val_loader, test_loader, target_scaler, input_size = prepare_fold_data_and_loaders(
|
||||
train_loader, val_loader, test_loader, target_scaler, data_scaler, input_size = prepare_fold_data_and_loaders(
|
||||
full_df=full_df,
|
||||
train_idx=train_idx,
|
||||
val_idx=val_idx,
|
||||
@ -81,6 +81,11 @@ def run_classic_training(
|
||||
# Save artifacts specific to this run if needed (e.g., for later inference)
|
||||
torch.save(test_loader, classic_output_dir / "classic_test_loader.pt")
|
||||
torch.save(target_scaler, classic_output_dir / "classic_target_scaler.pt")
|
||||
if data_scaler is not None:
|
||||
torch.save(data_scaler, classic_output_dir / "classic_data_scaler.pt")
|
||||
logger.info(f"Saved data scaler to {classic_output_dir / 'classic_data_scaler.pt'}")
|
||||
else:
|
||||
logger.warning("No data scaler was returned by prepare_fold_data_and_loaders. Cannot save data scaler.")
|
||||
torch.save(input_size, classic_output_dir / "classic_input_size.pt")
|
||||
# Save config for this run
|
||||
try: config_dump = config.model_dump()
|
||||
@ -93,7 +98,8 @@ def run_classic_training(
|
||||
model_config=config.model,
|
||||
train_config=config.training,
|
||||
input_size=input_size,
|
||||
target_scaler=target_scaler
|
||||
target_scaler=target_scaler,
|
||||
data_scaler=data_scaler
|
||||
)
|
||||
logger.info("Classic LSTMForecastLightningModule initialized.")
|
||||
|
||||
@ -132,6 +138,7 @@ def run_classic_training(
|
||||
trainer = pl.Trainer(
|
||||
accelerator=accelerator, devices=devices,
|
||||
max_epochs=config.training.epochs,
|
||||
check_val_every_n_epoch=config.training.check_val_n_epoch,
|
||||
callbacks=callbacks, logger=pl_logger,
|
||||
log_every_n_steps=max(1, len(train_loader)//10),
|
||||
enable_progress_bar=True,
|
||||
@ -214,6 +221,7 @@ def run_classic_training(
|
||||
y_true_scaled=all_targets_scaled,
|
||||
y_pred_scaled=all_preds_scaled,
|
||||
target_scaler=target_scaler,
|
||||
data_scaler=data_scaler,
|
||||
eval_config=config.evaluation,
|
||||
fold_num=-1, # Indicate classic run
|
||||
output_dir=str(classic_output_dir),
|
||||
@ -273,4 +281,6 @@ def run_classic_training(
|
||||
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
||||
run_end_time = time.perf_counter()
|
||||
logger.info(f"--- Finished Classic Training Run in {run_end_time - run_start_time:.2f} seconds ---")
|
||||
return test_metrics
|
||||
pass
|
||||
|
||||
return test_metrics
|
||||
|
@ -1,37 +1,32 @@
|
||||
"""
|
||||
Ensemble evaluation for time series forecasting models.
|
||||
|
||||
This module provides functionality to evaluate ensemble predictions
|
||||
by combining predictions from n-1 folds and testing on the remaining fold.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml # For loading fold config
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
||||
import pandas as pd # For time index handling
|
||||
import pickle # Need pickle for the specific error check
|
||||
import pandas as pd
|
||||
import pickle
|
||||
|
||||
from forecasting_model.evaluation import evaluate_fold_predictions
|
||||
from forecasting_model.utils.evaluation import evaluate_fold_predictions
|
||||
from forecasting_model.train.model import LSTMForecastLightningModule
|
||||
from forecasting_model.utils.forecast_config_model import MainConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TODO: make this a class
|
||||
|
||||
def load_fold_model_and_objects(
|
||||
fold_dir: Path,
|
||||
) -> Optional[Tuple[LSTMForecastLightningModule, MainConfig, torch.utils.data.DataLoader, Union[StandardScaler, MinMaxScaler, None], int, Optional[pd.Index], List[int]]]:
|
||||
) -> Optional[Tuple[LSTMForecastLightningModule, MainConfig, torch.utils.data.DataLoader, Union[StandardScaler, MinMaxScaler, None], Union[StandardScaler, MinMaxScaler, None], int, Optional[pd.Index], List[int]]]:
|
||||
"""
|
||||
Load a trained model, its config, dataloader, scaler, input_size, prediction time index, and forecast horizons.
|
||||
Load a trained model, its config, dataloader, scalers, input_size, prediction time index, and forecast horizons.
|
||||
|
||||
Args:
|
||||
fold_dir: Directory containing the fold's artifacts (checkpoint, config, loader, etc.).
|
||||
|
||||
Returns:
|
||||
A tuple containing (model, config, test_loader, target_scaler, input_size, prediction_target_time_index, forecast_horizons)
|
||||
A tuple containing (model, config, test_loader, target_scaler, data_scaler, input_size, prediction_target_time_index, forecast_horizons)
|
||||
or None if any essential artifact is missing or loading fails.
|
||||
"""
|
||||
try:
|
||||
@ -48,18 +43,23 @@ def load_fold_model_and_objects(
|
||||
|
||||
# 2. Load Saved Objects using torch.load
|
||||
test_loader_path = fold_dir / "test_loader.pt"
|
||||
scaler_path = fold_dir / "target_scaler.pt"
|
||||
target_scaler_path = fold_dir / "target_scaler.pt"
|
||||
data_scaler_path = fold_dir / "data_scaler.pt" # Added path for data_scaler
|
||||
input_size_path = fold_dir / "input_size.pt"
|
||||
prediction_index_path = fold_dir / "prediction_target_time_index.pt"
|
||||
|
||||
if not all([p.is_file() for p in [test_loader_path, scaler_path, input_size_path]]):
|
||||
logger.error(f"Missing one or more required artifacts (test_loader, target_scaler, input_size) in {fold_dir}")
|
||||
# Check existence of required files
|
||||
required_paths = [test_loader_path, target_scaler_path, data_scaler_path, input_size_path]
|
||||
if not all([p.is_file() for p in required_paths]):
|
||||
missing = [p.name for p in required_paths if not p.is_file()]
|
||||
logger.error(f"Missing one or more required artifacts ({', '.join(missing)}) in {fold_dir}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# --- Explicitly set weights_only=False for non-model objects ---
|
||||
test_loader = torch.load(test_loader_path, weights_only=False)
|
||||
target_scaler = torch.load(scaler_path, weights_only=False)
|
||||
target_scaler = torch.load(target_scaler_path, weights_only=False)
|
||||
data_scaler = torch.load(data_scaler_path, weights_only=False) # Load data_scaler
|
||||
input_size = torch.load(input_size_path, weights_only=False)
|
||||
# --- End Modification ---
|
||||
except pickle.UnpicklingError as e:
|
||||
@ -111,15 +111,17 @@ def load_fold_model_and_objects(
|
||||
logger.info(f"Loading model from checkpoint: {checkpoint_path}")
|
||||
model = LSTMForecastLightningModule.load_from_checkpoint(
|
||||
checkpoint_path,
|
||||
map_location=torch.device('cpu'), # Optional: load to CPU first if memory is tight
|
||||
map_location=torch.device('cpu'),
|
||||
model_config=fold_config.model,
|
||||
train_config=fold_config.training,
|
||||
input_size=input_size,
|
||||
target_scaler=target_scaler
|
||||
target_scaler=target_scaler,
|
||||
data_scaler=data_scaler
|
||||
)
|
||||
model.eval()
|
||||
logger.info(f"Successfully loaded model and artifacts from {fold_dir}")
|
||||
return model, fold_config, test_loader, target_scaler, input_size, prediction_target_time_index, forecast_horizons
|
||||
# Return data_scaler in the tuple
|
||||
return model, fold_config, test_loader, target_scaler, data_scaler, input_size, prediction_target_time_index, forecast_horizons
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Checkpoint file not found: {checkpoint_path}")
|
||||
@ -278,8 +280,7 @@ def evaluate_ensemble_for_test_fold(
|
||||
if load_result is None:
|
||||
logger.error(f"Failed to load necessary artifacts for test fold {test_fold_num}. Skipping ensemble evaluation for this fold.")
|
||||
return None
|
||||
# Unpack results including the prediction time index and horizons
|
||||
_, test_fold_config, test_loader, target_scaler, _, prediction_target_time_index, test_forecast_horizons = load_result
|
||||
_, test_fold_config, test_loader, target_scaler, data_scaler, _, prediction_target_time_index, test_forecast_horizons = load_result
|
||||
|
||||
# Load models from all *other* folds
|
||||
ensemble_models: List[LSTMForecastLightningModule] = []
|
||||
@ -291,7 +292,8 @@ def evaluate_ensemble_for_test_fold(
|
||||
|
||||
model_load_result = load_fold_model_and_objects(fold_dir)
|
||||
if model_load_result:
|
||||
model, _, _, _, _, _, fold_horizons = model_load_result # Only need the model here
|
||||
# Unpack, only need model and horizons here
|
||||
model, _, _, _, _, _, _, fold_horizons = model_load_result
|
||||
if model:
|
||||
ensemble_models.append(model)
|
||||
# Store horizons from the first successful model load
|
||||
@ -353,8 +355,9 @@ def evaluate_ensemble_for_test_fold(
|
||||
y_true_scaled=targets_np,
|
||||
y_pred_scaled=preds_np,
|
||||
target_scaler=target_scaler,
|
||||
data_scaler=data_scaler,
|
||||
eval_config=test_fold_config.evaluation,
|
||||
fold_num=test_fold_num - 1,
|
||||
fold_num=test_fold_num - 1, # Pass 0-based index if expected by eval func
|
||||
output_dir=str(method_plot_dir.parent.parent),
|
||||
plot_subdir=f"method_{method}",
|
||||
prediction_time_index=prediction_time_index_for_plot, # Pass the index
|
||||
@ -368,10 +371,8 @@ def evaluate_ensemble_for_test_fold(
|
||||
|
||||
|
||||
def run_ensemble_evaluation(
|
||||
config: MainConfig, # Pass main config for context if needed, though fold configs are loaded
|
||||
output_base_dir: Path,
|
||||
# full_data_index: Optional[pd.Index] = None # Removed, get index from loaded objects
|
||||
) -> Dict[int, Dict[str, Dict[str, float]]]:
|
||||
# TODO: Get rid of the base dir path here, should be available through config... Check....
|
||||
config: MainConfig, output_base_dir: Path) -> Dict[int, Dict[str, Dict[str, float]]]:
|
||||
"""
|
||||
Run ensemble evaluation across all folds, treating each as the test set once.
|
||||
|
||||
|
@ -19,12 +19,14 @@ class LSTMForecastLightningModule(pl.LightningModule):
|
||||
|
||||
Encapsulates the model architecture, training, validation, and test logic.
|
||||
Uses torchmetrics for efficient metric calculation.
|
||||
Stores data_scaler and target_scaler for later use (e.g., inference).
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
train_config: TrainingConfig,
|
||||
input_size: int,
|
||||
data_scaler: Optional[Union[StandardScaler, MinMaxScaler]] = None,
|
||||
target_scaler: Optional[Union[StandardScaler, MinMaxScaler]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@ -48,10 +50,12 @@ class LSTMForecastLightningModule(pl.LightningModule):
|
||||
|
||||
self.model_config = model_config
|
||||
self.train_config = train_config
|
||||
self.target_scaler = target_scaler # Store scaler for this fold
|
||||
self.data_scaler = data_scaler
|
||||
self.target_scaler = target_scaler
|
||||
|
||||
# Use save_hyperparameters() - forecast_horizon is part of model_config which is saved
|
||||
self.save_hyperparameters('model_config', 'train_config', 'input_size', ignore=['target_scaler'])
|
||||
# Ignore scalers as they have state and are not simple hyperparameters
|
||||
self.save_hyperparameters('model_config', 'train_config', 'input_size', ignore=['data_scaler', 'target_scaler'])
|
||||
# Note: Pydantic models might not be perfectly saved/loaded by PL's hparams, check if needed.
|
||||
# If issues arise loading, might need to flatten relevant hparams manually.
|
||||
|
||||
@ -142,13 +146,13 @@ class LSTMForecastLightningModule(pl.LightningModule):
|
||||
def _inverse_transform(self, data: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
"""Helper to inverse transform data (preds or targets) using the stored target scaler."""
|
||||
if self.target_scaler is None:
|
||||
logger.warning("Attempted inverse transform, but `target_scaler` is None.")
|
||||
return None
|
||||
|
||||
data_cpu = data.detach().cpu().numpy().astype(np.float64)
|
||||
original_shape = data_cpu.shape # e.g., (batch_size, len(horizons))
|
||||
num_elements = data_cpu.size
|
||||
|
||||
# Scaler expects 2D input (N, 1)
|
||||
data_flat = data_cpu.reshape(num_elements, 1)
|
||||
|
||||
try:
|
||||
@ -157,7 +161,7 @@ class LSTMForecastLightningModule(pl.LightningModule):
|
||||
inversed_tensor = torch.from_numpy(inversed_np).float().to(data.device)
|
||||
# Reshape back to original multi-horizon shape
|
||||
return inversed_tensor.reshape(original_shape)
|
||||
# return inversed_tensor.flatten() # Keep flat if needed for specific metric inputs
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to inverse transform data: {e}", exc_info=True)
|
||||
return None
|
||||
@ -201,6 +205,9 @@ class LSTMForecastLightningModule(pl.LightningModule):
|
||||
logger.warning(f"Shape mismatch after inverse transform in validation: Preds {outputs_inv.shape}, Targets {y_inv.shape}")
|
||||
else:
|
||||
logger.warning("Could not compute original scale MAE in validation due to inverse transform failure.")
|
||||
else:
|
||||
# Only log warning if inverse transform was expected but scaler is missing
|
||||
logger.debug("`target_scaler` is None, skipping validation MAE calculation on original scale.")
|
||||
|
||||
|
||||
def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
|
||||
|
@ -2,174 +2,20 @@ import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch.utils.data import DataLoader
|
||||
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
||||
from typing import Tuple, Generator, List, Optional, Union, Dict, Literal, Type
|
||||
from typing import Tuple, Optional, Union, Type
|
||||
import math # Add math import
|
||||
|
||||
# Use relative import for utils within the package
|
||||
from .utils.forecast_config_model import DataConfig, FeatureConfig, TrainingConfig, EvaluationConfig, CrossValidationConfig
|
||||
from forecasting_model.utils.dataset_splitter import TimeSeriesDataset
|
||||
from forecasting_model.utils.forecast_config_model import FeatureConfig, TrainingConfig, EvaluationConfig
|
||||
|
||||
# Optional: Import wavelet library if needed later
|
||||
# import pywt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Data Loading ---
|
||||
def load_raw_data(config: DataConfig) -> pd.DataFrame:
|
||||
"""
|
||||
Load raw time series data from a CSV file, handling specific formats,
|
||||
performing initial cleaning, frequency checks, and NaN filling based on config.
|
||||
|
||||
Args:
|
||||
config: DataConfig object containing file path, raw/standard column names,
|
||||
frequency settings, and NaN handling flags.
|
||||
|
||||
Returns:
|
||||
DataFrame with a standardized datetime index (named config.datetime_col)
|
||||
and a standardized, cleaned target column (named config.target_col).
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the data path does not exist.
|
||||
ValueError: If specified raw columns are not found, datetime parsing fails,
|
||||
or frequency checks indicate critical issues.
|
||||
Exception: For other pandas read_csv or processing errors.
|
||||
"""
|
||||
logger.info(f"Loading raw data from: {config.data_path}")
|
||||
try:
|
||||
# --- Initial Load ---
|
||||
df = pd.read_csv(config.data_path, header=0)
|
||||
logger.debug(f"Loaded raw data shape: {df.shape}")
|
||||
|
||||
# --- Validate Raw Columns ---
|
||||
if config.raw_datetime_col not in df.columns:
|
||||
raise ValueError(f"Raw datetime column '{config.raw_datetime_col}' not found in {config.data_path}")
|
||||
if config.raw_target_col not in df.columns:
|
||||
raise ValueError(f"Raw target column '{config.raw_target_col}' not found in {config.data_path}")
|
||||
|
||||
# --- Time Parsing (Specific Format Handling) ---
|
||||
logger.info(f"Parsing raw datetime column: '{config.raw_datetime_col}'")
|
||||
try:
|
||||
# Extract the start time part 'dd.mm.yyyy hh:mm'
|
||||
# Handle potential errors during split if format deviates
|
||||
start_times = df[config.raw_datetime_col].astype(str).str.split(' - ', expand=True)[0]
|
||||
# Define the specific format
|
||||
datetime_format = config.raw_datetime_format or '%d.%m.%Y %H:%M'
|
||||
# Parse to datetime, coercing errors to NaT
|
||||
parsed_timestamps = pd.to_datetime(start_times, format=datetime_format, errors='coerce')
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to split or parse raw datetime column '{config.raw_datetime_col}' using expected format: {e}", exc_info=True)
|
||||
raise ValueError("Datetime parsing failed. Check raw_datetime_col format and data.")
|
||||
|
||||
# Check for parsing errors (NaT values)
|
||||
num_parsing_errors = parsed_timestamps.isnull().sum()
|
||||
if num_parsing_errors > 0:
|
||||
original_len = len(df)
|
||||
df = df.loc[parsed_timestamps.notnull()].copy() # Keep only rows with valid timestamps
|
||||
parsed_timestamps = parsed_timestamps.dropna()
|
||||
logger.warning(f"Dropped {num_parsing_errors} rows ({num_parsing_errors/original_len:.1%}) due to timestamp parsing errors "
|
||||
f"(expected format: '{datetime_format}' on start time).")
|
||||
if df.empty:
|
||||
raise ValueError("No valid timestamps found after parsing. Check data format.")
|
||||
|
||||
# Assign parsed timestamp and set as index with standardized name
|
||||
df[config.datetime_col] = parsed_timestamps
|
||||
df = df.set_index(config.datetime_col)
|
||||
logger.debug(f"Set '{config.datetime_col}' as index.")
|
||||
|
||||
|
||||
# --- Target Column Processing ---
|
||||
logger.info(f"Processing target column: '{config.raw_target_col}' -> '{config.target_col}'")
|
||||
# Convert raw target to numeric, coercing errors
|
||||
df[config.target_col] = pd.to_numeric(df[config.raw_target_col], errors='coerce')
|
||||
|
||||
# Handle NaNs caused by coercion
|
||||
num_coercion_errors = df[config.target_col].isnull().sum()
|
||||
if num_coercion_errors > 0:
|
||||
logger.warning(f"Found {num_coercion_errors} non-numeric values in raw target column '{config.raw_target_col}'. Coerced to NaN.")
|
||||
# Keep rows with NaN for now, handle based on config flag below
|
||||
|
||||
|
||||
# --- Column Selection ---
|
||||
# Keep only the standardized target column for the forecasting pipeline
|
||||
# Discard raw columns and any others loaded initially
|
||||
df = df[[config.target_col]]
|
||||
logger.debug(f"Selected target column '{config.target_col}'. Shape: {df.shape}")
|
||||
|
||||
|
||||
# --- Initial Target NaN Filling (Optional) ---
|
||||
if config.fill_initial_target_nans:
|
||||
missing_prices = df[config.target_col].isnull().sum()
|
||||
if missing_prices > 0:
|
||||
logger.info(f"Found {missing_prices} missing values in target column '{config.target_col}'. Applying ffill then bfill.")
|
||||
df[config.target_col] = df[config.target_col].ffill()
|
||||
df[config.target_col] = df[config.target_col].bfill() # Fill remaining NaNs at the start
|
||||
|
||||
final_missing = df[config.target_col].isnull().sum()
|
||||
if final_missing > 0:
|
||||
logger.error(f"{final_missing} missing values REMAIN in target column after ffill/bfill. Cannot proceed.")
|
||||
raise ValueError("Target column contains unfillable NaN values.")
|
||||
else:
|
||||
logger.debug("No missing values found in target column.")
|
||||
else:
|
||||
logger.info("Skipping initial NaN filling for target column as per config.")
|
||||
# Warning if NaNs exist and aren't being filled here
|
||||
if df[config.target_col].isnull().any():
|
||||
logger.warning(f"NaNs exist in target column '{config.target_col}' and initial filling is disabled.")
|
||||
|
||||
|
||||
# --- Frequency Check & Setting ---
|
||||
logger.info("Checking time index frequency...")
|
||||
df = df.sort_index() # Ensure index is sorted before frequency checks
|
||||
|
||||
# Handle duplicate timestamps before frequency inference
|
||||
duplicates = df.index.duplicated().sum()
|
||||
if duplicates > 0:
|
||||
logger.warning(f"Found {duplicates} duplicate timestamps. Keeping the first occurrence.")
|
||||
df = df[~df.index.duplicated(keep='first')]
|
||||
|
||||
if config.expected_frequency:
|
||||
inferred_freq = pd.infer_freq(df.index)
|
||||
logger.debug(f"Inferred frequency: {inferred_freq}")
|
||||
|
||||
if inferred_freq == config.expected_frequency:
|
||||
logger.info(f"Inferred frequency matches expected ('{config.expected_frequency}'). Setting index frequency.")
|
||||
df = df.asfreq(config.expected_frequency)
|
||||
# Check for NaNs introduced by asfreq (filling gaps)
|
||||
missing_after_asfreq = df[config.target_col].isnull().sum()
|
||||
if missing_after_asfreq > 0:
|
||||
logger.warning(f"{missing_after_asfreq} NaNs appeared after setting frequency to '{config.expected_frequency}'. Applying ffill/bfill.")
|
||||
# Only fill if initial filling was also enabled, otherwise just warn? Be explicit.
|
||||
if config.fill_initial_target_nans:
|
||||
df[config.target_col] = df[config.target_col].ffill().bfill()
|
||||
if df[config.target_col].isnull().any():
|
||||
logger.error("NaNs still present after attempting to fill gaps from asfreq. Check data continuity.")
|
||||
raise ValueError("Unfillable NaNs after setting frequency.")
|
||||
else:
|
||||
logger.warning("Initial NaN filling was disabled, leaving NaNs introduced by asfreq.")
|
||||
|
||||
elif inferred_freq:
|
||||
logger.warning(f"Inferred frequency ('{inferred_freq}') does NOT match expected ('{config.expected_frequency}'). Index frequency will not be explicitly set. This might affect time-based features or models assuming regular intervals.")
|
||||
# Consider raising an error depending on strictness needed
|
||||
# raise ValueError("Inferred frequency does not match expected frequency.")
|
||||
else:
|
||||
logger.error(f"Could not infer frequency, but expected frequency was set to '{config.expected_frequency}'. Check data for gaps or irregularities. Index frequency will not be explicitly set.")
|
||||
# This is often a critical issue for time series models
|
||||
raise ValueError("Could not infer frequency. Ensure data has regular intervals matching expected_frequency.")
|
||||
else:
|
||||
logger.info("No expected frequency specified in config. Skipping frequency check and setting.")
|
||||
|
||||
logger.info(f"Data loading and initial preparation complete. Final shape: {df.shape}")
|
||||
return df
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Data file not found at: {config.data_path}")
|
||||
raise
|
||||
except ValueError as e: # Catch ValueErrors raised internally or by pandas
|
||||
logger.error(f"Data processing error: {e}", exc_info=True)
|
||||
raise
|
||||
except Exception as e: # Catch other unexpected errors
|
||||
logger.error(f"Failed to load or process data from {config.data_path}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
# --- Feature Engineering ---
|
||||
def engineer_features(df: pd.DataFrame, target_col: str, feature_config: FeatureConfig) -> pd.DataFrame:
|
||||
@ -242,7 +88,7 @@ def engineer_features(df: pd.DataFrame, target_col: str, feature_config: Feature
|
||||
seconds_past_midnight = features_df.index.hour * 3600 + features_df.index.minute * 60 + features_df.index.second
|
||||
features_df['sin_day'] = np.sin(2 * np.pi * seconds_past_midnight / seconds_in_day)
|
||||
|
||||
if feature_config.cosin_curve: # Assuming this means cos for day
|
||||
if feature_config.cosine_curve: # Assuming this means cos for day
|
||||
logger.debug("Creating cosinusoidal daily time feature.")
|
||||
seconds_in_day = 24 * 60 * 60
|
||||
seconds_past_midnight = features_df.index.hour * 3600 + features_df.index.minute * 60 + features_df.index.second
|
||||
@ -324,233 +170,6 @@ def engineer_features(df: pd.DataFrame, target_col: str, feature_config: Feature
|
||||
return features_df
|
||||
|
||||
|
||||
# --- Cross Validation ---
|
||||
class TimeSeriesCrossValidationSplitter:
|
||||
"""
|
||||
Generates indices for time series cross-validation using a rolling (sliding) window.
|
||||
|
||||
The training window has a fixed size. For each split, the entire window
|
||||
(train, validation, and test sets) slides forward by a specified step size
|
||||
(typically the size of the test set). Validation and test set sizes are
|
||||
calculated as fractions of the fixed training window size.
|
||||
"""
|
||||
def __init__(self, config: CrossValidationConfig, n_samples: int):
|
||||
"""
|
||||
Args:
|
||||
config: CrossValidationConfig with split parameters.
|
||||
n_samples: Total number of samples in the dataset.
|
||||
"""
|
||||
self.n_splits = config.n_splits
|
||||
self.val_frac = config.val_size_fraction
|
||||
self.test_frac = config.test_size_fraction
|
||||
self.initial_train_size = config.initial_train_size # Used as the FIXED train size for rolling window
|
||||
self.n_samples = n_samples
|
||||
|
||||
if not (0 < self.val_frac < 1):
|
||||
raise ValueError(f"val_size_fraction must be between 0 and 1, got {self.val_frac}")
|
||||
if not (0 < self.test_frac < 1):
|
||||
raise ValueError(f"test_size_fraction must be between 0 and 1, got {self.test_frac}")
|
||||
if self.n_splits <= 0:
|
||||
raise ValueError(f"n_splits must be positive, got {self.n_splits}")
|
||||
|
||||
logger.info(f"Initializing TimeSeriesCrossValidationSplitter (Rolling Window): n_splits={self.n_splits}, "
|
||||
f"val_frac={self.val_frac}, test_frac={self.test_frac}, initial_train_size (fixed)={self.initial_train_size}") # Clarified log
|
||||
|
||||
def _calculate_initial_train_size(self) -> int:
|
||||
"""Determines the fixed training window size based on config or estimation."""
|
||||
# Check if integer is provided
|
||||
if isinstance(self.initial_train_size, int) and self.initial_train_size > 0:
|
||||
if self.initial_train_size >= self.n_samples:
|
||||
raise ValueError(f"initial_train_size ({self.initial_train_size}) must be less than total samples ({self.n_samples})")
|
||||
logger.info(f"Using specified fixed training window size: {self.initial_train_size}")
|
||||
return self.initial_train_size
|
||||
|
||||
# Check if float/fraction is provided
|
||||
elif isinstance(self.initial_train_size, float) and 0 < self.initial_train_size < 1:
|
||||
calculated_size = int(self.n_samples * self.initial_train_size)
|
||||
if calculated_size <= 0:
|
||||
raise ValueError("initial_train_size fraction results in non-positive size.")
|
||||
logger.info(f"Using fixed training window size calculated from fraction: {calculated_size}")
|
||||
return calculated_size
|
||||
|
||||
# Estimate if None
|
||||
elif self.initial_train_size is None:
|
||||
logger.info("Estimating fixed train size based on n_splits, val_frac, test_frac.")
|
||||
# Estimate based on the total space needed for all splits:
|
||||
# n_samples >= fixed_train_n + val_size + test_size + (n_splits - 1) * step_size
|
||||
# n_samples >= fixed_train_n + int(fixed_train_n*val_frac) + n_splits * int(fixed_train_n*test_frac)
|
||||
# n_samples >= fixed_train_n * (1 + val_frac + n_splits * test_frac)
|
||||
# fixed_train_n <= n_samples / (1 + val_frac + n_splits * test_frac)
|
||||
|
||||
denominator = 1.0 + self.val_frac + self.n_splits * self.test_frac
|
||||
if denominator <= 1.0: # Avoid division by zero or non-positive, and ensure train frac < 1
|
||||
raise ValueError(f"Cannot estimate initial_train_size. Combination of val_frac ({self.val_frac}), "
|
||||
f"test_frac ({self.test_frac}), and n_splits ({self.n_splits}) is invalid (denominator {denominator:.2f} <= 1.0).")
|
||||
|
||||
estimated_size = int(self.n_samples / denominator)
|
||||
|
||||
# Add a sanity check: ensure estimated size is reasonably large
|
||||
min_required_for_features = 1 # Placeholder - ideally get from FeatureConfig if possible, but complex here
|
||||
if estimated_size < min_required_for_features:
|
||||
raise ValueError(f"Estimated fixed train size ({estimated_size}) is too small. "
|
||||
f"Check CV config (n_splits={self.n_splits}, val_frac={self.val_frac}, test_frac={self.test_frac}) "
|
||||
f"relative to total samples ({self.n_samples}). Consider specifying initial_train_size manually.")
|
||||
|
||||
logger.info(f"Estimated fixed training window size: {estimated_size}")
|
||||
return estimated_size
|
||||
else:
|
||||
raise ValueError(f"Invalid initial_train_size type or value: {self.initial_train_size}")
|
||||
|
||||
|
||||
def split(self) -> Generator[Tuple[np.ndarray, np.ndarray, np.ndarray], None, None]:
|
||||
"""
|
||||
Generate train/validation/test indices for each fold using a rolling window.
|
||||
Pre-calculates the number of possible splits based on data size and window parameters.
|
||||
|
||||
Yields:
|
||||
Tuple of (train_indices, val_indices, test_indices) for each fold.
|
||||
|
||||
Raises:
|
||||
ValueError: If parameters lead to invalid split sizes or overlaps,
|
||||
or if the data is too small for the configuration.
|
||||
"""
|
||||
indices = np.arange(self.n_samples)
|
||||
fixed_train_n = self._calculate_initial_train_size() # This is now the fixed size
|
||||
|
||||
# Calculate val/test sizes based on the *fixed* training size. Min size of 1.
|
||||
val_size = max(1, int(fixed_train_n * self.val_frac))
|
||||
test_size = max(1, int(fixed_train_n * self.test_frac))
|
||||
|
||||
# Calculate the total size of one complete train+val+test window
|
||||
fold_window_size = fixed_train_n + val_size + test_size
|
||||
|
||||
# Check if even the first window fits
|
||||
if fold_window_size > self.n_samples:
|
||||
raise ValueError(f"Configuration Error: The total window size (Train {fixed_train_n} + Val {val_size} + Test {test_size} = {fold_window_size}) "
|
||||
f"exceeds total samples ({self.n_samples}). Decrease initial_train_size, fractions, or increase data.")
|
||||
|
||||
# Determine the step size (how much the window slides)
|
||||
# Default: slide by the test set size for contiguous, non-overlapping test periods
|
||||
step_size = test_size
|
||||
if step_size <= 0:
|
||||
raise ValueError(f"Step size (derived from test_size {test_size}) must be positive.")
|
||||
|
||||
# --- Calculate the number of splits actually possible ---
|
||||
# Last possible start index for the train set
|
||||
last_possible_train_start_idx = self.n_samples - fold_window_size
|
||||
# Calculate how many steps fit within this range (integer division)
|
||||
# If last possible start is 5, step is 2: steps possible at 0, 2, 4 => (5 // 2) + 1 = 2 + 1 = 3
|
||||
num_possible_steps = max(0, last_possible_train_start_idx // step_size) + 1 # +1 because we start at index 0
|
||||
|
||||
# Use the minimum of requested splits and possible splits
|
||||
actual_n_splits = min(self.n_splits, num_possible_steps)
|
||||
|
||||
if actual_n_splits < self.n_splits:
|
||||
logger.warning(f"Data size ({self.n_samples} samples) only allows for {actual_n_splits} splits "
|
||||
f"with fixed train size {fixed_train_n}, val size {val_size}, test size {test_size} (total window {fold_window_size}) and step size {step_size} "
|
||||
f"(requested {self.n_splits}).")
|
||||
elif actual_n_splits == 0:
|
||||
# This case should be caught by the fold_window_size > self.n_samples check, but belt-and-suspenders
|
||||
logger.error("Data too small for even one split with the rolling window configuration.")
|
||||
return # Return generator that yields nothing
|
||||
|
||||
# --- Generate the splits ---
|
||||
for i in range(actual_n_splits):
|
||||
logger.debug(f"Generating indices for fold {i+1}/{actual_n_splits} (Rolling Window)") # Log using actual_n_splits
|
||||
|
||||
# Calculate window boundaries for this fold
|
||||
train_start_idx = i * step_size
|
||||
train_end_idx = train_start_idx + fixed_train_n
|
||||
val_start_idx = train_end_idx
|
||||
val_end_idx = val_start_idx + val_size
|
||||
test_start_idx = val_end_idx
|
||||
test_end_idx = test_start_idx + test_size # = train_start_idx + fold_window_size
|
||||
|
||||
# Determine indices for this fold using slicing
|
||||
train_indices = indices[train_start_idx:train_end_idx]
|
||||
val_indices = indices[val_start_idx:val_end_idx]
|
||||
test_indices = indices[test_start_idx:test_end_idx]
|
||||
|
||||
# --- Basic Validation Checks (Optional, should be guaranteed by calculations) ---
|
||||
# Ensure no overlap (guaranteed by slicing if sizes > 0)
|
||||
# Ensure sequence (guaranteed by slicing)
|
||||
|
||||
logger.info(f"Fold {i+1}: Train indices {train_indices[0]}-{train_indices[-1]} (size {len(train_indices)}), "
|
||||
f"Val indices {val_indices[0]}-{val_indices[-1]} (size {len(val_indices)}), "
|
||||
f"Test indices {test_indices[0]}-{test_indices[-1]} (size {len(test_indices)})")
|
||||
|
||||
yield train_indices, val_indices, test_indices
|
||||
|
||||
|
||||
# --- Dataset Class ---
|
||||
class TimeSeriesDataset(Dataset):
|
||||
"""
|
||||
PyTorch Dataset for time series forecasting.
|
||||
|
||||
Takes a NumPy array (features + target), sequence length, and a list of
|
||||
specific forecast horizons. Returns (input_sequence, target_vector) tuples,
|
||||
where target_vector contains the target values at the specified future steps.
|
||||
"""
|
||||
def __init__(self, data_array: np.ndarray, sequence_length: int, forecast_horizon: List[int], target_col_index: int = 0):
|
||||
"""
|
||||
Args:
|
||||
data_array: Numpy array of shape (n_samples, n_features).
|
||||
Assumes the target variable is one of the columns.
|
||||
sequence_length: Length of the input sequence (lookback window).
|
||||
forecast_horizon: List of specific steps ahead to predict (e.g., [1, 6, 12]).
|
||||
target_col_index: Index of the target column in data_array. Defaults to 0.
|
||||
"""
|
||||
if sequence_length <= 0:
|
||||
raise ValueError("sequence_length must be positive.")
|
||||
if not forecast_horizon or not isinstance(forecast_horizon, list) or any(h <= 0 for h in forecast_horizon):
|
||||
raise ValueError("forecast_horizon must be a non-empty list of positive integers.")
|
||||
if data_array.ndim != 2:
|
||||
raise ValueError(f"data_array must be 2D, but got shape {data_array.shape}")
|
||||
|
||||
self.max_horizon = max(forecast_horizon) # Find the furthest point needed
|
||||
|
||||
min_len_required = sequence_length + self.max_horizon
|
||||
if min_len_required > data_array.shape[0]:
|
||||
raise ValueError(f"sequence_length ({sequence_length}) + max_horizon ({self.max_horizon}) = {min_len_required} "
|
||||
f"exceeds total samples provided ({data_array.shape[0]})")
|
||||
if not (0 <= target_col_index < data_array.shape[1]):
|
||||
raise ValueError(f"target_col_index ({target_col_index}) out of bounds for data with {data_array.shape[1]} columns.")
|
||||
|
||||
|
||||
self.data = torch.tensor(data_array, dtype=torch.float32)
|
||||
self.sequence_length = sequence_length
|
||||
self.forecast_horizon_list = sorted(forecast_horizon)
|
||||
self.target_col_index = target_col_index
|
||||
self.n_samples = data_array.shape[0]
|
||||
self.n_features = data_array.shape[1]
|
||||
|
||||
logger.debug(f"TimeSeriesDataset created: data shape={self.data.shape}, "
|
||||
f"seq_len={self.sequence_length}, forecast_horizons={self.forecast_horizon_list}, "
|
||||
f"max_horizon={self.max_horizon}, target_idx={self.target_col_index}")
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the total number of sequences that can be generated."""
|
||||
return self.n_samples - self.sequence_length - self.max_horizon + 1
|
||||
|
||||
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Returns a single (input_sequence, target_vector) pair.
|
||||
Target vector contains values for the specified forecast horizons.
|
||||
"""
|
||||
if not (0 <= idx < len(self)):
|
||||
raise IndexError(f"Index {idx} out of bounds for dataset with length {len(self)}")
|
||||
|
||||
input_start = idx
|
||||
input_end = idx + self.sequence_length
|
||||
input_sequence = self.data[input_start:input_end, :] # Shape: (seq_len, n_features)
|
||||
|
||||
# Calculate indices for each horizon relative to the end of the input sequence
|
||||
# Horizon h corresponds to index: input_end + h - 1
|
||||
target_indices = [input_end + h - 1 for h in self.forecast_horizon_list]
|
||||
target_vector = self.data[target_indices, self.target_col_index] # Shape: (len(forecast_horizon_list),)
|
||||
|
||||
return input_sequence, target_vector
|
||||
|
||||
# --- Data Preparation ---
|
||||
def prepare_fold_data_and_loaders(
|
||||
full_df: pd.DataFrame, # Should contain only the target initially
|
||||
@ -561,7 +180,7 @@ def prepare_fold_data_and_loaders(
|
||||
feature_config: FeatureConfig,
|
||||
train_config: TrainingConfig,
|
||||
eval_config: EvaluationConfig
|
||||
) -> Tuple[DataLoader, DataLoader, DataLoader, Union[StandardScaler, MinMaxScaler, None], int]:
|
||||
) -> Tuple[DataLoader, DataLoader, DataLoader, Optional[Union[StandardScaler, MinMaxScaler]], Optional[Union[StandardScaler, MinMaxScaler]], int]:
|
||||
"""
|
||||
Prepares data loaders for a single cross-validation fold.
|
||||
|
||||
@ -588,14 +207,16 @@ def prepare_fold_data_and_loaders(
|
||||
feature_config: Configuration for feature engineering.
|
||||
train_config: Configuration for training (used for batch size, device hints).
|
||||
eval_config: Configuration for evaluation (used for batch size).
|
||||
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- train_loader: DataLoader for the training set.
|
||||
- val_loader: DataLoader for the validation set.
|
||||
- test_loader: DataLoader for the test set.
|
||||
- target_scaler: The scaler fitted on the target variable (for inverse transform). Can be None.
|
||||
- scaler: The scaler fitted on all features of the training data (for transforming new data). Can be None.
|
||||
- target_scaler: The scaler fitted on the target variable of the training data (for inverse transform). Can be None.
|
||||
- data_scaler: The scaler fitted on the training data (for inverse transform). Can be None.
|
||||
- input_size: The number of features in the input sequences (X).
|
||||
|
||||
Raises:
|
||||
@ -621,12 +242,12 @@ def prepare_fold_data_and_loaders(
|
||||
# Max history *before* the start of the training set
|
||||
max_history_needed_before_train = max(max_lookback, feature_config.sequence_length)
|
||||
|
||||
slice_start_idx = max(0, train_idx[0] - max_history_needed_before_train)
|
||||
slice_start_idx = max(0, int(train_idx[0] - max_history_needed_before_train))
|
||||
# The end index needs to cover the test set PLUS the maximum horizon needed for the last test target
|
||||
slice_end_idx = test_idx[-1] + max_horizon_needed # Go up to the last needed target
|
||||
|
||||
# Ensure end index is within bounds
|
||||
slice_end_idx = min(slice_end_idx + 1, len(full_df)) # +1 because iloc is exclusive
|
||||
slice_end_idx = min(int(slice_end_idx + 1), len(full_df)) # +1 because iloc is exclusive
|
||||
|
||||
if slice_start_idx >= slice_end_idx:
|
||||
raise ValueError(f"Calculated slice start ({slice_start_idx}) >= slice end ({slice_end_idx}). Check indices and horizon.")
|
||||
@ -700,7 +321,7 @@ def prepare_fold_data_and_loaders(
|
||||
except ValueError:
|
||||
raise ValueError(f"Target column '{target_col}' not found in the final feature columns: {feature_cols}")
|
||||
|
||||
scaler: Optional[Union[StandardScaler, MinMaxScaler]] = None
|
||||
data_scaler: Optional[Union[StandardScaler, MinMaxScaler]] = None
|
||||
target_scaler: Optional[Union[StandardScaler, MinMaxScaler]] = None
|
||||
ScalerClass: Optional[Type[Union[StandardScaler, MinMaxScaler]]] = None
|
||||
|
||||
@ -718,19 +339,29 @@ def prepare_fold_data_and_loaders(
|
||||
test_data = test_df[feature_cols].values
|
||||
|
||||
if ScalerClass is not None:
|
||||
scaler = ScalerClass()
|
||||
target_scaler = ScalerClass()
|
||||
data_scaler = ScalerClass()
|
||||
target_scaler = ScalerClass() # Initialize target scaler regardless of whether target_col is present
|
||||
logger.info(f"Applying {feature_config.scaling_method} scaling. Fitting on training data for the fold.")
|
||||
scaler.fit(train_data)
|
||||
target_scaler.fit(train_data[:, target_col_index_in_features].reshape(-1, 1))
|
||||
train_data_scaled = scaler.transform(train_data)
|
||||
val_data_scaled = scaler.transform(val_data)
|
||||
test_data_scaled = scaler.transform(test_data)
|
||||
|
||||
# Fit the main scaler on all training features
|
||||
data_scaler.fit(train_data)
|
||||
|
||||
# Fit the target scaler only on the target column
|
||||
target_data_train = train_data[:, target_col_index_in_features].reshape(-1, 1)
|
||||
target_scaler.fit(target_data_train)
|
||||
|
||||
|
||||
# Transform all datasets using the main scaler
|
||||
train_data_scaled = data_scaler.transform(train_data)
|
||||
val_data_scaled = data_scaler.transform(val_data)
|
||||
test_data_scaled = data_scaler.transform(test_data)
|
||||
logger.debug("Scaling complete for the fold.")
|
||||
else:
|
||||
train_data_scaled = train_data
|
||||
val_data_scaled = val_data
|
||||
test_data_scaled = test_data
|
||||
data_scaler = None # Explicitly set to None if no scaling
|
||||
target_scaler = None # Explicitly set to None if no scaling
|
||||
|
||||
input_size = train_data_scaled.shape[1]
|
||||
|
||||
@ -789,10 +420,9 @@ def prepare_fold_data_and_loaders(
|
||||
|
||||
logger.info("Data loaders prepared successfully for the fold.")
|
||||
|
||||
return train_loader, val_loader, test_loader, target_scaler, input_size
|
||||
return train_loader, val_loader, test_loader, target_scaler, data_scaler, input_size
|
||||
|
||||
# --- Classic Train/Val/Test Split ---
|
||||
|
||||
def split_data_classic(
|
||||
n_samples: int,
|
||||
val_frac: float,
|
234
forecasting_model/utils/dataset_splitter.py
Normal file
234
forecasting_model/utils/dataset_splitter.py
Normal file
@ -0,0 +1,234 @@
|
||||
from typing import Generator, Tuple, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from forecasting_model.utils.helper import logger
|
||||
from forecasting_model.utils import CrossValidationConfig
|
||||
|
||||
|
||||
class TimeSeriesCrossValidationSplitter:
|
||||
"""
|
||||
Generates indices for time series cross-validation using a rolling (sliding) window.
|
||||
|
||||
The training window has a fixed size. For each split, the entire window
|
||||
(train, validation, and test sets) slides forward by a specified step size
|
||||
(typically the size of the test set). Validation and test set sizes are
|
||||
calculated as fractions of the fixed training window size.
|
||||
"""
|
||||
def __init__(self, config: CrossValidationConfig, n_samples: int):
|
||||
"""
|
||||
Args:
|
||||
config: CrossValidationConfig with split parameters.
|
||||
n_samples: Total number of samples in the dataset.
|
||||
"""
|
||||
self.n_splits = config.n_splits
|
||||
self.val_frac = config.val_size_fraction
|
||||
self.test_frac = config.test_size_fraction
|
||||
self.initial_train_size = config.initial_train_size # Used as the FIXED train size for rolling window
|
||||
self.n_samples = n_samples
|
||||
|
||||
if not (0 < self.val_frac < 1):
|
||||
raise ValueError(f"val_size_fraction must be between 0 and 1, got {self.val_frac}")
|
||||
if not (0 < self.test_frac < 1):
|
||||
raise ValueError(f"test_size_fraction must be between 0 and 1, got {self.test_frac}")
|
||||
if self.n_splits <= 0:
|
||||
raise ValueError(f"n_splits must be positive, got {self.n_splits}")
|
||||
|
||||
logger.info(f"Initializing TimeSeriesCrossValidationSplitter (Rolling Window): n_splits={self.n_splits}, "
|
||||
f"val_frac={self.val_frac}, test_frac={self.test_frac}, initial_train_size (fixed)={self.initial_train_size}") # Clarified log
|
||||
|
||||
def _calculate_initial_train_size(self) -> int:
|
||||
"""Determines the fixed training window size based on config or estimation."""
|
||||
# Check if integer is provided
|
||||
if isinstance(self.initial_train_size, int) and self.initial_train_size > 0:
|
||||
if self.initial_train_size >= self.n_samples:
|
||||
raise ValueError(f"initial_train_size ({self.initial_train_size}) must be less than total samples ({self.n_samples})")
|
||||
logger.info(f"Using specified fixed training window size: {self.initial_train_size}")
|
||||
return self.initial_train_size
|
||||
|
||||
# Check if float/fraction is provided
|
||||
elif isinstance(self.initial_train_size, float) and 0 < self.initial_train_size < 1:
|
||||
calculated_size = int(self.n_samples * self.initial_train_size)
|
||||
if calculated_size <= 0:
|
||||
raise ValueError("initial_train_size fraction results in non-positive size.")
|
||||
logger.info(f"Using fixed training window size calculated from fraction: {calculated_size}")
|
||||
return calculated_size
|
||||
|
||||
# Estimate if None
|
||||
elif self.initial_train_size is None:
|
||||
logger.info("Estimating fixed train size based on n_splits, val_frac, test_frac.")
|
||||
# Estimate based on the total space needed for all splits:
|
||||
# n_samples >= fixed_train_n + val_size + test_size + (n_splits - 1) * step_size
|
||||
# n_samples >= fixed_train_n + int(fixed_train_n*val_frac) + n_splits * int(fixed_train_n*test_frac)
|
||||
# n_samples >= fixed_train_n * (1 + val_frac + n_splits * test_frac)
|
||||
# fixed_train_n <= n_samples / (1 + val_frac + n_splits * test_frac)
|
||||
|
||||
denominator = 1.0 + self.val_frac + self.n_splits * self.test_frac
|
||||
if denominator <= 1.0: # Avoid division by zero or non-positive, and ensure train frac < 1
|
||||
raise ValueError(f"Cannot estimate initial_train_size. Combination of val_frac ({self.val_frac}), "
|
||||
f"test_frac ({self.test_frac}), and n_splits ({self.n_splits}) is invalid (denominator {denominator:.2f} <= 1.0).")
|
||||
|
||||
estimated_size = int(self.n_samples / denominator)
|
||||
|
||||
# Add a sanity check: ensure estimated size is reasonably large
|
||||
min_required_for_features = 1 # Placeholder - ideally get from FeatureConfig if possible, but complex here
|
||||
if estimated_size < min_required_for_features:
|
||||
raise ValueError(f"Estimated fixed train size ({estimated_size}) is too small. "
|
||||
f"Check CV config (n_splits={self.n_splits}, val_frac={self.val_frac}, test_frac={self.test_frac}) "
|
||||
f"relative to total samples ({self.n_samples}). Consider specifying initial_train_size manually.")
|
||||
|
||||
logger.info(f"Estimated fixed training window size: {estimated_size}")
|
||||
return estimated_size
|
||||
else:
|
||||
raise ValueError(f"Invalid initial_train_size type or value: {self.initial_train_size}")
|
||||
|
||||
|
||||
def split(self) -> Generator[Tuple[np.ndarray, np.ndarray, np.ndarray], None, None]:
|
||||
"""
|
||||
Generate train/validation/test indices for each fold using a rolling window.
|
||||
Pre-calculates the number of possible splits based on data size and window parameters.
|
||||
|
||||
Yields:
|
||||
Tuple of (train_indices, val_indices, test_indices) for each fold.
|
||||
|
||||
Raises:
|
||||
ValueError: If parameters lead to invalid split sizes or overlaps,
|
||||
or if the data is too small for the configuration.
|
||||
"""
|
||||
indices = np.arange(self.n_samples)
|
||||
fixed_train_n = self._calculate_initial_train_size() # This is now the fixed size
|
||||
|
||||
# Calculate val/test sizes based on the *fixed* training size. Min size of 1.
|
||||
val_size = max(1, int(fixed_train_n * self.val_frac))
|
||||
test_size = max(1, int(fixed_train_n * self.test_frac))
|
||||
|
||||
# Calculate the total size of one complete train+val+test window
|
||||
fold_window_size = fixed_train_n + val_size + test_size
|
||||
|
||||
# Check if even the first window fits
|
||||
if fold_window_size > self.n_samples:
|
||||
raise ValueError(f"Configuration Error: The total window size (Train {fixed_train_n} + Val {val_size} + Test {test_size} = {fold_window_size}) "
|
||||
f"exceeds total samples ({self.n_samples}). Decrease initial_train_size, fractions, or increase data.")
|
||||
|
||||
# Determine the step size (how much the window slides)
|
||||
# Default: slide by the test set size for contiguous, non-overlapping test periods
|
||||
step_size = test_size
|
||||
if step_size <= 0:
|
||||
raise ValueError(f"Step size (derived from test_size {test_size}) must be positive.")
|
||||
|
||||
# --- Calculate the number of splits actually possible ---
|
||||
# Last possible start index for the train set
|
||||
last_possible_train_start_idx = self.n_samples - fold_window_size
|
||||
# Calculate how many steps fit within this range (integer division)
|
||||
# If last possible start is 5, step is 2: steps possible at 0, 2, 4 => (5 // 2) + 1 = 2 + 1 = 3
|
||||
num_possible_steps = max(0, last_possible_train_start_idx // step_size) + 1 # +1 because we start at index 0
|
||||
|
||||
# Use the minimum of requested splits and possible splits
|
||||
actual_n_splits = min(self.n_splits, num_possible_steps)
|
||||
|
||||
if actual_n_splits < self.n_splits:
|
||||
logger.warning(f"Data size ({self.n_samples} samples) only allows for {actual_n_splits} splits "
|
||||
f"with fixed train size {fixed_train_n}, val size {val_size}, test size {test_size} (total window {fold_window_size}) and step size {step_size} "
|
||||
f"(requested {self.n_splits}).")
|
||||
elif actual_n_splits == 0:
|
||||
# This case should be caught by the fold_window_size > self.n_samples check, but belt-and-suspenders
|
||||
logger.error("Data too small for even one split with the rolling window configuration.")
|
||||
return # Return generator that yields nothing
|
||||
|
||||
# --- Generate the splits ---
|
||||
for i in range(actual_n_splits):
|
||||
logger.debug(f"Generating indices for fold {i+1}/{actual_n_splits} (Rolling Window)") # Log using actual_n_splits
|
||||
|
||||
# Calculate window boundaries for this fold
|
||||
train_start_idx = i * step_size
|
||||
train_end_idx = train_start_idx + fixed_train_n
|
||||
val_start_idx = train_end_idx
|
||||
val_end_idx = val_start_idx + val_size
|
||||
test_start_idx = val_end_idx
|
||||
test_end_idx = test_start_idx + test_size # = train_start_idx + fold_window_size
|
||||
|
||||
# Determine indices for this fold using slicing
|
||||
train_indices = indices[train_start_idx:train_end_idx]
|
||||
val_indices = indices[val_start_idx:val_end_idx]
|
||||
test_indices = indices[test_start_idx:test_end_idx]
|
||||
|
||||
# --- Basic Validation Checks (Optional, should be guaranteed by calculations) ---
|
||||
# Ensure no overlap (guaranteed by slicing if sizes > 0)
|
||||
# Ensure sequence (guaranteed by slicing)
|
||||
|
||||
logger.info(f"Fold {i+1}: Train indices {train_indices[0]}-{train_indices[-1]} (size {len(train_indices)}), "
|
||||
f"Val indices {val_indices[0]}-{val_indices[-1]} (size {len(val_indices)}), "
|
||||
f"Test indices {test_indices[0]}-{test_indices[-1]} (size {len(test_indices)})")
|
||||
|
||||
yield train_indices, val_indices, test_indices
|
||||
|
||||
|
||||
class TimeSeriesDataset(Dataset):
|
||||
"""
|
||||
PyTorch Dataset for time series forecasting.
|
||||
|
||||
Takes a NumPy array (features + target), sequence length, and a list of
|
||||
specific forecast horizons. Returns (input_sequence, target_vector) tuples,
|
||||
where target_vector contains the target values at the specified future steps.
|
||||
"""
|
||||
def __init__(self, data_array: np.ndarray, sequence_length: int, forecast_horizon: List[int], target_col_index: int = 0):
|
||||
"""
|
||||
Args:
|
||||
data_array: Numpy array of shape (n_samples, n_features).
|
||||
Assumes the target variable is one of the columns.
|
||||
sequence_length: Length of the input sequence (lookback window).
|
||||
forecast_horizon: List of specific steps ahead to predict (e.g., [1, 6, 12]).
|
||||
target_col_index: Index of the target column in data_array. Defaults to 0.
|
||||
"""
|
||||
if sequence_length <= 0:
|
||||
raise ValueError("sequence_length must be positive.")
|
||||
if not forecast_horizon or not isinstance(forecast_horizon, list) or any(h <= 0 for h in forecast_horizon):
|
||||
raise ValueError("forecast_horizon must be a non-empty list of positive integers.")
|
||||
if data_array.ndim != 2:
|
||||
raise ValueError(f"data_array must be 2D, but got shape {data_array.shape}")
|
||||
|
||||
self.max_horizon = max(forecast_horizon) # Find the furthest point needed
|
||||
|
||||
min_len_required = sequence_length + self.max_horizon
|
||||
if min_len_required > data_array.shape[0]:
|
||||
raise ValueError(f"sequence_length ({sequence_length}) + max_horizon ({self.max_horizon}) = {min_len_required} "
|
||||
f"exceeds total samples provided ({data_array.shape[0]})")
|
||||
if not (0 <= target_col_index < data_array.shape[1]):
|
||||
raise ValueError(f"target_col_index ({target_col_index}) out of bounds for data with {data_array.shape[1]} columns.")
|
||||
|
||||
|
||||
self.data = torch.tensor(data_array, dtype=torch.float32)
|
||||
self.sequence_length = sequence_length
|
||||
self.forecast_horizon_list = sorted(forecast_horizon)
|
||||
self.target_col_index = target_col_index
|
||||
self.n_samples = data_array.shape[0]
|
||||
self.n_features = data_array.shape[1]
|
||||
|
||||
logger.debug(f"TimeSeriesDataset created: data shape={self.data.shape}, "
|
||||
f"seq_len={self.sequence_length}, forecast_horizons={self.forecast_horizon_list}, "
|
||||
f"max_horizon={self.max_horizon}, target_idx={self.target_col_index}")
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the total number of sequences that can be generated."""
|
||||
return self.n_samples - self.sequence_length - self.max_horizon + 1
|
||||
|
||||
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Returns a single (input_sequence, target_vector) pair.
|
||||
Target vector contains values for the specified forecast horizons.
|
||||
"""
|
||||
if not (0 <= idx < len(self)):
|
||||
raise IndexError(f"Index {idx} out of bounds for dataset with length {len(self)}")
|
||||
|
||||
input_start = idx
|
||||
input_end = idx + self.sequence_length
|
||||
input_sequence = self.data[input_start:input_end, :] # Shape: (seq_len, n_features)
|
||||
|
||||
# Calculate indices for each horizon relative to the end of the input sequence
|
||||
# Horizon h corresponds to index: input_end + h - 1
|
||||
target_indices = [input_end + h - 1 for h in self.forecast_horizon_list]
|
||||
target_vector = self.data[target_indices, self.target_col_index] # Shape: (len(forecast_horizon_list),)
|
||||
|
||||
return input_sequence, target_vector
|
218
forecasting_model/utils/evaluation.py
Normal file
218
forecasting_model/utils/evaluation.py
Normal file
@ -0,0 +1,218 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchmetrics
|
||||
|
||||
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
||||
from typing import Dict, Optional, Union, List
|
||||
import pandas as pd
|
||||
|
||||
from forecasting_model.utils.forecast_config_model import EvaluationConfig
|
||||
|
||||
from forecasting_model.io.plotting import (
|
||||
setup_plot_style,
|
||||
save_plot,
|
||||
create_time_series_plot,
|
||||
create_scatter_plot,
|
||||
create_residuals_plot,
|
||||
create_residuals_distribution_plot,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Fold Evaluation Function ---
|
||||
def evaluate_fold_predictions(
|
||||
y_true_scaled: np.ndarray,
|
||||
y_pred_scaled: np.ndarray,
|
||||
target_scaler: Union[StandardScaler, MinMaxScaler, None],
|
||||
data_scaler: Union[StandardScaler, MinMaxScaler, None],
|
||||
eval_config: EvaluationConfig,
|
||||
fold_num: int, # Zero-based fold index
|
||||
output_dir: str, # Base output directory
|
||||
plot_subdir: Optional[str] = "plots",
|
||||
prediction_time_index: Optional[pd.Index] = None,
|
||||
forecast_horizons: Optional[List[int]] = None,
|
||||
plot_title_prefix: Optional[str] = None
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Processes prediction results (multiple horizons) for a fold or ensemble.
|
||||
|
||||
Takes scaled predictions and targets (shape: samples, num_horizons),
|
||||
inverse transforms them using the target_scaler, calculates overall metrics
|
||||
(MAE, RMSE) across all horizons, and generates evaluation plots *for the
|
||||
first specified horizon only*.
|
||||
|
||||
Args:
|
||||
y_true_scaled: Numpy array of scaled ground truth targets (n_samples, len(horizons)).
|
||||
y_pred_scaled: Numpy array of scaled model predictions (n_samples, len(horizons)).
|
||||
target_scaler: The scaler fitted on the target variable. Used for inverse transform.
|
||||
data_scaler: The scaler fitted on the input features (kept for potential future use or context, not used in current calculations). # ADDED Docstring
|
||||
eval_config: Configuration object for evaluation parameters.
|
||||
fold_num: The current fold number (zero-based or -1 for classic).
|
||||
output_dir: The base directory to save outputs.
|
||||
plot_subdir: Specific subdirectory under output_dir for plots.
|
||||
prediction_time_index: Pandas Index representing the time for each prediction point (n_samples,).
|
||||
Required for meaningful time plots.
|
||||
forecast_horizons: List of horizons predicted (e.g., [1, 6, 12]). Required for plotting.
|
||||
plot_title_prefix: Optional string to prepend to plot titles.
|
||||
|
||||
Returns:
|
||||
Dictionary containing evaluation metrics {'MAE': value, 'RMSE': value} on the
|
||||
original scale, calculated *across all predicted horizons*.
|
||||
"""
|
||||
fold_id_str = f"Fold {fold_num + 1}" if fold_num >= 0 else "Classic Run"
|
||||
eval_context_str = f"{plot_title_prefix} {fold_id_str}" if plot_title_prefix else fold_id_str
|
||||
logger.info(f"Processing evaluation results for: {eval_context_str}")
|
||||
|
||||
if y_true_scaled.shape != y_pred_scaled.shape:
|
||||
raise ValueError(f"Shape mismatch between targets and predictions for {eval_context_str}: "
|
||||
f"{y_true_scaled.shape} vs {y_pred_scaled.shape}")
|
||||
if y_true_scaled.ndim != 2:
|
||||
raise ValueError(f"Expected 2D arrays (samples, num_horizons) for {eval_context_str}, got {y_true_scaled.ndim}D")
|
||||
|
||||
n_samples, n_horizons = y_true_scaled.shape
|
||||
logger.debug(f"Processing {n_samples} samples across {n_horizons} horizons for {eval_context_str}.")
|
||||
|
||||
# --- Inverse Transform (Outputs NumPy) ---
|
||||
# Flatten the multi-horizon arrays for the scaler (which expects (N, 1))
|
||||
y_true_flat_scaled = y_true_scaled.reshape(-1, 1) # Shape: (n_samples * n_horizons, 1)
|
||||
y_pred_flat_scaled = y_pred_scaled.reshape(-1, 1) # Shape: (n_samples * n_horizons, 1)
|
||||
|
||||
y_true_inv_np: np.ndarray
|
||||
y_pred_inv_np: np.ndarray
|
||||
|
||||
if target_scaler is not None:
|
||||
try:
|
||||
logger.debug(f"Inverse transforming predictions and targets for {eval_context_str}.")
|
||||
y_true_inv_flat = target_scaler.inverse_transform(y_true_flat_scaled)
|
||||
y_pred_inv_flat = target_scaler.inverse_transform(y_pred_flat_scaled)
|
||||
# Reshape back to (n_samples, n_horizons) for potential per-horizon analysis later
|
||||
y_true_inv_np = y_true_inv_flat.reshape(n_samples, n_horizons)
|
||||
y_pred_inv_np = y_pred_inv_flat.reshape(n_samples, n_horizons)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during inverse scaling for {eval_context_str}: {e}", exc_info=True)
|
||||
logger.error("Metrics calculation will be skipped due to inverse transform failure.")
|
||||
return {'MAE': np.nan, 'RMSE': np.nan}
|
||||
|
||||
else:
|
||||
logger.info(f"No target scaler provided for {eval_context_str}, assuming inputs are on original scale.")
|
||||
y_true_inv_np = y_true_scaled # Keep original shape (n_samples, n_horizons)
|
||||
y_pred_inv_np = y_pred_scaled # Keep original shape
|
||||
|
||||
# --- Calculate Metrics using torchmetrics.functional (Overall across all horizons) ---
|
||||
metrics: Dict[str, float] = {'MAE': np.nan, 'RMSE': np.nan}
|
||||
try:
|
||||
# Flatten arrays for overall metrics calculation
|
||||
y_true_flat_for_metrics = y_true_inv_np.flatten()
|
||||
y_pred_flat_for_metrics = y_pred_inv_np.flatten()
|
||||
|
||||
valid_mask = ~np.isnan(y_true_flat_for_metrics) & ~np.isnan(y_pred_flat_for_metrics)
|
||||
if np.sum(valid_mask) < len(y_true_flat_for_metrics):
|
||||
nan_count = len(y_true_flat_for_metrics) - np.sum(valid_mask)
|
||||
logger.warning(f"{nan_count} NaN values found in predictions/targets (across all horizons) for {eval_context_str}. These will be excluded from metrics.")
|
||||
|
||||
|
||||
if np.sum(valid_mask) > 0:
|
||||
y_true_tensor = torch.from_numpy(y_true_flat_for_metrics[valid_mask]).float().cpu()
|
||||
y_pred_tensor = torch.from_numpy(y_pred_flat_for_metrics[valid_mask]).float().cpu()
|
||||
|
||||
mae_tensor = torchmetrics.functional.mean_absolute_error(y_pred_tensor, y_true_tensor)
|
||||
mse_tensor = torchmetrics.functional.mean_squared_error(y_pred_tensor, y_true_tensor)
|
||||
rmse_tensor = torch.sqrt(mse_tensor)
|
||||
|
||||
metrics['MAE'] = mae_tensor.item()
|
||||
metrics['RMSE'] = rmse_tensor.item()
|
||||
|
||||
logger.info(f"{eval_context_str} Test Set Overall Metrics (torchmetrics): MAE={metrics['MAE']:.4f}, RMSE={metrics['RMSE']:.4f} (across all horizons)")
|
||||
else:
|
||||
logger.warning(f"Skipping metric calculation for {eval_context_str} due to no valid (non-NaN) data points.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to calculate overall metrics using torchmetrics for {eval_context_str}: {e}", exc_info=True)
|
||||
|
||||
|
||||
# --- Generate Plots (Optional - Focus on FIRST horizon) ---
|
||||
if eval_config.save_plots and np.sum(valid_mask) > 0:
|
||||
if forecast_horizons is None or not forecast_horizons:
|
||||
logger.warning(f"Skipping plot generation for {eval_context_str}: `forecast_horizons` list not provided.")
|
||||
elif prediction_time_index is None or len(prediction_time_index) != n_samples:
|
||||
logger.warning(f"Skipping plot generation for {eval_context_str}: `prediction_time_index` is missing or has incorrect length ({len(prediction_time_index) if prediction_time_index is not None else 'None'} != {n_samples}).")
|
||||
else:
|
||||
logger.info(f"Generating evaluation plots for {eval_context_str} (using first horizon H+{forecast_horizons[0]} only)...")
|
||||
base_plot_dir = Path(output_dir)
|
||||
fold_plot_dir = base_plot_dir / plot_subdir if plot_subdir else base_plot_dir
|
||||
setup_plot_style()
|
||||
|
||||
for h_i, horizon in enumerate(forecast_horizons):
|
||||
# We want only some horizon plots to be available.
|
||||
if horizon in [forecast_horizons[0], forecast_horizons[-1], 24, 48, 12]:
|
||||
y_true_h1 = y_true_inv_np[:, h_i]
|
||||
y_pred_h1 = y_pred_inv_np[:, h_i]
|
||||
residuals_h1 = y_true_h1 - y_pred_h1
|
||||
|
||||
# Calculate the actual time index for the first horizon's targets
|
||||
# Requires the original dataset's frequency if available, otherwise assumes simple offset
|
||||
target_time_index_h1 = prediction_time_index
|
||||
try:
|
||||
# Assuming prediction_time_index corresponds to the *time* of prediction
|
||||
# The target for H+h occurs `h` steps later.
|
||||
# This requires a DatetimeIndex with a frequency.
|
||||
if isinstance(prediction_time_index, pd.DatetimeIndex) and prediction_time_index.freq:
|
||||
time_offset = pd.Timedelta(horizon, unit=prediction_time_index.freq.name)
|
||||
target_time_index_h1 = prediction_time_index + time_offset
|
||||
xlabel_h1 = f"Time (Target H+{horizon})"
|
||||
else:
|
||||
logger.warning(f"Prediction time index lacks frequency info. Using original prediction time for H+{horizon} plot x-axis.")
|
||||
xlabel_h1 = f"Prediction Time (Plotting H+{horizon})"
|
||||
except Exception as time_err:
|
||||
logger.warning(f"Could not calculate target time index for H+{horizon}: {time_err}. Using prediction time index for x-axis.")
|
||||
xlabel_h1 = f"Prediction Time (Plotting H+{horizon})"
|
||||
|
||||
|
||||
title_suffix = f"- {eval_context_str} (H+{horizon})"
|
||||
|
||||
try:
|
||||
fig_ts = create_time_series_plot(
|
||||
target_time_index_h1, y_true_h1, y_pred_h1, # Use H1 data and time
|
||||
f"Predictions vs Actual {title_suffix}",
|
||||
xlabel=xlabel_h1, ylabel="Value (Original Scale)",
|
||||
max_points=eval_config.plot_sample_size
|
||||
)
|
||||
save_plot(fig_ts, fold_plot_dir / f"predictions_vs_actual_h{horizon}.png")
|
||||
|
||||
fig_scatter = create_scatter_plot(
|
||||
y_true_h1, y_pred_h1, # Use H1 data
|
||||
f"Scatter Plot {title_suffix}",
|
||||
xlabel="Actual Values (Original Scale)", ylabel="Predicted Values (Original Scale)"
|
||||
)
|
||||
save_plot(fig_scatter, fold_plot_dir / f"scatter_predictions_h{horizon}.png")
|
||||
|
||||
fig_res_time = create_residuals_plot(
|
||||
target_time_index_h1, residuals_h1, # Use H1 residuals and time
|
||||
f"Residuals Over Time {title_suffix}",
|
||||
xlabel=xlabel_h1, ylabel="Residual (Original Scale)",
|
||||
max_points=eval_config.plot_sample_size
|
||||
)
|
||||
save_plot(fig_res_time, fold_plot_dir / f"residuals_time_h{horizon}.png")
|
||||
if horizon == forecast_horizons[-1]:
|
||||
# Residual distribution can use residuals from ALL horizons
|
||||
residuals_all = y_true_inv_np.flatten() - y_pred_inv_np.flatten()
|
||||
fig_res_dist = create_residuals_distribution_plot(
|
||||
residuals_all, # Use all residuals
|
||||
f"Residuals Distribution {eval_context_str} (All Horizons)", # Adjusted title
|
||||
xlabel="Residual Value (Original Scale)", ylabel="Density"
|
||||
)
|
||||
save_plot(fig_res_dist, fold_plot_dir / "residuals_distribution_all_horizons.png")
|
||||
|
||||
logger.info(f"Evaluation plots saved to: {fold_plot_dir}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate or save one or more plots for {eval_context_str}: {e}", exc_info=True)
|
||||
|
||||
elif eval_config.save_plots and np.sum(valid_mask) == 0:
|
||||
logger.warning(f"Skipping plot generation for {eval_context_str} due to no valid data points.")
|
||||
|
||||
logger.info(f"Evaluation processing finished for {eval_context_str}.")
|
||||
return metrics
|
@ -1,9 +1,8 @@
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from typing import Optional, List, Union, Literal
|
||||
from enum import Enum
|
||||
|
||||
|
||||
# --- Nested Configs ---
|
||||
|
||||
class WaveletTransformConfig(BaseModel):
|
||||
"""Configuration for optional wavelet transform features."""
|
||||
apply: bool = False
|
||||
@ -49,14 +48,13 @@ class FeatureConfig(BaseModel):
|
||||
rolling_window_sizes: List[int] = []
|
||||
use_time_features: bool = True
|
||||
sinus_curve: bool = False # Added
|
||||
cosin_curve: bool = False # Added
|
||||
cosine_curve: bool = False # Added
|
||||
wavelet_transform: Optional[WaveletTransformConfig] = None
|
||||
fill_nan: Optional[Union[str, float, int]] = 'ffill' # Added (e.g., 'ffill', 0)
|
||||
clipping: ClippingConfig = ClippingConfig() # Default instance
|
||||
scaling_method: Optional[Literal['standard', 'minmax']] = 'standard' # Added literal validation
|
||||
|
||||
@field_validator('lags', 'rolling_window_sizes', 'forecast_horizon')
|
||||
@classmethod
|
||||
def check_positive_list_values(cls, v: List[int]) -> List[int]:
|
||||
if any(val <= 0 for val in v):
|
||||
raise ValueError('Lists lags, rolling_window_sizes, and forecast_horizon must contain only positive values')
|
||||
@ -75,7 +73,8 @@ class ModelConfig(BaseModel):
|
||||
class TrainingConfig(BaseModel):
|
||||
"""Configuration for the training process (PyTorch Lightning)."""
|
||||
batch_size: int = Field(..., gt=0)
|
||||
epochs: int = Field(..., gt=0) # Max epochs
|
||||
epochs: int = Field(..., gt=0)
|
||||
check_val_n_epoch: int = Field(..., gt=0)
|
||||
learning_rate: float = Field(..., gt=0.0)
|
||||
loss_function: Literal['MSE', 'MAE'] = 'MSE'
|
||||
# device: str = 'auto' # Handled by PL Trainer accelerator/devices args
|
||||
|
@ -3,15 +3,15 @@ import json
|
||||
import logging
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict
|
||||
from typing import Optional, List, Dict, Type
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from forecasting_model import MainConfig
|
||||
|
||||
# Get the root logger
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -35,15 +35,16 @@ def parse_arguments():
|
||||
return args
|
||||
|
||||
|
||||
def load_config(config_path: Path) -> MainConfig:
|
||||
def load_config(config_path: Path, config_cls: Type[BaseModel]) -> Type[BaseModel]:
|
||||
"""
|
||||
Load and validate configuration from YAML file using Pydantic.
|
||||
|
||||
Args:
|
||||
config_path: Path to the YAML configuration file.
|
||||
config_cls: Pydantic configuration class.
|
||||
|
||||
Returns:
|
||||
Validated MainConfig object.
|
||||
Validated Basemodel object.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the config file doesn't exist.
|
||||
@ -60,7 +61,7 @@ def load_config(config_path: Path) -> MainConfig:
|
||||
config_dict = yaml.safe_load(f)
|
||||
|
||||
# Validate configuration using Pydantic model
|
||||
config = MainConfig(**config_dict)
|
||||
config = config_cls(**config_dict)
|
||||
logger.info("Configuration loaded and validated successfully.")
|
||||
return config
|
||||
except yaml.YAMLError as e:
|
||||
@ -171,3 +172,75 @@ class NumpyEncoder(json.JSONEncoder):
|
||||
elif pd.isna(obj): # Handle pandas NaT or numpy NaN gracefully
|
||||
return None
|
||||
return super(NumpyEncoder, self).default(obj)
|
||||
|
||||
def calculate_h1_target_index(
|
||||
full_df: pd.DataFrame,
|
||||
test_idx: Optional[np.ndarray],
|
||||
sequence_length: int,
|
||||
forecast_horizon: Optional[list[int]],
|
||||
n_predictions: int,
|
||||
fold_id: int # For logging context
|
||||
) -> Optional[pd.DatetimeIndex]:
|
||||
"""
|
||||
Calculates the DatetimeIndex for the first forecast horizon (h1) targets.
|
||||
|
||||
Args:
|
||||
full_df: The complete dataset DataFrame with a DatetimeIndex.
|
||||
test_idx: The indices marking the start of sequences in the test set.
|
||||
sequence_length: The length of the input sequences.
|
||||
forecast_horizon: List of forecast steps (e.g., [1, 2, ...]).
|
||||
n_predictions: The number of predictions generated for the test set.
|
||||
fold_id: The current fold number (1-based) for logging.
|
||||
|
||||
Returns:
|
||||
A pandas DatetimeIndex corresponding to the h1 targets, or None if calculation fails.
|
||||
"""
|
||||
if test_idx is None or not forecast_horizon:
|
||||
logger.warning(f"Fold {fold_id}: Skipping target time index calculation (missing test_idx or forecast_horizon).")
|
||||
return None
|
||||
|
||||
if n_predictions == 0:
|
||||
logger.warning(f"Fold {fold_id}: Skipping target time index calculation (n_predictions is 0).")
|
||||
return None
|
||||
|
||||
try:
|
||||
first_horizon = forecast_horizon[0]
|
||||
# Calculate the theoretical index relative to the start of full_df
|
||||
target_indices_h1 = test_idx + sequence_length + first_horizon - 1
|
||||
|
||||
# Filter out indices that would fall outside the bounds of the dataframe
|
||||
valid_mask = target_indices_h1 < len(full_df)
|
||||
valid_target_indices_h1 = target_indices_h1[valid_mask]
|
||||
|
||||
# Check if we have enough valid indices for the predictions made
|
||||
if len(valid_target_indices_h1) < n_predictions:
|
||||
logger.warning(
|
||||
f"Fold {fold_id}: Cannot calculate target time index for h1; "
|
||||
f"insufficient valid indices ({len(valid_target_indices_h1)}) "
|
||||
f"for the {n_predictions} predictions made."
|
||||
)
|
||||
return None
|
||||
|
||||
# Select the indices corresponding to the actual predictions
|
||||
# We assume the predictions correspond to the first possible valid targets
|
||||
indices_for_predictions = valid_target_indices_h1[:n_predictions]
|
||||
|
||||
# Retrieve the actual DatetimeIndex values
|
||||
prediction_target_time_index_h1 = full_df.index[indices_for_predictions]
|
||||
|
||||
# Final sanity check on length
|
||||
if len(prediction_target_time_index_h1) != n_predictions:
|
||||
logger.warning(
|
||||
f"Fold {fold_id}: Calculated target time index length ({len(prediction_target_time_index_h1)}) "
|
||||
f"does not match prediction count ({n_predictions}). Plotting x-axis might be misaligned."
|
||||
)
|
||||
return None
|
||||
|
||||
return prediction_target_time_index_h1
|
||||
|
||||
except IndexError as e:
|
||||
logger.error(f"Fold {fold_id}: Error accessing DataFrame index during target time calculation (potentially out of bounds): {e}", exc_info=True)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Fold {fold_id}: Error calculating target time index for plotting: {e}", exc_info=True)
|
||||
return None
|
||||
|
@ -13,21 +13,24 @@ from pytorch_lightning.loggers import CSVLogger
|
||||
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
||||
|
||||
# Import necessary components from your project structure
|
||||
# Assuming forecasting_model is a package installable or in PYTHONPATH
|
||||
from forecasting_model.utils.forecast_config_model import MainConfig
|
||||
from forecasting_model.data_processing import (
|
||||
load_raw_data,
|
||||
TimeSeriesCrossValidationSplitter,
|
||||
from forecasting_model.utils.data_processing import (
|
||||
prepare_fold_data_and_loaders
|
||||
)
|
||||
from forecasting_model.utils.dataset_splitter import TimeSeriesCrossValidationSplitter
|
||||
from forecasting_model.io.data import load_raw_data
|
||||
from forecasting_model.train.model import LSTMForecastLightningModule
|
||||
from forecasting_model.evaluation import evaluate_fold_predictions
|
||||
from forecasting_model.utils.evaluation import evaluate_fold_predictions
|
||||
from forecasting_model.train.ensemble_evaluation import run_ensemble_evaluation
|
||||
|
||||
# Import the new classic training function
|
||||
from forecasting_model.train.classic import run_classic_training
|
||||
from forecasting_model.train.classic import run_model_training
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from forecasting_model.utils.helper import parse_arguments, load_config, set_seeds, aggregate_cv_metrics, save_results
|
||||
from forecasting_model.utils.helper import (
|
||||
parse_arguments, load_config,
|
||||
set_seeds, aggregate_cv_metrics,
|
||||
save_results, calculate_h1_target_index
|
||||
)
|
||||
from forecasting_model.io.plotting import plot_loss_curve_from_csv, create_multi_horizon_time_series_plot, save_plot
|
||||
|
||||
# Silence overly verbose libraries if needed
|
||||
@ -46,6 +49,7 @@ logger = logging.getLogger()
|
||||
|
||||
|
||||
# --- Single Fold Processing Function ---
|
||||
# noinspection PyInconsistentReturns
|
||||
def run_single_fold(
|
||||
fold_num: int,
|
||||
train_idx: np.ndarray,
|
||||
@ -53,8 +57,9 @@ def run_single_fold(
|
||||
test_idx: np.ndarray,
|
||||
config: MainConfig,
|
||||
full_df: pd.DataFrame,
|
||||
output_base_dir: Path # Receives Path object from run_training_pipeline
|
||||
) -> Tuple[Dict[str, float], Optional[float], Optional[Path], Optional[Path], Optional[Path], Optional[Path]]:
|
||||
output_base_dir: Path,
|
||||
enable_progress_bar: bool = True
|
||||
) -> Optional[Tuple[Dict[str, float], Optional[float], Optional[Path], Optional[Path], Optional[Path], Optional[Path], Optional[Path]]]:
|
||||
"""
|
||||
Runs the pipeline for a single cross-validation fold.
|
||||
|
||||
@ -66,6 +71,7 @@ def run_single_fold(
|
||||
config: The main configuration object.
|
||||
full_df: The complete raw DataFrame.
|
||||
output_base_dir: The base directory Path for saving results.
|
||||
enable_progress_bar: Whether to enable progress bar.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
@ -73,6 +79,7 @@ def run_single_fold(
|
||||
- best_val_score: The best validation score achieved during training (or None).
|
||||
- saved_model_path: Path to the best saved model checkpoint (or None).
|
||||
- saved_target_scaler_path: Path to the saved target scaler (or None).
|
||||
- saved_data_scaler_path: Path to the saved data feature scaler (or None).
|
||||
- saved_input_size_path: Path to the saved input size file (or None).
|
||||
- saved_config_path: Path to the saved config file for this fold (or None).
|
||||
"""
|
||||
@ -92,19 +99,23 @@ def run_single_fold(
|
||||
all_preds_scaled: Optional[np.ndarray] = None
|
||||
all_targets_scaled: Optional[np.ndarray] = None
|
||||
target_scaler: Optional[Union[StandardScaler, MinMaxScaler]] = None # Need to keep scaler reference
|
||||
data_scaler: Optional[Union[StandardScaler, MinMaxScaler]] = None # Added to keep data scaler reference
|
||||
prediction_target_time_index_h1: Optional[pd.DatetimeIndex] = None
|
||||
pl_logger = None
|
||||
|
||||
# Variables to store paths of saved artifacts
|
||||
saved_model_path: Optional[Path] = None
|
||||
saved_target_scaler_path: Optional[Path] = None
|
||||
saved_data_scaler_path: Optional[Path] = None # Added
|
||||
saved_input_size_path: Optional[Path] = None
|
||||
saved_config_path: Optional[Path] = None
|
||||
|
||||
try:
|
||||
# --- Per-Fold Data Preparation ---
|
||||
logger.info("Preparing data loaders for the fold...")
|
||||
# Keep scaler and input_size references returned by prepare_fold_data_and_loaders
|
||||
train_loader, val_loader, test_loader, target_scaler_fold, input_size = prepare_fold_data_and_loaders( # Renamed target_scaler
|
||||
# Assume prepare_fold_data_and_loaders returns the data_scaler as the 5th element
|
||||
# Modify this call based on the actual return signature of prepare_fold_data_and_loaders
|
||||
train_loader, val_loader, test_loader, target_scaler_fold, data_scaler_fold, input_size = prepare_fold_data_and_loaders(
|
||||
full_df=full_df,
|
||||
train_idx=train_idx,
|
||||
val_idx=val_idx,
|
||||
@ -114,13 +125,17 @@ def run_single_fold(
|
||||
train_config=config.training,
|
||||
eval_config=config.evaluation
|
||||
)
|
||||
target_scaler = target_scaler_fold # Store the scaler in the outer scope
|
||||
target_scaler = target_scaler_fold # Store the target scaler in the outer scope
|
||||
data_scaler = data_scaler_fold # Store the data scaler in the outer scope
|
||||
logger.info(f"Data loaders prepared. Input size determined: {input_size}")
|
||||
|
||||
# Save necessary items for potential later use (e.g., ensemble)
|
||||
# Save necessary items for potential later use (e.g., ensemble, inference)
|
||||
# Capture the paths when saving
|
||||
saved_target_scaler_path = fold_output_dir / "target_scaler.pt"
|
||||
torch.save(target_scaler, saved_target_scaler_path)
|
||||
saved_data_scaler_path = fold_output_dir / "data_scaler.pt"
|
||||
torch.save(data_scaler, saved_data_scaler_path)
|
||||
|
||||
torch.save(test_loader, fold_output_dir / "test_loader.pt") # Test loader might be large, consider if needed
|
||||
|
||||
# Save input size and capture path
|
||||
@ -140,13 +155,14 @@ def run_single_fold(
|
||||
model_config=config.model,
|
||||
train_config=config.training,
|
||||
input_size=input_size,
|
||||
target_scaler=target_scaler_fold # Pass scaler during init
|
||||
target_scaler=target_scaler_fold,
|
||||
data_scaler=data_scaler
|
||||
)
|
||||
logger.info("LSTMForecastLightningModule initialized.")
|
||||
|
||||
# --- PyTorch Lightning Callbacks ---
|
||||
# Ensure monitor_metric matches the exact name logged in model.py
|
||||
monitor_metric = "val_MeanAbsoluteError_Original_Scale" # Corrected metric name
|
||||
monitor_metric = "val_MeanAbsoluteError" # Corrected metric name
|
||||
monitor_mode = "min"
|
||||
|
||||
early_stop_callback = None
|
||||
@ -174,6 +190,7 @@ def run_single_fold(
|
||||
|
||||
callbacks = [checkpoint_callback, lr_monitor]
|
||||
if early_stop_callback:
|
||||
# noinspection PyTypeChecker
|
||||
callbacks.append(early_stop_callback)
|
||||
|
||||
# --- PyTorch Lightning Logger ---
|
||||
@ -190,12 +207,13 @@ def run_single_fold(
|
||||
|
||||
trainer = pl.Trainer(
|
||||
accelerator=accelerator,
|
||||
check_val_every_n_epoch=config.training.check_val_n_epoch,
|
||||
devices=devices,
|
||||
enable_progress_bar=enable_progress_bar,
|
||||
max_epochs=config.training.epochs,
|
||||
callbacks=callbacks,
|
||||
logger=pl_logger,
|
||||
log_every_n_steps=max(1, len(train_loader)//10),
|
||||
enable_progress_bar=True,
|
||||
gradient_clip_val=getattr(config.training, 'gradient_clip_val', None),
|
||||
precision=precision,
|
||||
)
|
||||
@ -262,59 +280,33 @@ def run_single_fold(
|
||||
logger.info(f"Processing {n_predictions} prediction results for Fold {fold_id}...")
|
||||
|
||||
# --- Calculate Correct Time Index for Plotting (First Horizon) ---
|
||||
prediction_target_time_index_h1 = calculate_h1_target_index(
|
||||
full_df=full_df,
|
||||
test_idx=test_idx,
|
||||
sequence_length=config.features.sequence_length,
|
||||
forecast_horizon=config.features.forecast_horizon,
|
||||
n_predictions=n_predictions,
|
||||
fold_id=fold_id
|
||||
)
|
||||
|
||||
# --- Handle Saving/Cleanup of the Index File ---
|
||||
prediction_target_time_index_h1_path = fold_output_dir / "prediction_target_time_index_h1.pt"
|
||||
|
||||
prediction_target_time_index_h1 = None
|
||||
|
||||
if test_idx is not None and config.features.forecast_horizon and len(config.features.forecast_horizon) > 0:
|
||||
if prediction_target_time_index_h1 is not None and config.evaluation.save_plots:
|
||||
# Save the calculated index if valid and plots are enabled
|
||||
try:
|
||||
test_block_index = full_df.index[test_idx]
|
||||
seq_len = config.features.sequence_length
|
||||
first_horizon = config.features.forecast_horizon[0]
|
||||
torch.save(prediction_target_time_index_h1, prediction_target_time_index_h1_path)
|
||||
logger.debug(f"Saved prediction target time index for h1 to {prediction_target_time_index_h1_path}")
|
||||
except Exception as save_e:
|
||||
logger.warning(f"Failed to save prediction target time index file {prediction_target_time_index_h1_path}: {save_e}")
|
||||
elif prediction_target_time_index_h1_path.exists():
|
||||
# Remove outdated file if index is invalid/not calculated OR plots disabled
|
||||
logger.debug(f"Removing potentially outdated time index file: {prediction_target_time_index_h1_path}")
|
||||
try:
|
||||
prediction_target_time_index_h1_path.unlink()
|
||||
except OSError as e:
|
||||
logger.warning(f"Could not remove outdated prediction target index h1 file {prediction_target_time_index_h1_path}: {e}")
|
||||
|
||||
target_indices_h1 = test_idx + seq_len + first_horizon - 1
|
||||
|
||||
valid_target_indices_h1_mask = target_indices_h1 < len(full_df)
|
||||
valid_target_indices_h1 = target_indices_h1[valid_target_indices_h1_mask]
|
||||
|
||||
if len(valid_target_indices_h1) >= n_predictions: # Should be exactly n_predictions if no indices were out of bounds
|
||||
prediction_target_time_index_h1 = full_df.index[valid_target_indices_h1[:n_predictions]]
|
||||
if len(prediction_target_time_index_h1) != n_predictions:
|
||||
logger.warning(f"Fold {fold_id}: Calculated target time index length ({len(prediction_target_time_index_h1)}) "
|
||||
f"does not match prediction count ({n_predictions}). Plotting x-axis might be misaligned.")
|
||||
prediction_target_time_index_h1 = None
|
||||
|
||||
else:
|
||||
logger.warning(f"Fold {fold_id}: Cannot calculate target time index for h1; insufficient valid indices ({len(valid_target_indices_h1)} < {n_predictions}).")
|
||||
prediction_target_time_index_h1 = None
|
||||
|
||||
|
||||
# Save the calculated index if it's valid and evaluation plots are enabled
|
||||
if prediction_target_time_index_h1 is not None and not prediction_target_time_index_h1.empty and config.evaluation.save_plots:
|
||||
try:
|
||||
torch.save(prediction_target_time_index_h1, prediction_target_time_index_h1_path)
|
||||
logger.debug(f"Saved prediction target time index for h1 to {prediction_target_time_index_h1_path}")
|
||||
except Exception as save_e:
|
||||
logger.warning(f"Failed to save prediction target time index file {prediction_target_time_index_h1_path}: {save_e}")
|
||||
|
||||
elif prediction_target_time_index_h1_path.exists():
|
||||
try:
|
||||
prediction_target_time_index_h1_path.unlink()
|
||||
logger.debug("Removed outdated prediction target time index h1 file.")
|
||||
except OSError as e:
|
||||
logger.warning(f"Could not remove outdated prediction target index h1 file {prediction_target_time_index_h1_path}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Fold {fold_id}: Error calculating or saving target time index for plotting: {e}", exc_info=True)
|
||||
prediction_target_time_index_h1 = None
|
||||
else:
|
||||
logger.warning(f"Fold {fold_id}: Skipping target time index calculation (missing test_idx, forecast_horizon, or empty list).")
|
||||
if prediction_target_time_index_h1_path.exists():
|
||||
try:
|
||||
prediction_target_time_index_h1_path.unlink()
|
||||
logger.debug("Removed outdated prediction target time index h1 file as calculation was skipped.")
|
||||
except OSError as e:
|
||||
logger.warning(f"Could not remove outdated prediction target index h1 file {prediction_target_time_index_h1_path}: {e}")
|
||||
# --- End Index Calculation and Saving ---
|
||||
|
||||
|
||||
@ -324,6 +316,7 @@ def run_single_fold(
|
||||
y_true_scaled=all_targets_scaled, # Pass the (N, H) array
|
||||
y_pred_scaled=all_preds_scaled, # Pass the (N, H) array
|
||||
target_scaler=target_scaler,
|
||||
data_scaler=data_scaler,
|
||||
eval_config=config.evaluation,
|
||||
fold_num=fold_num, # Pass zero-based index
|
||||
output_dir=str(fold_output_dir),
|
||||
@ -331,7 +324,7 @@ def run_single_fold(
|
||||
# Pass the calculated index for the targets being plotted (h1 reference)
|
||||
prediction_time_index=prediction_target_time_index_h1, # Use the calculated index here (for h1)
|
||||
forecast_horizons=config.features.forecast_horizon, # Pass the list of horizons
|
||||
plot_title_prefix=f"CV Fold {fold_id}"
|
||||
plot_title_prefix=f"CV Fold {fold_id}",
|
||||
)
|
||||
save_results(fold_metrics, fold_output_dir / "test_metrics.json")
|
||||
else:
|
||||
@ -376,31 +369,14 @@ def run_single_fold(
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing prediction results for Fold {fold_id}: {e}", exc_info=True)
|
||||
|
||||
# --- Plot Loss Curve for Fold ---
|
||||
try:
|
||||
actual_log_dir = Path(pl_logger.log_dir) / pl_logger.name # Should be .../fold_XX/training_logs
|
||||
metrics_file_path = actual_log_dir / "metrics.csv"
|
||||
|
||||
if metrics_file_path.is_file():
|
||||
plot_loss_curve_from_csv(
|
||||
metrics_csv_path=metrics_file_path,
|
||||
output_path=fold_output_dir / "plots" / "loss_curve.png", # Save in plots subdir
|
||||
title=f"Fold {fold_id} Training Progression",
|
||||
train_loss_col='train_loss',
|
||||
val_loss_col='val_loss' # This function handles fallback
|
||||
)
|
||||
logger.info(f"Loss curve plot saved for Fold {fold_id} to {fold_output_dir / 'plots' / 'loss_curve.png'}.")
|
||||
else:
|
||||
logger.warning(f"Fold {fold_id}: Could not find metrics.csv at {metrics_file_path} for loss curve plot.")
|
||||
except Exception as e:
|
||||
logger.error(f"Fold {fold_id}: Failed to generate loss curve plot: {e}", exc_info=True)
|
||||
# --- End Loss Curve Plotting ---
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred during Fold {fold_id} pipeline: {e}", exc_info=True)
|
||||
# Ensure paths are None if an error occurs before they are set
|
||||
if saved_model_path is None: saved_model_path = None
|
||||
if saved_target_scaler_path is None: saved_target_scaler_path = None
|
||||
if saved_data_scaler_path is None: saved_data_scaler_path = None # Added check
|
||||
if saved_input_size_path is None: saved_input_size_path = None
|
||||
if saved_config_path is None: saved_config_path = None
|
||||
|
||||
|
||||
finally:
|
||||
@ -413,11 +389,39 @@ def run_single_fold(
|
||||
# Delete loaders explicitly if they might hold references
|
||||
del train_loader, val_loader, test_loader
|
||||
|
||||
# --- Plot Loss Curve for Fold ---
|
||||
if pl_logger and hasattr(pl_logger, 'log_dir') and pl_logger.log_dir: # Check if logger exists and has log_dir
|
||||
try:
|
||||
# Use the logger's log_dir directly, it already includes the 'name' segment
|
||||
actual_log_dir = Path(pl_logger.log_dir) # FIX: Remove appending pl_logger.name
|
||||
metrics_file_path = actual_log_dir / "metrics.csv"
|
||||
|
||||
if metrics_file_path.is_file():
|
||||
plot_loss_curve_from_csv(
|
||||
metrics_csv_path=metrics_file_path,
|
||||
# Save plot inside the specific fold's plot directory
|
||||
output_path=fold_output_dir / "plots" / "loss_curve.png",
|
||||
title=f"Fold {fold_id} Training Progression",
|
||||
train_loss_col='train_loss', # Ensure these column names match your CSVLogger output
|
||||
val_loss_col='val_loss' # Ensure these column names match your CSVLogger output
|
||||
)
|
||||
logger.info(f"Loss curve plot saved for Fold {fold_id} to {fold_output_dir / 'plots' / 'loss_curve.png'}.")
|
||||
else:
|
||||
logger.warning(f"Fold {fold_id}: Could not find metrics.csv at {metrics_file_path} for loss curve plot.")
|
||||
except AttributeError:
|
||||
logger.warning(f"Fold {fold_id}: Could not plot loss curve, CSVLogger object or log_dir attribute missing.")
|
||||
except Exception as e:
|
||||
logger.error(f"Fold {fold_id}: Failed to generate loss curve plot: {e}", exc_info=True)
|
||||
else:
|
||||
logger.warning(f"Fold {fold_id}: Skipping loss curve plot generation as CSVLogger was not properly initialized or log_dir is missing.")
|
||||
# --- End Loss Curve Plotting ---
|
||||
|
||||
fold_end_time = time.perf_counter()
|
||||
logger.info(f"--- Finished Fold {fold_id} in {fold_end_time - fold_start_time:.2f} seconds ---")
|
||||
pass
|
||||
|
||||
# Return the calculated fold metrics, best validation score, and saved artifact paths
|
||||
return fold_metrics, best_val_score, saved_model_path, saved_target_scaler_path, saved_input_size_path, saved_config_path
|
||||
return fold_metrics, best_val_score, saved_model_path, saved_target_scaler_path, saved_data_scaler_path, saved_input_size_path, saved_config_path
|
||||
|
||||
|
||||
# --- Main Training & Evaluation Function ---
|
||||
@ -450,8 +454,8 @@ def run_training_pipeline(config: MainConfig, output_base_dir: Path):
|
||||
sys.exit(1)
|
||||
|
||||
for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split()):
|
||||
# Unpack the two new return values from run_single_fold
|
||||
fold_metrics, best_val_score, saved_model_path, saved_target_scaler_path, _input_size_path, _config_path = run_single_fold(
|
||||
# Unpack the return values from run_single_fold, including the new data_scaler path
|
||||
fold_metrics, best_val_score, saved_model_path, saved_target_scaler_path, saved_data_scaler_path, _input_size_path, _config_path = run_single_fold(
|
||||
fold_num=fold_num,
|
||||
train_idx=train_idx,
|
||||
val_idx=val_idx,
|
||||
@ -503,7 +507,7 @@ def run_training_pipeline(config: MainConfig, output_base_dir: Path):
|
||||
classic_output_dir = output_base_dir / "classic_run" # Define dir for logging path
|
||||
try:
|
||||
# Call the original classic training function directly
|
||||
classic_metrics = run_classic_training(
|
||||
classic_metrics = run_model_training(
|
||||
config=config,
|
||||
full_df=df,
|
||||
output_base_dir=output_base_dir # It creates classic_run subdir internally
|
||||
@ -597,7 +601,7 @@ def run():
|
||||
|
||||
# --- Configuration Loading ---
|
||||
try:
|
||||
config = load_config(config_path)
|
||||
config = load_config(config_path, MainConfig)
|
||||
except Exception:
|
||||
# Error already logged in load_config
|
||||
sys.exit(1)
|
||||
@ -629,4 +633,7 @@ def run():
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
raise DeprecationWarning(
|
||||
"This was the intial class for training, is not maintained!\n Exiting...."
|
||||
)
|
||||
exit(-9999)
|
@ -11,15 +11,31 @@ max_rate: 1.0
|
||||
|
||||
# The length of the time window (in hours) for which the optimization is run
|
||||
# This should match the forecast horizon of the models being evaluated.
|
||||
optimization_horizon_hours: 24
|
||||
optimization_horizon_hours: 12
|
||||
|
||||
# Output directory for the optimization results
|
||||
output_dir: 'output/optimization_results'
|
||||
|
||||
# List of models to evaluate. Each entry includes the path to the model's
|
||||
# forecast output file and the path to the forecasting config used for that model.
|
||||
models:
|
||||
- name: "Model_A"
|
||||
forecast_path: "path/to/model_a_forecast_output.csv" # Path to the file containing forecast time points and prices
|
||||
forecast_config_path: "configs/model_a_forecasting_config.yaml" # Path to the forecasting config used for this model
|
||||
- name: "Model_B"
|
||||
forecast_path: "path/to/model_b_forecast_output.csv"
|
||||
forecast_config_path: "configs/model_b_forecasting_config.yaml"
|
||||
# Add more models here
|
||||
- name: "LSTM-Single-Model"
|
||||
type: "model"
|
||||
# Path to the saved PyTorch model file (.ckpt for type='model') or the ensemble definition JSON file (.json for type='ensemble').
|
||||
model_path: 'output/classic/best_trial_num90/classic_run/checkpoints/best_classic_model.ckpt'
|
||||
# Path to the forecasting config (YAML) used for this model training (or for the best trial in an ensemble)
|
||||
model_config_path: 'output/classic/best_trial_num90/classic_run/config.yaml'
|
||||
# Path to the target scaler file for the single model (or will be loaded per fold for ensemble).
|
||||
target_scaler_path: 'output/classic/best_trial_num90/classic_run/classic_target_scaler.pt'
|
||||
# Path to the data scaler file for the single model (or will be loaded per fold for ensemble).
|
||||
data_scaler_path: 'output/classic/best_trial_num90/classic_run/classic_data_scaler.pt'
|
||||
# Path to the input size file for the single model (or will be loaded per fold for ensemble).
|
||||
input_size_path: 'output/classic/best_trial_num90/classic_run/classic_input_size.pt'
|
||||
- name: "LSTM-Ensemble"
|
||||
type: "ensemble"
|
||||
model_path: "output/ensemble/lstm_price_forecast_best_ensemble.json"
|
||||
model_config_path: "output/ensemble/lstm_price_forecast_best_config.yaml"
|
||||
target_scaler_path: None
|
||||
data_scaler_path: None
|
||||
input_size_path: None
|
||||
|
||||
|
861
optim_run.py
861
optim_run.py
File diff suppressed because it is too large
Load Diff
626
optim_run_daa.py
Normal file
626
optim_run_daa.py
Normal file
@ -0,0 +1,626 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
import logging
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from pathlib import Path
|
||||
from datetime import timedelta
|
||||
|
||||
from pydantic import BaseModel
|
||||
from tqdm import tqdm
|
||||
|
||||
# Import Forecasting Providers
|
||||
from forecasting_model.utils.helper import load_config
|
||||
from optimizer.forecasting.base import ForecastProvider
|
||||
from optimizer.forecasting.single_model import SingleModelProvider
|
||||
from optimizer.forecasting.ensemble import EnsembleProvider
|
||||
|
||||
from optimizer.optimization.battery import solve_battery_optimization_hourly
|
||||
from optimizer.utils.optimizer_config_model import OptimizationRunConfig
|
||||
from forecasting_model.utils.forecast_config_model import MainConfig, FeatureConfig #, DataConfig
|
||||
|
||||
# Import the loading functions
|
||||
from optimizer.utils.model_io import load_single_model_artifact, load_ensemble_artifact
|
||||
|
||||
# Feature Engineering Import
|
||||
from forecasting_model import engineer_features, load_raw_data
|
||||
|
||||
from typing import Dict, Any, Optional, Type, List, Tuple
|
||||
|
||||
# Silence overly verbose libraries if needed
|
||||
mpl_logger = logging.getLogger('matplotlib')
|
||||
mpl_logger.setLevel(logging.WARNING)
|
||||
pil_logger = logging.getLogger('PIL')
|
||||
pil_logger.setLevel(logging.WARNING)
|
||||
pl_logger = logging.getLogger('pytorch_lightning')
|
||||
pl_logger.setLevel(logging.WARNING)
|
||||
|
||||
# --- Basic Logging Setup ---
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)-7s - %(message)s',
|
||||
datefmt='%H:%M:%S')
|
||||
# Get the root logger
|
||||
logger = logging.getLogger()
|
||||
|
||||
def load_optimization_config(config_path: str) -> Type[BaseModel] | None:
|
||||
"""Loads the main optimization configuration from a YAML file."""
|
||||
return load_config(Path(config_path), OptimizationRunConfig)
|
||||
|
||||
def load_main_forecasting_config(config_path: str) -> Type[BaseModel] | None:
|
||||
"""Loads the main forecasting configuration from a YAML file."""
|
||||
return load_config(Path(config_path), MainConfig)
|
||||
|
||||
def check_and_adjust_decisions_given_state(decision_p: float, current_state_b: float, max_capacity_b: float) -> tuple[float, float]:
|
||||
"""Clips decision based on battery capacity (0 to max_capacity)."""
|
||||
potential_state = current_state_b + decision_p
|
||||
new_state = max(0.0, min(potential_state, max_capacity_b))
|
||||
valid_decision = new_state - current_state_b
|
||||
return valid_decision, new_state
|
||||
|
||||
def simulate_daily_schedule(
|
||||
planned_p_schedule: np.ndarray,
|
||||
actual_prices_day: np.ndarray,
|
||||
initial_b_state: float,
|
||||
max_capacity: float,
|
||||
max_rate: float # Although max_rate isn't used in check_and_adjust, it's conceptually part of battery limits
|
||||
) -> Tuple[float, float, List[float], List[float]]:
|
||||
"""
|
||||
Simulates the execution of a planned 24h schedule against actual prices.
|
||||
|
||||
Args:
|
||||
planned_p_schedule: Array of 24 planned power decisions (charge positive).
|
||||
actual_prices_day: Array of 24 actual prices for the day.
|
||||
initial_b_state: Battery state at the beginning of the day (00:00).
|
||||
max_capacity: Maximum battery capacity.
|
||||
max_rate: Maximum charge/discharge rate (implicitly handled by optimizer output).
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- total_daily_profit: The profit realized over the 24 hours.
|
||||
- final_b_state_day_end: Battery state at the end of the day (after the 23:00 action).
|
||||
- executed_p_hourly: List of the actual power decisions made each hour.
|
||||
- actual_b_hourly: List of the battery state *before* each hourly action.
|
||||
"""
|
||||
if len(planned_p_schedule) != 24 or len(actual_prices_day) != 24:
|
||||
raise ValueError(f"Inputs must have length 24. Got schedule: {len(planned_p_schedule)}, prices: {len(actual_prices_day)}")
|
||||
|
||||
current_b_state = initial_b_state
|
||||
total_daily_profit = 0.0
|
||||
executed_p_hourly = []
|
||||
actual_b_hourly = [] # State at the beginning of each hour
|
||||
|
||||
for h in range(24):
|
||||
planned_p_decision = planned_p_schedule[h]
|
||||
actual_price_hour = actual_prices_day[h]
|
||||
|
||||
actual_b_hourly.append(current_b_state) # Store state *before* action
|
||||
|
||||
# Check feasibility against current state and capacity
|
||||
executed_p_decision, next_b_state = check_and_adjust_decisions_given_state(
|
||||
planned_p_decision, current_b_state, max_capacity
|
||||
)
|
||||
|
||||
# Calculate profit/cost for the executed action (Buy (+P) is cost, Sell (-P) is revenue)
|
||||
|
||||
# TODO: We might want to check this -> in my mind, DAA prices are *binding* so what happens if we cannot charge?
|
||||
# Do we have to buy/pay anyways? -> Do this. planned_p_decision > executed_p_decision when calculating prices
|
||||
hourly_profit = -1 * planned_p_decision * actual_price_hour
|
||||
total_daily_profit += hourly_profit
|
||||
|
||||
# Store executed action and update state for the next hour
|
||||
executed_p_hourly.append(executed_p_decision)
|
||||
current_b_state = next_b_state
|
||||
|
||||
# final_b_state_day_end = current_b_state (state after the last hour's action)
|
||||
return total_daily_profit, current_b_state, executed_p_hourly, actual_b_hourly
|
||||
|
||||
|
||||
# --- Main Execution Logic ---
|
||||
if __name__ == "__main__":
|
||||
logger.info("Starting DAA battery optimization evaluation using provider-specific engineered features.")
|
||||
|
||||
# --- Load Main Optimization Config ---
|
||||
optimization_config_path = "optim_config.yaml"
|
||||
optimization_config = load_optimization_config(optimization_config_path)
|
||||
|
||||
if optimization_config is None:
|
||||
logger.critical("Failed to load main optimization config. Exiting.")
|
||||
exit(1)
|
||||
|
||||
if not optimization_config.models:
|
||||
logger.critical("No models or ensembles specified in optimization config. Exiting.")
|
||||
exit(1)
|
||||
|
||||
# Fixed DAA horizon
|
||||
OPTIMIZATION_HORIZON_HOURS = 24
|
||||
|
||||
# --- Load Original Historical Data ---
|
||||
try:
|
||||
first_model_config_path_str = optimization_config.models[0].model_config_path
|
||||
main_forecasting_config_for_data = load_main_forecasting_config(first_model_config_path_str)
|
||||
if main_forecasting_config_for_data is None:
|
||||
raise ValueError(f"Failed to load forecasting config ({first_model_config_path_str}) to get data path.")
|
||||
|
||||
historical_data_config = main_forecasting_config_for_data.data
|
||||
target_col = historical_data_config.target_col # Assume consistent target for now
|
||||
|
||||
logger.info(f"Loading original historical data from: {historical_data_config.data_path}")
|
||||
full_historical_df = load_raw_data(historical_data_config) # load_raw_data handles initial cleaning
|
||||
|
||||
if full_historical_df.empty or target_col not in full_historical_df.columns:
|
||||
raise ValueError(f"Loaded original historical data is empty or missing target column '{target_col}'.")
|
||||
|
||||
if not isinstance(full_historical_df.index, pd.DatetimeIndex):
|
||||
raise TypeError("Loaded historical data must have a DatetimeIndex.")
|
||||
if full_historical_df.index.freq is None and historical_data_config.expected_frequency:
|
||||
logger.warning(f"Data index frequency not set, attempting to set to '{historical_data_config.expected_frequency}'.")
|
||||
try:
|
||||
full_historical_df = full_historical_df.asfreq(historical_data_config.expected_frequency)
|
||||
if full_historical_df[target_col].isnull().any():
|
||||
logger.warning(f"NaNs found after setting frequency. Applying ffill().bfill() to '{target_col}'.")
|
||||
full_historical_df[target_col] = full_historical_df[target_col].ffill().bfill()
|
||||
if full_historical_df[target_col].isnull().any():
|
||||
raise ValueError(f"NaNs remain after filling '{target_col}' post-asfreq.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to set frequency to {historical_data_config.expected_frequency}: {e}", exc_info=True)
|
||||
logger.warning("Proceeding without guaranteed index frequency.")
|
||||
|
||||
logger.info(f"Original historical data loaded. Shape: {full_historical_df.shape}, Range: {full_historical_df.index.min()} to {full_historical_df.index.max()}, Freq: {full_historical_df.index.freq}")
|
||||
|
||||
except Exception as e:
|
||||
logger.critical(f"Failed during initial data loading: {e}", exc_info=True)
|
||||
exit(1)
|
||||
|
||||
|
||||
# --- Load Models, Engineer Features per Provider, Instantiate Providers ---
|
||||
providers_data: Dict[str, Dict[str, Any]] = {} # Stores {'provider_name': {'provider': instance, 'df': engineered_df, 'first_valid_forecast_time': timestamp}}
|
||||
|
||||
for model_eval_config in optimization_config.models:
|
||||
provider_name = model_eval_config.name
|
||||
artifact_type = model_eval_config.type
|
||||
artifact_path = Path(model_eval_config.model_path)
|
||||
model_config_path = Path(model_eval_config.model_config_path)
|
||||
logger.info(f"--- Processing Provider: {provider_name} ({artifact_type}) ---")
|
||||
|
||||
provider_instance: Optional[ForecastProvider] = None
|
||||
provider_feature_config: Optional[FeatureConfig] = None
|
||||
provider_target_scaler: Optional[Any] = None
|
||||
provider_data_scaler: Optional[Any] = None
|
||||
|
||||
try:
|
||||
# --- Load Model/Ensemble Artifact (Similar to original script) ---
|
||||
if artifact_type == 'model':
|
||||
logger.info(f"Loading single model artifact: {provider_name}")
|
||||
# (Load artifact details as before...)
|
||||
target_scaler_path = Path(model_eval_config.target_scaler_path) if model_eval_config.target_scaler_path else None
|
||||
data_scaler_path = Path(model_eval_config.data_scaler_path) if model_eval_config.data_scaler_path else None
|
||||
input_size_path = Path(model_eval_config.input_size_path) if model_eval_config.input_size_path else None
|
||||
|
||||
loaded_artifact_info = load_single_model_artifact(
|
||||
model_path=artifact_path,
|
||||
config_path=model_config_path,
|
||||
input_size_path=input_size_path,
|
||||
target_scaler_path=target_scaler_path,
|
||||
data_scaler_path=data_scaler_path
|
||||
)
|
||||
if not loaded_artifact_info: raise ValueError("load_single_model_artifact returned None.")
|
||||
current_main_config = loaded_artifact_info['main_forecasting_config']
|
||||
if current_main_config.data.target_col != target_col: raise ValueError(f"Target column mismatch.")
|
||||
provider_feature_config = loaded_artifact_info['feature_config']
|
||||
provider_target_scaler = loaded_artifact_info['target_scaler']
|
||||
provider_data_scaler = loaded_artifact_info['data_scaler']
|
||||
if provider_data_scaler is None: logger.warning(f"Data scaler not found for '{provider_name}'.")
|
||||
|
||||
provider_instance = SingleModelProvider(
|
||||
model_instance=loaded_artifact_info['model_instance'],
|
||||
feature_config=provider_feature_config,
|
||||
target_scaler=provider_target_scaler,
|
||||
data_scaler=provider_data_scaler
|
||||
)
|
||||
if 1 not in provider_instance.get_forecast_horizons(): # Still useful check for interpolation
|
||||
raise ValueError(f"Model must forecast horizon 1. Horizons: {provider_instance.get_forecast_horizons()}")
|
||||
|
||||
elif artifact_type == 'ensemble':
|
||||
logger.info(f"Loading ensemble artifact: {provider_name}")
|
||||
# (Load artifact details as before...)
|
||||
hpo_base_output_dir_for_ensemble = artifact_path.parent
|
||||
loaded_artifact_info = load_ensemble_artifact(
|
||||
ensemble_definition_path=artifact_path,
|
||||
hpo_base_output_dir=hpo_base_output_dir_for_ensemble
|
||||
)
|
||||
if not loaded_artifact_info or not loaded_artifact_info.get('fold_artifacts'): raise ValueError("load_ensemble_artifact failed.")
|
||||
if not loaded_artifact_info['fold_artifacts'][0].get('feature_config'): raise ValueError(f"Missing feature_config in ensemble fold.")
|
||||
provider_feature_config = loaded_artifact_info['fold_artifacts'][0]['feature_config']
|
||||
|
||||
provider_instance = EnsembleProvider(
|
||||
fold_artifacts=loaded_artifact_info['fold_artifacts'],
|
||||
ensemble_method=loaded_artifact_info['ensemble_method'],
|
||||
)
|
||||
if 1 not in provider_instance.get_forecast_horizons():
|
||||
raise ValueError("Ensemble has no folds that forecast horizon 1.")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown artifact type '{artifact_type}'.")
|
||||
|
||||
# --- Feature Engineering for this Provider ---
|
||||
logger.info(f"Engineering features specifically for '{provider_name}'.")
|
||||
if provider_feature_config is None: raise RuntimeError("Could not determine feature config.")
|
||||
|
||||
engineered_df_provider = engineer_features(
|
||||
full_historical_df.copy(),
|
||||
target_col=target_col,
|
||||
feature_config=provider_feature_config
|
||||
)
|
||||
first_valid_index_dt = engineered_df_provider.first_valid_index()
|
||||
if pd.isna(first_valid_index_dt): raise ValueError("Engineered DF contains only NaNs.")
|
||||
logger.info(f"Feature engineering for '{provider_name}' complete. First valid index: {first_valid_index_dt}")
|
||||
|
||||
# --- Determine First Timestamp usable for *Starting* a Forecast ---
|
||||
# This is the earliest timestamp where the model has enough historical features.
|
||||
# The model will use data *up to* this timestamp to predict the *next* steps.
|
||||
provider_seq_len = provider_instance.get_required_lookback()
|
||||
try:
|
||||
first_valid_loc = engineered_df_provider.index.get_loc(first_valid_index_dt)
|
||||
first_possible_input_end_loc = first_valid_loc + provider_seq_len - 1
|
||||
if first_possible_input_end_loc >= len(engineered_df_provider.index):
|
||||
raise ValueError(f"Not enough data ({len(engineered_df_provider.index)} starting {first_valid_index_dt}) for lookback {provider_seq_len}.")
|
||||
# This is the last timestamp included in the *input* for the *first possible* forecast
|
||||
first_usable_input_timestamp = engineered_df_provider.index[first_possible_input_end_loc]
|
||||
logger.info(f"First usable timestamp as input end (t=0) for '{provider_name}': {first_usable_input_timestamp} (needs seq_len={provider_seq_len})")
|
||||
except (KeyError, IndexError) as e:
|
||||
raise ValueError(f"Error determining first usable input time for {provider_name}: {e}")
|
||||
|
||||
|
||||
providers_data[provider_name] = {
|
||||
'provider': provider_instance,
|
||||
'df': engineered_df_provider,
|
||||
'first_usable_input_timestamp': first_usable_input_timestamp
|
||||
}
|
||||
logger.info(f"Successfully processed provider '{provider_name}'.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process provider '{provider_name}': {e}. Skipping.", exc_info=True)
|
||||
if provider_name in providers_data: del providers_data[provider_name]
|
||||
continue
|
||||
|
||||
# --- End Loading/Preparing Providers ---
|
||||
if not providers_data:
|
||||
logger.critical("No forecast providers successfully created. Exiting.")
|
||||
exit(1)
|
||||
|
||||
successfully_loaded_provider_names = list(providers_data.keys())
|
||||
logger.info(f"Successfully prepared {len(successfully_loaded_provider_names)} providers: {successfully_loaded_provider_names}")
|
||||
|
||||
# --- Determine Simulation Date Range ---
|
||||
# Find the latest "first usable input timestamp" across all providers
|
||||
global_first_usable_input_time = max(p_data['first_usable_input_timestamp'] for p_data in providers_data.values())
|
||||
|
||||
# The first *decision day* (Day D) must be such that the forecast input time (e.g., D 23:00)
|
||||
# is at or after global_first_usable_input_time.
|
||||
# The first *target day* (Day D+1) must have 24h of actual prices available.
|
||||
first_decision_day = (global_first_usable_input_time + timedelta(hours=1)).normalize() # Start of Day D+1
|
||||
first_target_day_start = first_decision_day
|
||||
|
||||
# The last *target day* (Day D+1) must end within the historical data
|
||||
last_target_day_end = full_historical_df.index.max().normalize() + timedelta(hours=23) # End of the last full day in data
|
||||
last_target_day_start = last_target_day_end.normalize()
|
||||
last_decision_day = last_target_day_start - timedelta(days=1)
|
||||
|
||||
# Create the range of target day start times
|
||||
simulation_target_days = pd.date_range(
|
||||
start=first_target_day_start,
|
||||
end=last_target_day_start,
|
||||
freq='D' # Daily frequency
|
||||
)
|
||||
|
||||
if simulation_target_days.empty or len(simulation_target_days) < 1:
|
||||
logger.critical(f"Not enough data for DAA simulation. First possible target day: {first_target_day_start}, Last target day: {last_target_day_start}. Check data length, lookbacks.")
|
||||
exit(1)
|
||||
|
||||
logger.info(f"Evaluating DAA strategy for target days from {simulation_target_days.min().strftime('%Y-%m-%d')} to {simulation_target_days.max().strftime('%Y-%m-%d')} ({len(simulation_target_days)} days).")
|
||||
|
||||
# --- DAA Evaluation Loop ---
|
||||
daily_results_list = []
|
||||
# Initialize battery states for baseline and each provider
|
||||
# We need to store the state *at the start* of the target day
|
||||
current_b_states = {
|
||||
'baseline': optimization_config.initial_b,
|
||||
**{name: optimization_config.initial_b for name in successfully_loaded_provider_names}
|
||||
}
|
||||
|
||||
for target_day_start in tqdm(simulation_target_days, desc="Simulating DAA Days"):
|
||||
decision_day = target_day_start - timedelta(days=1)
|
||||
target_day_end = target_day_start + timedelta(hours=OPTIMIZATION_HORIZON_HOURS - 1)
|
||||
logger.debug(f"Processing Target Day: {target_day_start.strftime('%Y-%m-%d')}")
|
||||
|
||||
# Timestamp for forecast generation input (e.g., end of decision day)
|
||||
# This is t=0 for the forecast predicting target_day_start onwards
|
||||
forecast_input_end_time = target_day_start - timedelta(hours=1) # e.g., Day D 23:00
|
||||
|
||||
# Check if forecast_input_end_time is valid for all providers
|
||||
if forecast_input_end_time < global_first_usable_input_time:
|
||||
logger.warning(f"Skipping target day {target_day_start.strftime('%Y-%m-%d')}: Forecast input time {forecast_input_end_time} is before required {global_first_usable_input_time}.")
|
||||
continue
|
||||
|
||||
# Get actual prices for the target day
|
||||
try:
|
||||
actual_prices_target_day = full_historical_df.loc[target_day_start:target_day_end, target_col].values
|
||||
if len(actual_prices_target_day) != OPTIMIZATION_HORIZON_HOURS:
|
||||
logger.warning(f"Skipping target day {target_day_start.strftime('%Y-%m-%d')}: Actual prices slice length {len(actual_prices_target_day)} != {OPTIMIZATION_HORIZON_HOURS}. Check data continuity.")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipping target day {target_day_start.strftime('%Y-%m-%d')}: Could not retrieve actual prices. Error: {e}")
|
||||
continue
|
||||
|
||||
# Get last actual price at forecast input time (needed for potential interpolation)
|
||||
try:
|
||||
last_actual_price_anchor = full_historical_df.loc[forecast_input_end_time, target_col]
|
||||
if pd.isna(last_actual_price_anchor):
|
||||
logger.warning(f"Last actual price at {forecast_input_end_time} is NaN. Cannot use as anchor. Skipping day.")
|
||||
continue
|
||||
except KeyError:
|
||||
logger.error(f"Cannot get last actual price at timestamp {forecast_input_end_time}. Skipping day.")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting last actual price at {forecast_input_end_time}: {e}. Skipping day.")
|
||||
continue
|
||||
|
||||
# --- Store results for this day ---
|
||||
day_results = {
|
||||
'decision_day': decision_day,
|
||||
'target_day_start': target_day_start,
|
||||
'actual_prices': actual_prices_target_day.tolist()
|
||||
}
|
||||
|
||||
# --- Baseline Optimization (Perfect Foresight for Target Day) ---
|
||||
logger.debug(f"Running baseline optimization for target day {target_day_start.strftime('%Y-%m-%d')}")
|
||||
initial_b_state_baseline = current_b_states['baseline']
|
||||
try:
|
||||
# 1. Optimize using actual prices
|
||||
baseline_status, _, baseline_P_perfect, baseline_B_perfect = solve_battery_optimization_hourly(
|
||||
actual_prices_target_day, initial_b_state_baseline, optimization_config.max_capacity, optimization_config.max_rate
|
||||
)
|
||||
if baseline_status.lower() != "optimal":
|
||||
logger.warning(f"Baseline optimization non-optimal ({baseline_status}) for target day {target_day_start.strftime('%Y-%m-%d')}")
|
||||
|
||||
# 2. Simulate the resulting schedule
|
||||
if baseline_P_perfect is not None:
|
||||
(daily_profit_baseline,
|
||||
final_b_state_baseline,
|
||||
executed_p_baseline,
|
||||
actual_b_baseline) = simulate_daily_schedule(
|
||||
baseline_P_perfect, actual_prices_target_day, initial_b_state_baseline,
|
||||
optimization_config.max_capacity, optimization_config.max_rate)
|
||||
|
||||
day_results['baseline'] = {
|
||||
"status": baseline_status,
|
||||
"daily_profit": daily_profit_baseline,
|
||||
"planned_P_schedule": baseline_P_perfect.tolist(),
|
||||
"executed_P_schedule": executed_p_baseline,
|
||||
"actual_B_schedule_start": actual_b_baseline, # State at start of each hour
|
||||
"final_B_state": final_b_state_baseline
|
||||
}
|
||||
current_b_states['baseline'] = final_b_state_baseline # Update state for next day
|
||||
logger.debug(f"Baseline daily profit: {daily_profit_baseline:.2f}, Final B: {final_b_state_baseline:.2f}")
|
||||
else:
|
||||
raise ValueError("Baseline optimization failed to produce a schedule.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Baseline simulation failed for target day {target_day_start.strftime('%Y-%m-%d')}: {e}", exc_info=True)
|
||||
day_results['baseline'] = {"status": "Error", "daily_profit": np.nan, "final_B_state": initial_b_state_baseline}
|
||||
# Keep previous state if error occurs
|
||||
current_b_states['baseline'] = initial_b_state_baseline
|
||||
|
||||
|
||||
# --- Forecast Provider Optimizations ---
|
||||
for provider_name in successfully_loaded_provider_names:
|
||||
provider_data = providers_data[provider_name]
|
||||
provider_instance = provider_data['provider']
|
||||
engineered_df_provider = provider_data['df']
|
||||
initial_b_state_provider = current_b_states[provider_name]
|
||||
|
||||
logger.debug(f"Running DAA forecast and optimization for provider '{provider_name}'")
|
||||
|
||||
# 1. Generate 24h forecast for the target day
|
||||
try:
|
||||
forecast_prices_input = provider_instance.get_forecast(
|
||||
engineered_df=engineered_df_provider,
|
||||
forecast_start_time=forecast_input_end_time, # t=0 for forecast (end of Day D)
|
||||
optimization_horizon_hours=OPTIMIZATION_HORIZON_HOURS,
|
||||
last_actual_price=last_actual_price_anchor # Price at t=0
|
||||
)
|
||||
if forecast_prices_input is None: raise ValueError("Provider returned None forecast.")
|
||||
if not isinstance(forecast_prices_input, np.ndarray) or forecast_prices_input.shape != (OPTIMIZATION_HORIZON_HOURS,):
|
||||
raise ValueError(f"Forecast shape mismatch: Expected {(OPTIMIZATION_HORIZON_HOURS,)}, Got {forecast_prices_input.shape}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Forecast generation failed for '{provider_name}', target day {target_day_start.strftime('%Y-%m-%d')}: {e}", exc_info=True)
|
||||
day_results[provider_name] = {"status": "Forecast Error", "daily_profit": np.nan, "final_B_state": initial_b_state_provider}
|
||||
current_b_states[provider_name] = initial_b_state_provider # Keep previous state
|
||||
continue
|
||||
|
||||
# 2. Run Optimization with Forecast Prices
|
||||
try:
|
||||
model_status, _, model_P_planned, model_B_planned = solve_battery_optimization_hourly(
|
||||
forecast_prices_input, initial_b_state_provider, optimization_config.max_capacity, optimization_config.max_rate
|
||||
)
|
||||
if model_status.lower() != "optimal":
|
||||
logger.warning(f"Provider '{provider_name}' optimization non-optimal ({model_status}) for target day {target_day_start.strftime('%Y-%m-%d')}")
|
||||
|
||||
# 3. Simulate the planned schedule against actual prices
|
||||
if model_P_planned is not None:
|
||||
(daily_profit_provider,
|
||||
final_b_state_provider,
|
||||
executed_p_provider,
|
||||
actual_b_provider) = simulate_daily_schedule(
|
||||
model_P_planned, actual_prices_target_day, initial_b_state_provider,
|
||||
optimization_config.max_capacity, optimization_config.max_rate)
|
||||
|
||||
day_results[provider_name] = {
|
||||
"status": model_status,
|
||||
"daily_profit": daily_profit_provider,
|
||||
"planned_P_schedule": model_P_planned.tolist(),
|
||||
"executed_P_schedule": executed_p_provider,
|
||||
"actual_B_schedule_start": actual_b_provider,
|
||||
"final_B_state": final_b_state_provider,
|
||||
"forecast_prices": forecast_prices_input.tolist() # Store forecast for analysis
|
||||
}
|
||||
current_b_states[provider_name] = final_b_state_provider # Update state
|
||||
logger.debug(f"Provider '{provider_name}' daily profit: {daily_profit_provider:.2f}, Final B: {final_b_state_provider:.2f}")
|
||||
else:
|
||||
raise ValueError("Provider optimization failed to produce a schedule.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Optimization or Simulation failed for '{provider_name}', target day {target_day_start.strftime('%Y-%m-%d')}: {e}", exc_info=True)
|
||||
day_results[provider_name] = {"status": "Optimization/Simulation Error", "daily_profit": np.nan, "final_B_state": initial_b_state_provider}
|
||||
current_b_states[provider_name] = initial_b_state_provider # Keep previous state
|
||||
|
||||
# Append results for this day
|
||||
daily_results_list.append(day_results)
|
||||
logger.debug(f"Finished processing target day: {target_day_start.strftime('%Y-%m-%d')}")
|
||||
|
||||
logger.info(f"Finished simulating {len(simulation_target_days)} DAA target days.")
|
||||
|
||||
# --- Post-processing and Plotting ---
|
||||
logger.info("Starting DAA results analysis and plotting.")
|
||||
|
||||
if not daily_results_list:
|
||||
logger.warning("No daily results were collected. Skipping analysis.")
|
||||
exit(0)
|
||||
|
||||
# Convert results list to a DataFrame
|
||||
flat_results = []
|
||||
for day_res in daily_results_list:
|
||||
base_info = {'target_day_start': day_res['target_day_start']}
|
||||
# Add baseline results
|
||||
flat_results.append({**base_info, 'type': 'baseline', **day_res.get('baseline', {})})
|
||||
# Add provider results
|
||||
for provider_name in successfully_loaded_provider_names:
|
||||
provider_res = day_res.get(provider_name, {"status": "Missing Result", "daily_profit": np.nan})
|
||||
flat_results.append({**base_info, 'type': provider_name, **provider_res})
|
||||
|
||||
results_df = pd.DataFrame(flat_results)
|
||||
results_df['target_day_start'] = pd.to_datetime(results_df['target_day_start'])
|
||||
results_df.set_index('target_day_start', inplace=True) # Index by the day being evaluated
|
||||
|
||||
# Calculate Cumulative Profit
|
||||
profit_pivot = results_df.pivot_table(index=results_df.index, columns='type', values='daily_profit')
|
||||
cumulative_profit_df = profit_pivot.cumsum()
|
||||
|
||||
# --- Plotting ---
|
||||
logger.info("Generating DAA plots.")
|
||||
output_dir = Path(optimization_config.output_dir) if optimization_config.output_dir else Path(".")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Saving plots to: {output_dir.resolve()}")
|
||||
|
||||
# Plot 1: Daily Profit Over Time
|
||||
if not profit_pivot.empty and not profit_pivot.isnull().all().all():
|
||||
fig1, ax = plt.subplots(figsize=(15, 7))
|
||||
plot_data = profit_pivot.dropna(axis=1, how='all')
|
||||
if not plot_data.empty:
|
||||
sns.lineplot(data=plot_data, ax=ax, dashes=False) # Ensure solid lines for clarity
|
||||
ax.set_xlabel('Target Day')
|
||||
ax.set_ylabel('Daily Profit (€)')
|
||||
ax.set_title('DAA Strategy: Daily Profit Comparison')
|
||||
ax.legend(title='Strategy')
|
||||
ax.axhline(0, color='grey', linestyle='--', linewidth=0.8) # Add zero line
|
||||
plt.grid()
|
||||
plt.tight_layout()
|
||||
plot_path = output_dir / "daily_profit_daa.png"
|
||||
plt.savefig(plot_path)
|
||||
logger.info(f"Daily Profit plot saved to {plot_path}")
|
||||
plt.close(fig1)
|
||||
else:
|
||||
logger.warning("Daily profit data is all NaN after filtering. Skipping Daily Profit plot.")
|
||||
else:
|
||||
logger.warning("No valid data available to plot Daily Profit.")
|
||||
|
||||
|
||||
# Plot 2: Cumulative Profit Over Time
|
||||
if not cumulative_profit_df.empty and not cumulative_profit_df.isnull().all().all():
|
||||
fig2, ax = plt.subplots(figsize=(15, 7))
|
||||
plot_data = cumulative_profit_df.dropna(axis=1, how='all')
|
||||
if not plot_data.empty:
|
||||
sns.lineplot(data=plot_data, ax=ax, dashes=False)
|
||||
ax.set_xlabel('Target Day')
|
||||
ax.set_ylabel('Cumulative Profit (€)')
|
||||
ax.set_title('DAA Strategy: Cumulative Profit Comparison')
|
||||
ax.legend(title='Strategy')
|
||||
plt.grid()
|
||||
plt.tight_layout()
|
||||
plot_path = output_dir / "cumulative_profit_daa.png"
|
||||
plt.savefig(plot_path)
|
||||
logger.info(f"Cumulative Profit plot saved to {plot_path}")
|
||||
plt.close(fig2)
|
||||
else:
|
||||
logger.warning("Cumulative profit data is all NaN after filtering. Skipping Cumulative Profit plot.")
|
||||
else:
|
||||
logger.warning("No valid data available to plot Cumulative Profit.")
|
||||
|
||||
# Optional Plot 3: Example Day Schedule (Planned vs Executed vs Actual Price)
|
||||
# Select a representative day (e.g., first valid day)
|
||||
example_day_index = results_df.index.unique()[0] # Take the first day
|
||||
example_day_data = results_df.loc[[example_day_index]] # Select rows for that day
|
||||
actual_prices_example = daily_results_list[0].get('actual_prices') # Get actual prices for the first day
|
||||
|
||||
if actual_prices_example and not example_day_data.empty:
|
||||
fig3, ax1 = plt.subplots(figsize=(15, 7))
|
||||
ax2 = ax1.twinx() # for prices
|
||||
|
||||
hours = np.arange(OPTIMIZATION_HORIZON_HOURS)
|
||||
#plot_styles = {'baseline': 'r--', **{name: next(iter(ax1.lines))['color'] + '-' for name in successfully_loaded_provider_names}} # Cycle colors
|
||||
|
||||
for provider_type in ['baseline'] + successfully_loaded_provider_names:
|
||||
provider_row = example_day_data[example_day_data['type'] == provider_type]
|
||||
if not provider_row.empty:
|
||||
planned_p = provider_row['planned_P_schedule'].iloc[0]
|
||||
executed_p = provider_row['executed_P_schedule'].iloc[0]
|
||||
if isinstance(planned_p, list) and isinstance(executed_p, list) and len(planned_p) == 24 and len(executed_p) == 24:
|
||||
#color, style = plot_styles[provider_type][:-1], plot_styles[provider_type][-1]
|
||||
ax1.plot(hours, planned_p, label=f'{provider_type} Planned P', linestyle=':', marker='.') # Dotted for planned
|
||||
ax1.plot(hours, executed_p, label=f'{provider_type} Executed P', marker='x') # Solid/Dashed for executed
|
||||
|
||||
# Plot actual prices
|
||||
ax2.plot(hours, actual_prices_example, label='Actual Price', color='grey', linestyle='-', marker='o', markersize=4)
|
||||
ax2.set_ylabel('Price (€/MWh)', color='grey')
|
||||
ax2.tick_params(axis='y', labelcolor='grey')
|
||||
# ax2.legend(loc='upper right') # Maybe too crowded with other legends
|
||||
|
||||
ax1.set_xlabel('Hour of Day')
|
||||
ax1.set_ylabel('Power (MW)')
|
||||
ax1.set_title(f'DAA Strategy: Example Day Schedule ({example_day_index.strftime("%Y-%m-%d")})')
|
||||
ax1.axhline(0, color='black', linestyle='-', linewidth=0.5) # Zero power line
|
||||
# Combine legends
|
||||
lines, labels = ax1.get_legend_handles_labels()
|
||||
lines2, labels2 = ax2.get_legend_handles_labels()
|
||||
ax1.legend(lines + lines2, labels + labels2, loc='best')
|
||||
|
||||
plt.xticks(hours) # Show all hours
|
||||
plt.tight_layout()
|
||||
plot_path = output_dir / "example_schedule_daa.png"
|
||||
plt.grid()
|
||||
plt.savefig(plot_path)
|
||||
logger.info(f"Example Day Schedule plot saved to {plot_path}")
|
||||
plt.close(fig3)
|
||||
else:
|
||||
logger.warning("Could not generate example day schedule plot (data missing).")
|
||||
|
||||
|
||||
# --- Save Results DataFrame (Optional) ---
|
||||
try:
|
||||
results_save_path = output_dir / "optimization_results_daa.csv"
|
||||
results_df_to_save = results_df.reset_index()
|
||||
# Convert list columns to strings for CSV compatibility
|
||||
list_cols = ['planned_P_schedule', 'executed_P_schedule', 'actual_B_schedule_start', 'actual_prices', 'forecast_prices']
|
||||
for col in list_cols:
|
||||
if col in results_df_to_save.columns:
|
||||
results_df_to_save[col] = results_df_to_save[col].apply(lambda x: str(x) if isinstance(x, list) else x)
|
||||
|
||||
results_df_to_save.to_csv(results_save_path, index=False)
|
||||
logger.info(f"Saved detailed DAA results DataFrame to {results_save_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save DAA results DataFrame: {e}", exc_info=True)
|
||||
|
||||
|
||||
logger.info("DAA Evaluation and plotting completed.")
|
@ -1,22 +1,48 @@
|
||||
# forecasting/base.py
|
||||
from typing import List, Dict, Any
|
||||
from typing import List
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ForecastProvider:
|
||||
def get_forecasts(self,
|
||||
historical_data: pd.DataFrame,
|
||||
forecast_horizons: List[int],
|
||||
optimization_horizon: int) -> Dict[int, np.ndarray]:
|
||||
"""Returns forecasts for each requested horizon."""
|
||||
pass
|
||||
def get_forecast(
|
||||
self,
|
||||
engineered_df: pd.DataFrame,
|
||||
forecast_start_time: pd.Timestamp,
|
||||
optimization_horizon_hours: int,
|
||||
last_actual_price: float
|
||||
) -> np.ndarray | None:
|
||||
"""
|
||||
Generates an hourly forecast for the specified optimization horizon.
|
||||
|
||||
Args:
|
||||
engineered_df: DataFrame containing pre-calculated features **specifically
|
||||
engineered for this provider**. Must have a DatetimeIndex.
|
||||
forecast_start_time: The timestamp corresponding to the last data point
|
||||
available before the forecast begins (t=0). The input
|
||||
sequence for the model will end at this timestamp.
|
||||
optimization_horizon_hours: The number of future hours to forecast.
|
||||
last_actual_price: The actual observed value at forecast_start_time, used as the
|
||||
anchor point (t=0) for interpolation.
|
||||
|
||||
Returns:
|
||||
A numpy array of shape (optimization_horizon_hours,) containing the hourly forecast,
|
||||
or None if forecasting fails.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement get_forecast")
|
||||
|
||||
def get_required_lookback(self) -> int:
|
||||
"""Returns the minimum number of historical data points required."""
|
||||
pass
|
||||
"""
|
||||
Returns the minimum number of *sequence* data points required as input by the model(s).
|
||||
This corresponds to the model's sequence_length (or max sequence_length for ensembles).
|
||||
Feature engineering lookback is handled *before* calling get_forecast, during the
|
||||
creation of the provider-specific engineered_df.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement get_required_lookback")
|
||||
|
||||
def get_forecast_horizons(self) -> List[int]:
|
||||
"""Returns the list of forecast horizons."""
|
||||
pass
|
||||
"""Returns the list of native forecast horizons the model predicts."""
|
||||
logger.warning(f"{self.__class__.__name__} does not explicitly provide forecast horizons. Interpolation might rely on defaults or fail if horizons are needed.")
|
||||
return []
|
||||
|
||||
|
@ -1,135 +1,164 @@
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import List, Dict, Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
||||
|
||||
from .base import ForecastProvider
|
||||
from forecasting_model.utils import FeatureConfig
|
||||
from forecasting_model.train.model import LSTMForecastLightningModule
|
||||
from forecasting_model import engineer_features
|
||||
from optimizer.forecasting.base import ForecastProvider
|
||||
from optimizer.forecasting.utils import interpolate_forecast
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
class EnsembleProvider(ForecastProvider):
|
||||
"""Provides forecasts using an ensemble of trained LSTM models."""
|
||||
"""
|
||||
Provides forecasts using an ensemble of trained LSTM models.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fold_artifacts: List[Dict[str, Any]],
|
||||
ensemble_method: str,
|
||||
ensemble_feature_config: FeatureConfig, # Assumed consistent across folds by loading logic
|
||||
ensemble_target_col: str, # Assumed consistent
|
||||
):
|
||||
if not fold_artifacts:
|
||||
raise ValueError("EnsembleProvider requires at least one fold artifact.")
|
||||
|
||||
self.fold_artifacts = fold_artifacts
|
||||
self.ensemble_method = ensemble_method
|
||||
# Store common config for reference, but use fold-specific details in get_forecast
|
||||
self.ensemble_feature_config = ensemble_feature_config
|
||||
self.ensemble_target_col = ensemble_target_col
|
||||
self.common_forecast_horizons = sorted(ensemble_feature_config.forecast_horizon) # Assumed consistent
|
||||
|
||||
# Calculate max lookback needed across all folds
|
||||
max_lookback = 0
|
||||
for i, fold in enumerate(fold_artifacts):
|
||||
try:
|
||||
fold_feature_config = fold['feature_config']
|
||||
fold_seq_len = fold_feature_config.sequence_length
|
||||
|
||||
feature_lookback = 0
|
||||
if fold_feature_config.lags:
|
||||
feature_lookback = max(feature_lookback, max(fold_feature_config.lags))
|
||||
if fold_feature_config.rolling_window_sizes:
|
||||
feature_lookback = max(feature_lookback, max(w - 1 for w in fold_feature_config.rolling_window_sizes))
|
||||
|
||||
fold_total_lookback = fold_seq_len + feature_lookback
|
||||
max_lookback = max(max_lookback, fold_total_lookback)
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Fold artifact {i} is missing expected key: {e}") from e
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error processing fold artifact {i} for lookback calculation: {e}") from e
|
||||
|
||||
self._required_lookback = max_lookback
|
||||
logger.debug(f"EnsembleProvider initialized with {len(fold_artifacts)} folds. Method: '{ensemble_method}'. Required lookback: {self._required_lookback}")
|
||||
self._max_sequence_length = 0
|
||||
self._fold_info: Dict[Any, Dict[str, Any]] = {} # Store simplified fold info by fold_id
|
||||
|
||||
if ensemble_method not in ['mean', 'median']:
|
||||
raise ValueError(f"Unsupported ensemble method: {ensemble_method}. Use 'mean' or 'median'.")
|
||||
|
||||
# Process folds to store necessary info (model, scaler, seq_len, horizons) and find max sequence length
|
||||
for i, fold_data in enumerate(fold_artifacts):
|
||||
fold_id = fold_data.get("fold_id", i + 1) # Use provided fold_id or index
|
||||
try:
|
||||
model = fold_data['model_instance']
|
||||
f_config = fold_data['feature_config']
|
||||
target_scaler = fold_data['target_scaler']
|
||||
data_scaler = fold_data['data_scaler']
|
||||
seq_len = f_config.sequence_length
|
||||
horizons = sorted(f_config.forecast_horizon)
|
||||
model_input_dim = getattr(model, 'input_size', None) # Store optional input dim
|
||||
|
||||
self._max_sequence_length = max(self._max_sequence_length, seq_len)
|
||||
self._fold_info[fold_id] = {
|
||||
'model': model,
|
||||
'target_scaler': target_scaler,
|
||||
'data_scaler': data_scaler,
|
||||
'sequence_length': seq_len,
|
||||
'horizons': horizons,
|
||||
'model_input_dim': model_input_dim
|
||||
}
|
||||
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Fold artifact (id: {fold_id}) is missing expected key: {e}") from e
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error processing fold artifact (id: {fold_id}) for setup: {e}") from e
|
||||
|
||||
self._required_lookback = self._max_sequence_length
|
||||
logger.debug(f"EnsembleProvider initialized with {len(self._fold_info)} folds. Method: '{ensemble_method}'. Max sequence length: {self._required_lookback}. Using shared data scaler.")
|
||||
|
||||
|
||||
def get_required_lookback(self) -> int:
|
||||
"""Returns the maximum sequence length required across all fold models."""
|
||||
return self._required_lookback
|
||||
|
||||
def get_forecast_horizons(self) -> List[int]:
|
||||
"""Returns the list of forecast horizons for the first fold (representative)."""
|
||||
if self._fold_info:
|
||||
first_fold_id = next(iter(self._fold_info)) # Get the key of the first item
|
||||
return self._fold_info[first_fold_id].get('horizons', [])
|
||||
return []
|
||||
|
||||
|
||||
def get_forecast(
|
||||
self,
|
||||
historical_data_slice: pd.DataFrame,
|
||||
optimization_horizon_hours: int
|
||||
engineered_df: pd.DataFrame, # Receives the DF engineered for *this specific ensemble*
|
||||
forecast_start_time: pd.Timestamp, # Use timestamp
|
||||
optimization_horizon_hours: int,
|
||||
last_actual_price: float
|
||||
) -> np.ndarray | None:
|
||||
"""
|
||||
Generates forecasts from each fold model, interpolates, and aggregates.
|
||||
Generates forecasts from each fold model using the ensemble's dedicated pre-engineered features,
|
||||
applies the shared data scaler, interpolates, and aggregates based on timestamp.
|
||||
"""
|
||||
logger.debug(f"EnsembleProvider: Generating forecast for {optimization_horizon_hours} hours using {self.ensemble_method}.")
|
||||
if len(historical_data_slice) < self._required_lookback:
|
||||
logger.error(f"Insufficient historical data provided. Need {self._required_lookback}, got {len(historical_data_slice)}.")
|
||||
return None
|
||||
logger.debug(f"EnsembleProvider: Generating forecast for {optimization_horizon_hours} hours starting after {forecast_start_time} using {self.ensemble_method}.")
|
||||
|
||||
fold_forecasts_interpolated = []
|
||||
last_actual_price = historical_data_slice[self.ensemble_target_col].iloc[-1] # Common anchor for all folds
|
||||
original_columns = engineered_df.columns # Keep original columns for scaled DataFrame
|
||||
|
||||
for fold_id, fold_data in self._fold_info.items():
|
||||
fold_model = fold_data['model']
|
||||
fold_target_scaler = fold_data['target_scaler']
|
||||
fold_data_scaler = fold_data['data_scaler']
|
||||
fold_seq_len = fold_data['sequence_length']
|
||||
fold_horizons = fold_data['horizons']
|
||||
fold_model_input_dim = fold_data['model_input_dim']
|
||||
|
||||
for i, fold_artifact in enumerate(self.fold_artifacts):
|
||||
fold_id = fold_artifact.get("fold_id", i + 1)
|
||||
try:
|
||||
fold_model: LSTMForecastLightningModule = fold_artifact['model_instance']
|
||||
fold_feature_config: FeatureConfig = fold_artifact['feature_config']
|
||||
fold_target_scaler: Optional[Any] = fold_artifact['target_scaler']
|
||||
fold_target_col: str = fold_artifact['main_forecasting_config'].data.target_col # Use fold specific target
|
||||
fold_seq_len = fold_feature_config.sequence_length
|
||||
fold_horizons = sorted(fold_feature_config.forecast_horizon)
|
||||
# 1. Select Input Sequence (Rows only, using timestamp)
|
||||
# Find the integer location of the forecast_start_time
|
||||
end_loc = engineered_df.index.get_loc(forecast_start_time)
|
||||
seq_start_loc = end_loc - fold_seq_len + 1
|
||||
|
||||
# Calculate lookback needed *for this specific fold* to check slice length
|
||||
fold_feature_lookback = 0
|
||||
if fold_feature_config.lags: fold_feature_lookback = max(fold_feature_lookback, max(fold_feature_config.lags))
|
||||
if fold_feature_config.rolling_window_sizes: fold_feature_lookback = max(fold_feature_lookback, max(w - 1 for w in fold_feature_config.rolling_window_sizes))
|
||||
fold_total_lookback = fold_seq_len + fold_feature_lookback
|
||||
|
||||
if len(historical_data_slice) < fold_total_lookback:
|
||||
logger.warning(f"Fold {fold_id}: Skipping fold. Insufficient historical data in slice for this fold's lookback ({fold_total_lookback} needed).")
|
||||
if seq_start_loc < 0:
|
||||
logger.warning(f"Fold {fold_id}: Skipping fold. Calculated sequence start location ({seq_start_loc}) is negative for time {forecast_start_time}.")
|
||||
continue
|
||||
|
||||
# 1. Feature Engineering (using fold's config)
|
||||
# Slice needs to be long enough for this fold's total lookback.
|
||||
# The input slice `historical_data_slice` should already be long enough based on max_lookback.
|
||||
engineered_df_fold = engineer_features(historical_data_slice.copy(), fold_target_col, fold_feature_config)
|
||||
# Slice using iloc derived from the timestamp
|
||||
input_sequence_data_fold = engineered_df.iloc[seq_start_loc : end_loc + 1]
|
||||
|
||||
if engineered_df_fold.isnull().any().any():
|
||||
logger.warning(f"Fold {fold_id}: NaNs found after feature engineering. Attempting fill.")
|
||||
engineered_df_fold = engineered_df_fold.ffill().bfill()
|
||||
if engineered_df_fold.isnull().any().any():
|
||||
logger.error(f"Fold {fold_id}: NaNs persist after fill. Skipping fold.")
|
||||
# Check for NaNs *before* scaling
|
||||
if input_sequence_data_fold.isnull().values.any():
|
||||
logger.error(f"Fold {fold_id}: NaNs found in the input sequence slice ending at {forecast_start_time} (iloc {seq_start_loc}:{end_loc+1}) *before* scaling. Skipping fold.")
|
||||
nan_cols = input_sequence_data_fold.columns[input_sequence_data_fold.isnull().any()].tolist()
|
||||
logger.error(f"Fold {fold_id}: Columns with NaNs in sequence: {nan_cols}")
|
||||
continue
|
||||
|
||||
|
||||
if fold_data_scaler: # Check if a target scaler exists for this fold
|
||||
try:
|
||||
scaled_sequence_data_np = fold_data_scaler.transform(input_sequence_data_fold.values)
|
||||
except Exception as e:
|
||||
logger.error(f"Fold {fold_id}: Error applying shared data_scaler: {e}. Skipping fold.", exc_info=True)
|
||||
scaled_sequence_data_np = input_sequence_data_fold.values
|
||||
continue
|
||||
|
||||
# 2. Create *one* input sequence (using fold's sequence length)
|
||||
if len(engineered_df_fold) < fold_seq_len:
|
||||
logger.error(f"Fold {fold_id}: Engineered data ({len(engineered_df_fold)}) is shorter than fold sequence length ({fold_seq_len}). Skipping fold.")
|
||||
# Check for NaNs *after* scaling (might occur with certain scalers or data issues)
|
||||
if np.isnan(scaled_sequence_data_np).any():
|
||||
logger.error(f"Fold {fold_id}: NaNs found in the input sequence *after* applying data_scaler. Skipping fold.")
|
||||
continue
|
||||
|
||||
input_sequence_data_fold = engineered_df_fold.iloc[-fold_seq_len:].copy()
|
||||
feature_columns_fold = [col for col in engineered_df_fold.columns if col != fold_target_col] # Example
|
||||
if not feature_columns_fold: feature_columns_fold = engineered_df_fold.columns.tolist()
|
||||
input_sequence_np_fold = input_sequence_data_fold[feature_columns_fold].values
|
||||
|
||||
if input_sequence_np_fold.shape != (fold_seq_len, len(feature_columns_fold)):
|
||||
logger.error(f"Fold {fold_id}: Input sequence has wrong shape. Expected ({fold_seq_len}, {len(feature_columns_fold)}), got {input_sequence_np_fold.shape}. Skipping fold.")
|
||||
# Shape check (using the scaled NumPy array)
|
||||
num_features_in_df = len(original_columns) # Get feature count from original columns
|
||||
if scaled_sequence_data_np.shape != (fold_seq_len, num_features_in_df):
|
||||
logger.error(f"Fold {fold_id}: Scaled input sequence NumPy array has wrong shape. Expected ({fold_seq_len}, {num_features_in_df}), got {scaled_sequence_data_np.shape}. Skipping fold.")
|
||||
continue
|
||||
|
||||
input_tensor_fold = torch.FloatTensor(input_sequence_np_fold).unsqueeze(0)
|
||||
# Optional: Check fold model's input size if available
|
||||
if fold_model_input_dim is not None and fold_model_input_dim != num_features_in_df:
|
||||
logger.error(f"Fold {fold_id}: Model expected input size ({fold_model_input_dim}) mismatch with features in DataFrame ({num_features_in_df}). Skipping fold.")
|
||||
continue
|
||||
|
||||
# 3. Run Inference (using fold's model)
|
||||
input_tensor_fold = torch.FloatTensor(scaled_sequence_data_np).unsqueeze(0) # Use scaled data
|
||||
|
||||
except KeyError:
|
||||
logger.warning(f"Fold {fold_id}: Timestamp {forecast_start_time} not found in the ensemble's engineered DataFrame index. Skipping fold.")
|
||||
continue
|
||||
except IndexError as e:
|
||||
logger.error(f"Fold {fold_id}: Error slicing ensemble DataFrame (shape {engineered_df.shape}) using iloc {seq_start_loc}:{end_loc+1} derived from time {forecast_start_time}: {e}. Skipping fold.", exc_info=True)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Fold {fold_id}: Error preparing input sequence ending at {forecast_start_time}: {e}. Skipping fold.", exc_info=True)
|
||||
continue
|
||||
|
||||
# 2. Run Inference (using fold's model and scaled input)
|
||||
try:
|
||||
fold_model.eval()
|
||||
with torch.no_grad():
|
||||
predictions_scaled_fold = fold_model(input_tensor_fold) # Shape (1, num_fold_horizons)
|
||||
@ -139,50 +168,61 @@ class EnsembleProvider(ForecastProvider):
|
||||
continue
|
||||
|
||||
predictions_scaled_np_fold = predictions_scaled_fold.squeeze(0).cpu().numpy()
|
||||
except Exception as e:
|
||||
logger.error(f"Fold {fold_id}: Error during model inference: {e}. Skipping fold.", exc_info=True)
|
||||
continue
|
||||
|
||||
# 4. Inverse Transform (using fold's scaler)
|
||||
predictions_original_scale_fold = predictions_scaled_np_fold
|
||||
if fold_target_scaler:
|
||||
try:
|
||||
predictions_original_scale_fold = fold_target_scaler.inverse_transform(predictions_scaled_np_fold.reshape(-1, 1)).flatten()
|
||||
except Exception as e:
|
||||
logger.error(f"Fold {fold_id}: Failed to apply inverse transform: {e}. Skipping fold.", exc_info=True)
|
||||
continue
|
||||
|
||||
# 5. Interpolate (using fold's horizons)
|
||||
# 3. Inverse Transform (using fold's *target* scaler)
|
||||
predictions_original_scale_fold = predictions_scaled_np_fold
|
||||
if fold_target_scaler: # Check if a target scaler exists for this fold
|
||||
try:
|
||||
# Ensure the prediction array is 2D for inverse_transform
|
||||
predictions_original_scale_fold = fold_target_scaler.inverse_transform(predictions_scaled_np_fold.reshape(-1, 1)).flatten()
|
||||
except Exception as e:
|
||||
logger.error(f"Fold {fold_id}: Failed to apply inverse *target* transform: {e}. Skipping fold.", exc_info=True)
|
||||
continue
|
||||
|
||||
|
||||
# 4. Interpolate (using fold's horizons and the common last_actual_price)
|
||||
try:
|
||||
interpolated_forecast_fold = interpolate_forecast(
|
||||
native_horizons=fold_horizons,
|
||||
native_predictions=predictions_original_scale_fold,
|
||||
target_horizon=optimization_horizon_hours,
|
||||
last_known_actual=last_actual_price
|
||||
last_known_actual=last_actual_price # Use common anchor
|
||||
)
|
||||
|
||||
if interpolated_forecast_fold is not None:
|
||||
fold_forecasts_interpolated.append(interpolated_forecast_fold)
|
||||
logger.debug(f"Fold {fold_id}: Successfully generated interpolated forecast.")
|
||||
else:
|
||||
# interpolate_forecast logs errors internally
|
||||
logger.warning(f"Fold {fold_id}: Interpolation failed. Skipping fold.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing ensemble fold {fold_id}: {e}", exc_info=True)
|
||||
continue # Skip this fold on error
|
||||
logger.error(f"Fold {fold_id}: Error during interpolation: {e}. Skipping fold.", exc_info=True)
|
||||
continue
|
||||
|
||||
# --- Aggregation ---
|
||||
if not fold_forecasts_interpolated:
|
||||
logger.error("No successful forecasts generated from any ensemble folds.")
|
||||
logger.error(f"No successful forecasts generated from any ensemble folds for window after {forecast_start_time}.")
|
||||
return None
|
||||
|
||||
logger.debug(f"Aggregating forecasts from {len(fold_forecasts_interpolated)} folds using '{self.ensemble_method}'.")
|
||||
stacked_predictions = np.stack(fold_forecasts_interpolated, axis=0) # Shape (n_folds, target_horizon)
|
||||
logger.debug(f"Aggregating forecasts from {len(fold_forecasts_interpolated)} successful folds using '{self.ensemble_method}'.")
|
||||
stacked_predictions = np.stack(fold_forecasts_interpolated, axis=0) # Shape (n_successful_folds, target_horizon)
|
||||
|
||||
if self.ensemble_method == 'mean':
|
||||
final_ensemble_forecast = np.mean(stacked_predictions, axis=0)
|
||||
elif self.ensemble_method == 'median':
|
||||
final_ensemble_forecast = np.median(stacked_predictions, axis=0)
|
||||
else:
|
||||
# Should be caught in __init__, but double-check
|
||||
logger.error(f"Internal error: Invalid ensemble method '{self.ensemble_method}' during aggregation.")
|
||||
return None
|
||||
try:
|
||||
if self.ensemble_method == 'mean':
|
||||
final_ensemble_forecast = np.mean(stacked_predictions, axis=0)
|
||||
elif self.ensemble_method == 'median':
|
||||
final_ensemble_forecast = np.median(stacked_predictions, axis=0)
|
||||
else:
|
||||
# Should be caught in __init__, but double-check
|
||||
logger.critical(f"Internal error: Invalid ensemble method '{self.ensemble_method}' during aggregation.")
|
||||
return None # Or raise error
|
||||
|
||||
logger.debug(f"EnsembleProvider: Successfully generated forecast.")
|
||||
return final_ensemble_forecast
|
||||
logger.debug(f"EnsembleProvider: Successfully generated forecast for window after {forecast_start_time}.")
|
||||
return final_ensemble_forecast
|
||||
except Exception as e:
|
||||
logger.error(f"Error during final ensemble aggregation: {e}", exc_info=True)
|
||||
return None
|
@ -1,110 +1,137 @@
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import List, Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
||||
|
||||
# Imports from our project structure
|
||||
from .base import ForecastProvider
|
||||
from forecasting_model.utils import FeatureConfig
|
||||
from forecasting_model.train.model import LSTMForecastLightningModule
|
||||
from forecasting_model import engineer_features
|
||||
from optimizer.forecasting.utils import interpolate_forecast
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SingleModelProvider(ForecastProvider):
|
||||
"""Provides forecasts using a single trained LSTM model."""
|
||||
"""Provides forecasts using a single trained LSTM model and its dedicated pre-engineered DataFrame."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_instance: LSTMForecastLightningModule,
|
||||
feature_config: FeatureConfig,
|
||||
target_col: str,
|
||||
target_scaler: Optional[Any], # BaseEstimator, TransformerMixin -> more specific if possible
|
||||
# input_size: int # Not needed directly if model instance is configured
|
||||
feature_config: FeatureConfig, # Needed for sequence_length, forecast_horizon
|
||||
target_scaler: Optional[Any], # Pass the scaler fitted during training/loading
|
||||
data_scaler: Optional[Any], # Pass the scaler used for input features
|
||||
):
|
||||
self.model = model_instance
|
||||
self.feature_config = feature_config
|
||||
self.target_col = target_col
|
||||
self.target_scaler = target_scaler
|
||||
# Store necessary parts of feature_config
|
||||
self.sequence_length = feature_config.sequence_length
|
||||
self.forecast_horizons = sorted(feature_config.forecast_horizon) # Ensure sorted
|
||||
self.target_scaler = target_scaler # Store the target scaler
|
||||
self.data_scaler = data_scaler # Store the input data scaler
|
||||
|
||||
# Calculate required lookback for feature engineering
|
||||
feature_lookback = 0
|
||||
if feature_config.lags:
|
||||
feature_lookback = max(feature_lookback, max(feature_config.lags))
|
||||
if feature_config.rolling_window_sizes:
|
||||
# Rolling window of size W needs W-1 previous points
|
||||
feature_lookback = max(feature_lookback, max(w - 1 for w in feature_config.rolling_window_sizes))
|
||||
# Input size check (optional, depends on model having attribute)
|
||||
self._model_input_dim = getattr(model_instance, 'input_size', None)
|
||||
if self._model_input_dim is None:
|
||||
logger.debug("Could not get model's expected input size attribute.")
|
||||
|
||||
# Total lookback: sequence length for model input + feature engineering needs
|
||||
# We need `sequence_length` points for the *last* input sequence.
|
||||
# The first point of that sequence needs `feature_lookback` points before it.
|
||||
# So, total points needed before the *end* of the input sequence is sequence_length + feature_lookback.
|
||||
# Since the input sequence ends *before* the first forecast point (t=1),
|
||||
# we need `sequence_length + feature_lookback` points before t=1.
|
||||
self._required_lookback = self.sequence_length + feature_lookback
|
||||
logger.debug(f"SingleModelProvider initialized. Required lookback: {self._required_lookback} (SeqLen: {self.sequence_length}, FeatLookback: {feature_lookback})")
|
||||
logger.debug(f"SingleModelProvider initialized. Sequence length: {self.sequence_length}, Horizons: {self.forecast_horizons}. Target scaler: {'Provided' if self.target_scaler else 'None'}, Data scaler: {'Provided' if self.data_scaler else 'None'}.")
|
||||
|
||||
|
||||
def get_required_lookback(self) -> int:
|
||||
return self._required_lookback
|
||||
"""Returns the sequence length required by the model."""
|
||||
return self.sequence_length
|
||||
|
||||
def get_forecast_horizons(self) -> List[int]:
|
||||
"""Returns the list of forecast horizons the model natively predicts."""
|
||||
return self.forecast_horizons
|
||||
|
||||
def get_forecast(
|
||||
self,
|
||||
historical_data_slice: pd.DataFrame,
|
||||
optimization_horizon_hours: int
|
||||
engineered_df: pd.DataFrame, # Receives the DF engineered specifically for this model
|
||||
forecast_start_time: pd.Timestamp, # Use timestamp instead of iloc
|
||||
optimization_horizon_hours: int,
|
||||
last_actual_price: float
|
||||
) -> np.ndarray | None:
|
||||
"""
|
||||
Generates forecast using the single model and interpolates to hourly resolution.
|
||||
Generates forecast using the single model and its dedicated pre-engineered features DataFrame.
|
||||
Selects input sequence based on forecast_start_time.
|
||||
"""
|
||||
logger.debug(f"SingleModelProvider: Generating forecast for {optimization_horizon_hours} hours.")
|
||||
if len(historical_data_slice) < self._required_lookback:
|
||||
logger.error(f"Insufficient historical data provided. Need {self._required_lookback}, got {len(historical_data_slice)}.")
|
||||
return None
|
||||
logger.debug(f"SingleModelProvider: Generating forecast for {optimization_horizon_hours} hours starting after {forecast_start_time}.")
|
||||
|
||||
# 1. Select Input Sequence using Timestamp
|
||||
try:
|
||||
# 1. Feature Engineering
|
||||
# Use the provided slice which already includes the lookback.
|
||||
engineered_df = engineer_features(historical_data_slice.copy(), self.target_col, self.feature_config)
|
||||
# Find the integer location of the forecast_start_time in the DataFrame's index
|
||||
end_loc = engineered_df.index.get_loc(forecast_start_time)
|
||||
start_loc = end_loc - self.sequence_length + 1
|
||||
|
||||
# Check for NaNs after feature engineering before creating sequences
|
||||
if engineered_df.isnull().any().any():
|
||||
logger.warning("NaNs found after feature engineering. Attempting to fill with ffill/bfill.")
|
||||
# Be careful about filling target vs features if needed
|
||||
engineered_df = engineered_df.ffill().bfill()
|
||||
if engineered_df.isnull().any().any():
|
||||
logger.error("NaNs persist after fill. Cannot create sequences.")
|
||||
return None
|
||||
|
||||
# 2. Create *one* input sequence ending at the last point of the historical slice
|
||||
# This sequence is used to predict starting from the next hour (t=1)
|
||||
if len(engineered_df) < self.sequence_length:
|
||||
logger.error(f"Engineered data ({len(engineered_df)}) is shorter than sequence length ({self.sequence_length}).")
|
||||
if start_loc < 0:
|
||||
logger.error(f"Cannot create input sequence: calculated start location ({start_loc}) is negative. forecast_start_time={forecast_start_time}, sequence_length={self.sequence_length}.")
|
||||
return None
|
||||
|
||||
input_sequence_data = engineered_df.iloc[-self.sequence_length:].copy()
|
||||
# Slice using iloc derived from the timestamp
|
||||
input_sequence_data = engineered_df.iloc[start_loc : end_loc + 1]
|
||||
|
||||
# Convert sequence data to numpy array (excluding target if model expects it that way)
|
||||
# Assuming model takes all engineered features as input
|
||||
# TODO: Verify the exact features the model expects (target included/excluded?)
|
||||
# Assuming all columns except maybe the original target are features
|
||||
feature_columns = [col for col in engineered_df.columns if col != self.target_col] # Example
|
||||
if not feature_columns: feature_columns = engineered_df.columns.tolist() # Use all if target wasn't dropped
|
||||
input_sequence_np = input_sequence_data[feature_columns].values
|
||||
# --- Validation Checks ---
|
||||
if len(input_sequence_data) != self.sequence_length:
|
||||
logger.error(f"Selected input sequence has wrong length. Expected {self.sequence_length}, got {len(input_sequence_data)}. Check slicing logic around {forecast_start_time} (iloc {start_loc}:{end_loc+1}).")
|
||||
return None
|
||||
|
||||
if input_sequence_np.shape != (self.sequence_length, len(feature_columns)):
|
||||
logger.error(f"Input sequence has wrong shape. Expected ({self.sequence_length}, {len(feature_columns)}), got {input_sequence_np.shape}")
|
||||
# Check for NaNs *within the selected sequence*
|
||||
if input_sequence_data.isnull().values.any(): # Faster check on numpy array
|
||||
logger.error(f"NaNs found in the input sequence ending at {forecast_start_time} (iloc {start_loc} to {end_loc}). Cannot generate forecast.")
|
||||
# Find problematic columns for debugging
|
||||
nan_cols = input_sequence_data.columns[input_sequence_data.isnull().any()].tolist()
|
||||
logger.error(f"Columns with NaNs in sequence: {nan_cols}")
|
||||
return None
|
||||
|
||||
input_tensor = torch.FloatTensor(input_sequence_np).unsqueeze(0) # Add batch dim
|
||||
input_sequence_np = input_sequence_data.astype('float64').values
|
||||
n_features_in_df = input_sequence_np.shape[1] # Get actual feature count from data
|
||||
|
||||
# 3. Run Inference
|
||||
# --- Scale Input Features (if data_scaler is provided) ---
|
||||
if self.data_scaler:
|
||||
try:
|
||||
# Scaler expects shape (n_samples, n_features)
|
||||
input_sequence_scaled_np = self.data_scaler.transform(input_sequence_np)
|
||||
logger.debug("Applied data scaler transform to input sequence.")
|
||||
except Exception as e:
|
||||
input_sequence_scaled_np = None
|
||||
logger.error(f"Failed to apply data scaler transform: {e}", exc_info=True)
|
||||
return None # Fail if scaling doesn't work
|
||||
else:
|
||||
input_sequence_scaled_np = input_sequence_np # Use unscaled if no scaler
|
||||
logger.debug("No data scaler provided, using unscaled input sequence.")
|
||||
|
||||
|
||||
# Shape check (on scaled data if applicable)
|
||||
if input_sequence_scaled_np.shape != (self.sequence_length, n_features_in_df):
|
||||
# This check is still relevant, even after scaling
|
||||
logger.error(f"Scaled input sequence NumPy array has wrong shape. Expected ({self.sequence_length}, {n_features_in_df}), got {input_sequence_scaled_np.shape}")
|
||||
return None
|
||||
|
||||
# Check against model's expected input size if available
|
||||
if self._model_input_dim is not None and self._model_input_dim != n_features_in_df:
|
||||
logger.error(f"Model's expected input size ({self._model_input_dim}) does not match number of features in provided engineered data ({n_features_in_df}).")
|
||||
return None
|
||||
|
||||
|
||||
# Convert to tensor
|
||||
input_tensor = torch.FloatTensor(input_sequence_scaled_np).unsqueeze(0) # Add batch dim
|
||||
|
||||
except KeyError:
|
||||
logger.error(f"Timestamp {forecast_start_time} not found in the engineered DataFrame's index. Cannot select input sequence.")
|
||||
return None
|
||||
except IndexError as e:
|
||||
# Might occur if get_loc works but slicing fails (less likely with check above)
|
||||
logger.error(f"Error slicing engineered DataFrame (shape {engineered_df.shape}) using iloc {start_loc}:{end_loc+1} derived from time {forecast_start_time}: {e}", exc_info=True)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing input sequence ending at {forecast_start_time}: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
# 2. Run Inference
|
||||
try:
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
# Model output shape: (1, num_horizons)
|
||||
@ -115,36 +142,39 @@ class SingleModelProvider(ForecastProvider):
|
||||
return None
|
||||
|
||||
predictions_scaled_np = predictions_scaled.squeeze(0).cpu().numpy() # Shape: (num_horizons,)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during model inference: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
# 4. Inverse Transform
|
||||
predictions_original_scale = predictions_scaled_np
|
||||
if self.target_scaler:
|
||||
try:
|
||||
# Scaler expects shape (n_samples, n_features), even if n_features=1
|
||||
predictions_original_scale = self.target_scaler.inverse_transform(predictions_scaled_np.reshape(-1, 1)).flatten()
|
||||
logger.debug("Applied inverse transform to predictions.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to apply inverse transform: {e}", exc_info=True)
|
||||
# Decide whether to return scaled or None. Returning None is safer.
|
||||
return None
|
||||
# 3. Inverse Transform
|
||||
predictions_original_scale = predictions_scaled_np
|
||||
if self.target_scaler:
|
||||
try:
|
||||
# Scaler expects shape (n_samples, n_features), even if n_features=1
|
||||
predictions_original_scale = self.target_scaler.inverse_transform(predictions_scaled_np.reshape(-1, 1)).flatten()
|
||||
logger.debug("Applied inverse transform to predictions.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to apply inverse transform: {e}", exc_info=True)
|
||||
return None # Fail if inverse transform doesn't work
|
||||
|
||||
# 5. Interpolate
|
||||
# Use the last actual price from the input data as the anchor point t=0
|
||||
last_actual_price = historical_data_slice[self.target_col].iloc[-1]
|
||||
|
||||
# 4. Interpolate
|
||||
try:
|
||||
interpolated_forecast = interpolate_forecast(
|
||||
native_horizons=self.forecast_horizons,
|
||||
native_predictions=predictions_original_scale,
|
||||
target_horizon=optimization_horizon_hours,
|
||||
last_known_actual=last_actual_price
|
||||
last_known_actual=last_actual_price # Use the provided value
|
||||
)
|
||||
|
||||
if interpolated_forecast is None:
|
||||
# interpolate_forecast logs errors internally
|
||||
logger.error("Interpolation step failed.")
|
||||
return None
|
||||
|
||||
logger.debug(f"SingleModelProvider: Successfully generated forecast.")
|
||||
logger.debug(f"SingleModelProvider: Successfully generated forecast for window after {forecast_start_time}.")
|
||||
return interpolated_forecast
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during single model forecast generation: {e}", exc_info=True)
|
||||
logger.error(f"Error during forecast interpolation: {e}", exc_info=True)
|
||||
return None
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, Dict, Any
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import logging
|
||||
@ -12,7 +12,7 @@ def interpolate_forecast(
|
||||
native_horizons: List[int],
|
||||
native_predictions: np.ndarray,
|
||||
target_horizon: int,
|
||||
last_known_actual: Optional[float] = None # Optional: use last known price as t=0 for anchor
|
||||
last_known_actual: Optional[float] = None
|
||||
) -> np.ndarray | None:
|
||||
"""
|
||||
Linearly interpolates model predictions at native horizons to a full hourly sequence.
|
||||
|
@ -1,5 +1,8 @@
|
||||
import cvxpy as cp
|
||||
import numpy as np
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def solve_battery_optimization_hourly(
|
||||
@ -19,7 +22,7 @@ def solve_battery_optimization_hourly(
|
||||
max_rate: Maximum charge/discharge power rate (MW).
|
||||
|
||||
Returns:
|
||||
Tuple: (status, optimal_profit, power_schedule, B_schedule)
|
||||
Tuple: (status, optimal_profit, P_schedule, B_schedule)
|
||||
Returns (status, None, None, None) if optimization fails.
|
||||
"""
|
||||
n_hours = len(hourly_prices)
|
||||
@ -33,20 +36,25 @@ def solve_battery_optimization_hourly(
|
||||
# --- Objective Function ---
|
||||
# Profit = sum(price[t] * Power[t])
|
||||
prices = np.array(hourly_prices)
|
||||
profit = prices @ P # Equivalent to cp.sum(cp.multiply(prices, P)) / prices.dot(P)
|
||||
# P is -discharge/+charge, so we negate to get profit
|
||||
# (selling = discharge = -P, so profit = price * -P)
|
||||
profit = -prices @ P
|
||||
objective = cp.Maximize(profit)
|
||||
|
||||
# --- Constraints ---
|
||||
constraints = []
|
||||
constraints = list()
|
||||
|
||||
# 1. Initial B
|
||||
constraints.append(B[0] == initial_B)
|
||||
|
||||
# 2. B Dynamics: B[t+1] = B[t] - P[t] * 1 hour
|
||||
# 2. B Dynamics: B[t+1] = B[t] + P[t] * 1 hour
|
||||
# P is -discharge/+charge, so battery increases when charging (P > 0)
|
||||
# and decreases when discharging (P < 0)
|
||||
constraints.append(B[1:] == B[:-1] + P)
|
||||
|
||||
# 3. Power Rate Limits: -max_rate <= P[t] <= max_rate
|
||||
constraints.append(cp.abs(P) <= max_rate)
|
||||
constraints.append(P <= max_rate)
|
||||
constraints.append(P >= -max_rate)
|
||||
|
||||
# 4. B Limits: 0 <= B[t] <= max_capacity (applies to B[0]...B[n])
|
||||
constraints.append(B >= 0)
|
||||
@ -66,9 +74,9 @@ def solve_battery_optimization_hourly(
|
||||
B.value # NumPy array of optimal B at start of each hour
|
||||
)
|
||||
else:
|
||||
print(f"Optimization failed. Solver status: {problem.status}")
|
||||
logger.error(f"Optimization failed. Solver status: {problem.status}")
|
||||
return problem.status, None, None, None
|
||||
|
||||
except cp.error.SolverError as e:
|
||||
print(f"Solver Error: {e}")
|
||||
logger.error(f"Solver Error: {e}")
|
||||
return "Solver Error", None, None, None
|
||||
|
@ -5,7 +5,6 @@ from pathlib import Path
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
import torch
|
||||
from sklearn.base import BaseEstimator, TransformerMixin # For scaler type hint
|
||||
|
||||
# Import necessary components from forecasting_model
|
||||
from forecasting_model.utils.forecast_config_model import MainConfig, FeatureConfig
|
||||
@ -17,7 +16,8 @@ def load_single_model_artifact(
|
||||
model_path: Path,
|
||||
config_path: Path,
|
||||
input_size_path: Path,
|
||||
target_scaler_path: Optional[Path] = None
|
||||
target_scaler_path: Optional[Path] = None,
|
||||
data_scaler_path: Optional[Path] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Loads artifacts for a single trained model checkpoint.
|
||||
@ -27,10 +27,11 @@ def load_single_model_artifact(
|
||||
config_path: Path to the corresponding main YAML config file.
|
||||
input_size_path: Path to the input_size.pt file.
|
||||
target_scaler_path: Optional path to the target_scaler.pt file.
|
||||
data_scaler_path: Optional path to the data_scaler.pt file.
|
||||
|
||||
Returns:
|
||||
A dictionary containing loaded artifacts ('model_instance', 'feature_config',
|
||||
'target_scaler', 'main_forecasting_config'), or None if loading fails.
|
||||
'target_scaler', 'data_scaler', 'main_forecasting_config'), or None if loading fails.
|
||||
"""
|
||||
logger.info(f"Loading single model artifact from directory: {model_path.parent}")
|
||||
loaded_artifacts = {}
|
||||
@ -51,7 +52,7 @@ def load_single_model_artifact(
|
||||
if not input_size_path.is_file():
|
||||
logger.error(f"Input size file not found at {input_size_path}")
|
||||
return None
|
||||
input_size = torch.load(input_size_path)
|
||||
input_size = torch.load(input_size_path, weights_only=False)
|
||||
if not isinstance(input_size, int) or input_size <= 0:
|
||||
logger.error(f"Invalid input size loaded from {input_size_path}: {input_size}")
|
||||
return None
|
||||
@ -61,69 +62,69 @@ def load_single_model_artifact(
|
||||
target_scaler = None
|
||||
if target_scaler_path:
|
||||
if not target_scaler_path.is_file():
|
||||
logger.warning(f"Target scaler file not found at {target_scaler_path}. Proceeding without scaler.")
|
||||
logger.warning(f"Target scaler file not found at {target_scaler_path}. Proceeding without target scaler.")
|
||||
else:
|
||||
try:
|
||||
target_scaler = torch.load(target_scaler_path)
|
||||
# Basic check if it looks like a scaler
|
||||
if not isinstance(target_scaler, (BaseEstimator, TransformerMixin)):
|
||||
logger.warning(f"Loaded object from {target_scaler_path} might not be a valid scaler ({type(target_scaler)}).")
|
||||
# Decide if this should be a hard failure or just a warning
|
||||
else:
|
||||
logger.debug(f"Loaded target scaler from {target_scaler_path}")
|
||||
target_scaler = torch.load(target_scaler_path, weights_only=False)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading target scaler from {target_scaler_path}: {e}", exc_info=True)
|
||||
# Decide if this should be a hard failure
|
||||
return None # Fail hard if scaler loading fails
|
||||
return None
|
||||
loaded_artifacts['target_scaler'] = target_scaler
|
||||
|
||||
# 4. Initialize Model Architecture
|
||||
# Ensure model config forecast horizon matches feature config (should be guaranteed by MainConfig validation)
|
||||
if set(main_config.model.forecast_horizon) != set(main_config.features.forecast_horizon):
|
||||
logger.warning(f"Mismatch between model ({main_config.model.forecast_horizon}) and feature ({main_config.features.forecast_horizon}) forecast horizons in config {config_path}. Using feature config.")
|
||||
# This might indicate an issue with the saved config, but we proceed using the feature config horizon
|
||||
# main_config.model.forecast_horizon = main_config.features.forecast_horizon # Correct it for model init? Risky.
|
||||
# 4. Load Data Scaler (Optional)
|
||||
data_scaler = None
|
||||
if data_scaler_path:
|
||||
if not data_scaler_path.is_file():
|
||||
logger.warning(f"Data scaler file not found at {data_scaler_path}. Proceeding without data scaler.")
|
||||
else:
|
||||
try:
|
||||
data_scaler = torch.load(data_scaler_path, weights_only=False)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading data scaler from {data_scaler_path}: {e}", exc_info=True)
|
||||
return None
|
||||
loaded_artifacts['data_scaler'] = data_scaler
|
||||
|
||||
|
||||
# 5. Initialize Model Architecture
|
||||
model_instance = LSTMForecastLightningModule(
|
||||
model_config=main_config.model,
|
||||
train_config=main_config.training, # Pass train config if needed
|
||||
input_size=input_size,
|
||||
target_scaler=target_scaler # Pass scaler to model if it uses it internally during inference
|
||||
target_scaler=target_scaler,
|
||||
data_scaler=data_scaler
|
||||
)
|
||||
logger.debug("Initialized model architecture.")
|
||||
|
||||
# 5. Load Model State Dictionary
|
||||
# 6. Load Model State Dictionary
|
||||
if not model_path.is_file():
|
||||
logger.error(f"Model checkpoint file not found at {model_path}")
|
||||
return None
|
||||
# Load onto CPU first to avoid GPU memory issues if the loading machine is different
|
||||
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
|
||||
model_dict = torch.load(model_path, map_location=torch.device('cpu'), weights_only=False)
|
||||
|
||||
# Adjust state dict keys if saved with 'model.' prefix from Lightning wrapper common during saving ckpt
|
||||
if any(key.startswith('model.') for key in state_dict.get('state_dict', state_dict).keys()):
|
||||
state_dict = {k.partition('model.')[2]: v for k, v in state_dict.get('state_dict', state_dict).items()}
|
||||
if any(key.startswith('model.') for key in model_dict.get('state_dict', model_dict).keys()):
|
||||
model_dict = {k.partition('model.')[2]: v for k, v in model_dict.get('state_dict', model_dict).items()}
|
||||
logger.debug("Adjusted state dict keys (removed 'model.' prefix).")
|
||||
|
||||
# Load the state dict
|
||||
# Use strict=False initially if unsure about exact key matching, but strict=True is safer
|
||||
try:
|
||||
load_result = model_instance.load_state_dict(state_dict, strict=True)
|
||||
logger.debug(f"Model state loaded. Result: {load_result}")
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Error loading state dict into model (strict=True): {e}. Trying strict=False.")
|
||||
try:
|
||||
load_result = model_instance.load_state_dict(state_dict, strict=False)
|
||||
logger.warning(f"Model state loaded with strict=False. Result: {load_result}. Check for missing/unexpected keys.")
|
||||
except Exception as e_false:
|
||||
logger.error(f"Failed to load state dict even with strict=False: {e_false}", exc_info=True)
|
||||
return None
|
||||
|
||||
state_dict = model_dict.get('state_dict', None)
|
||||
|
||||
# Load the state dict, use strict=True for safety
|
||||
if state_dict is not None:
|
||||
try:
|
||||
load_result = model_instance.load_state_dict(state_dict, strict=True)
|
||||
logger.debug(f"Model state loaded. Result: {load_result}")
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Error loading state dict into model: {e}. Mismatched keys or unexpected keys found.", exc_info=True)
|
||||
logger.error(f"Model keys: {list(model_instance.state_dict().keys())[:20]}...") # Show some model keys
|
||||
logger.error(f"Checkpoint keys: {list(state_dict.keys())[:20]}...") # Show some checkpoint keys
|
||||
return None
|
||||
else:
|
||||
raise ValueError("State dict is not Supposed to be `None`.")
|
||||
|
||||
model_instance.eval() # Set model to evaluation mode
|
||||
loaded_artifacts['model_instance'] = model_instance
|
||||
logger.info(f"Successfully loaded single model artifact: {model_path.name}")
|
||||
|
||||
return loaded_artifacts
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"A required file was not found during artifact loading for {model_path.parent}.", exc_info=True)
|
||||
return None
|
||||
@ -133,6 +134,7 @@ def load_single_model_artifact(
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load single model artifact from {model_path.parent}: {e}", exc_info=True)
|
||||
return None
|
||||
return loaded_artifacts
|
||||
|
||||
|
||||
def load_ensemble_artifact(
|
||||
@ -193,6 +195,9 @@ def load_ensemble_artifact(
|
||||
loaded_fold_artifacts: List[Dict[str, Any]] = []
|
||||
common_feature_config: Optional[FeatureConfig] = None
|
||||
common_target_col: Optional[str] = None
|
||||
# Flag to track if data scalers are expected based on the first loaded fold
|
||||
# None = Undetermined, True = Expected, False = Not expected/present
|
||||
data_scaler_expected: Optional[bool] = None
|
||||
|
||||
logger.info(f"Loading artifacts for {len(fold_models_definitions)} folds defined in the ensemble...")
|
||||
|
||||
@ -202,7 +207,8 @@ def load_ensemble_artifact(
|
||||
logger.debug(f"--- Loading Fold {fold_id} ---")
|
||||
|
||||
model_path_rel = fold_def.get("model_path")
|
||||
scaler_path_rel = fold_def.get("target_scaler_path")
|
||||
target_scaler_path_rel = fold_def.get("target_scaler_path") # Keep target scaler path separate
|
||||
data_scaler_path_rel = fold_def.get("data_scaler_path") # Get the data scaler path
|
||||
input_size_path_rel = fold_def.get("input_size_path")
|
||||
config_path_rel = fold_def.get("config_path")
|
||||
|
||||
@ -215,19 +221,21 @@ def load_ensemble_artifact(
|
||||
abs_model_path = (absolute_artifacts_base_dir / Path(model_path_rel)).resolve()
|
||||
abs_input_size_path = (absolute_artifacts_base_dir / Path(input_size_path_rel)).resolve()
|
||||
abs_config_path = (absolute_artifacts_base_dir / Path(config_path_rel)).resolve()
|
||||
abs_scaler_path = (absolute_artifacts_base_dir / Path(scaler_path_rel)).resolve() if scaler_path_rel else None
|
||||
|
||||
abs_target_scaler_path = (absolute_artifacts_base_dir / Path(target_scaler_path_rel)).resolve() if target_scaler_path_rel else None
|
||||
abs_data_scaler_path = (absolute_artifacts_base_dir / Path(data_scaler_path_rel)).resolve() if data_scaler_path_rel else None # Resolve data scaler path
|
||||
logger.debug(f"Fold {fold_id} - Model Path: {abs_model_path}")
|
||||
logger.debug(f"Fold {fold_id} - Config Path: {abs_config_path}")
|
||||
logger.debug(f"Fold {fold_id} - Input Size Path: {abs_input_size_path}")
|
||||
logger.debug(f"Fold {fold_id} - Scaler Path: {abs_scaler_path}")
|
||||
logger.debug(f"Fold {fold_id} - Target Scaler Path: {abs_target_scaler_path}")
|
||||
logger.debug(f"Fold {fold_id} - Data Scaler Path: {abs_data_scaler_path}") # Log data scaler path
|
||||
|
||||
# Load the artifacts for this single fold using the other function
|
||||
single_fold_loaded_artifacts = load_single_model_artifact(
|
||||
model_path=abs_model_path,
|
||||
config_path=abs_config_path,
|
||||
input_size_path=abs_input_size_path,
|
||||
target_scaler_path=abs_scaler_path
|
||||
target_scaler_path=abs_target_scaler_path,
|
||||
data_scaler_path=abs_data_scaler_path # Pass the data scaler path
|
||||
)
|
||||
|
||||
if single_fold_loaded_artifacts:
|
||||
@ -237,61 +245,71 @@ def load_ensemble_artifact(
|
||||
logger.info(f"Successfully loaded artifacts for fold {fold_id}.")
|
||||
|
||||
# --- Consistency Check (Optional but Recommended) ---
|
||||
# Store the feature config and target col from the first successful fold
|
||||
# Then compare subsequent folds against these
|
||||
current_feature_config = single_fold_loaded_artifacts['feature_config']
|
||||
current_target_col = single_fold_loaded_artifacts['main_forecasting_config'].data.target_col
|
||||
current_data_scaler = single_fold_loaded_artifacts.get('data_scaler') # Check if data scaler was loaded
|
||||
|
||||
if common_feature_config is None:
|
||||
common_feature_config = current_feature_config
|
||||
common_target_col = current_target_col
|
||||
logger.debug(f"Set common feature config and target column based on fold {fold_id}.")
|
||||
# Determine if data scalers are expected based on this first fold
|
||||
data_scaler_expected = (current_data_scaler is not None)
|
||||
logger.debug(f"Set common config, target column, and data_scaler_expected={data_scaler_expected} based on fold {fold_id}.")
|
||||
else:
|
||||
# Compare crucial feature engineering aspects
|
||||
if common_feature_config.sequence_length != current_feature_config.sequence_length or \
|
||||
set(common_feature_config.forecast_horizon) != set(current_feature_config.forecast_horizon) or \
|
||||
common_feature_config.scaling_method != current_feature_config.scaling_method: # Add more checks if needed
|
||||
logger.error(f"Fold {fold_id}: Feature configuration mismatch with previous folds. Cannot proceed with this ensemble definition.")
|
||||
# You might want to compare more fields like lags, rolling_windows etc.
|
||||
return None # Fail hard if configs are inconsistent
|
||||
if common_target_col != current_target_col:
|
||||
logger.error(f"Fold {fold_id}: Target column '{current_target_col}' differs from previous folds ('{common_target_col}'). Cannot proceed.")
|
||||
return None # Fail hard
|
||||
|
||||
# Check consistency of data scaler presence/absence
|
||||
current_fold_has_data_scaler = (current_data_scaler is not None)
|
||||
if data_scaler_expected != current_fold_has_data_scaler:
|
||||
logger.error(f"Fold {fold_id}: Inconsistent presence of data scaler compared to previous folds (Expected: {data_scaler_expected}, Found: {current_fold_has_data_scaler}). Cannot proceed.")
|
||||
# Ideally, also check scaler *type* if present, but presence check is a start
|
||||
# if data_scaler_expected and type(common_data_scaler) != type(current_data_scaler): # Example type check
|
||||
# logger.error(...)
|
||||
return None # Fail hard on inconsistent data scaler presence
|
||||
|
||||
else:
|
||||
logger.error(f"Failed to load artifacts for fold {fold_id}. Skipping fold.")
|
||||
# Decide if ensemble loading should fail if *any* fold fails
|
||||
# For now, we continue and will check if enough folds loaded later
|
||||
# Fail hard if any fold fails, as inconsistency is likely
|
||||
return None
|
||||
|
||||
except TypeError as e:
|
||||
# Catch potential errors if paths are None or invalid types
|
||||
logger.error(f"Fold {fold_id}: Error constructing artifact paths - check definition file content: {e}", exc_info=True)
|
||||
continue
|
||||
return None # Fail hard if path construction fails
|
||||
except Exception as e:
|
||||
logger.error(f"Fold {fold_id}: Unexpected error during loading: {e}", exc_info=True)
|
||||
continue # Skip this fold
|
||||
return None # Fail hard on unexpected errors per fold
|
||||
|
||||
|
||||
# --- Final Checks and Return ---
|
||||
if not loaded_fold_artifacts:
|
||||
logger.error("Failed to load artifacts for *any* fold in the ensemble.")
|
||||
return None
|
||||
|
||||
# Add a check if a minimum number of folds is required (e.g., > 1)
|
||||
if len(loaded_fold_artifacts) < 1: # Or maybe check against len(fold_models_definitions)?
|
||||
logger.error(f"Only successfully loaded {len(loaded_fold_artifacts)} folds, which might be insufficient for the ensemble.")
|
||||
# Decide if this is an error or just a warning
|
||||
# Check if number of loaded folds matches definition (since we fail hard on individual fold errors now)
|
||||
if len(loaded_fold_artifacts) != len(fold_models_definitions):
|
||||
logger.error(f"Loaded {len(loaded_fold_artifacts)} folds, but expected {len(fold_models_definitions)}. Indicates an earlier loading failure for some folds.")
|
||||
# This path might not be reachable if we fail hard above, but keep as safeguard
|
||||
return None
|
||||
|
||||
if common_feature_config is None or common_target_col is None:
|
||||
# This should not happen if loaded_fold_artifacts is not empty, but check anyway
|
||||
# This should not happen if loaded_fold_artifacts is not empty and we fail hard, but check anyway
|
||||
logger.error("Internal error: Could not determine common feature config or target column for the ensemble.")
|
||||
return None
|
||||
|
||||
logger.info(f"Successfully loaded artifacts for {len(loaded_fold_artifacts)} ensemble folds.")
|
||||
logger.info(f"Successfully loaded artifacts for all {len(loaded_fold_artifacts)} ensemble folds.")
|
||||
|
||||
return {
|
||||
'ensemble_method': ensemble_method,
|
||||
'fold_artifacts': loaded_fold_artifacts, # List of dicts
|
||||
'fold_artifacts': loaded_fold_artifacts, # List of dicts (now includes data_scaler)
|
||||
'ensemble_feature_config': common_feature_config, # The common config
|
||||
'ensemble_target_col': common_target_col # The common target column name
|
||||
}
|
||||
|
@ -8,6 +8,11 @@ class ModelEvalConfig(BaseModel):
|
||||
model_path: str = Field(..., description="Path to the saved PyTorch model file (.ckpt for type='model') or the ensemble definition JSON file (.json for type='ensemble').")
|
||||
model_config_path: str = Field(..., description="Path to the forecasting config (YAML) used for this model training (or for the best trial in an ensemble).")
|
||||
target_scaler_path: Optional[str] = Field(None, description="Path to the target scaler file for the single model (or will be loaded per fold for ensemble).")
|
||||
data_scaler_path: Optional[str] = Field(None, description="Path to the data scaler file for the single model (or will be loaded per fold for ensemble).")
|
||||
# test_loader_path: Optional[str] = Field(None, description="Path to the test loader file for the single model (or will be loaded per fold for ensemble).")
|
||||
input_size_path: Optional[str] = Field(None, description="Path to the input size file for the single model (or will be loaded per fold for ensemble).")
|
||||
|
||||
|
||||
|
||||
class OptimizationRunConfig(BaseModel):
|
||||
"""Main configuration for running battery optimization with multiple models/ensembles."""
|
||||
@ -15,4 +20,5 @@ class OptimizationRunConfig(BaseModel):
|
||||
max_capacity: float = Field(..., description="Maximum energy capacity of the battery (MWh).")
|
||||
max_rate: float = Field(..., description="Maximum charge/discharge power rate of the battery (MW).")
|
||||
optimization_horizon_hours: int = Field(24, gt=0, description="The length of the time window (in hours) for optimization.")
|
||||
output_dir: str = Field(..., description="Output directory for the optimization results.")
|
||||
models: List[ModelEvalConfig] = Field(..., description="List of forecasting models or ensembles to evaluate.")
|
@ -17,19 +17,19 @@ from pytorch_lightning.callbacks import EarlyStopping
|
||||
|
||||
# Import necessary components from the forecasting_model package
|
||||
from forecasting_model.utils.forecast_config_model import MainConfig
|
||||
from forecasting_model.data_processing import (
|
||||
from forecasting_model.utils.data_processing import (
|
||||
prepare_fold_data_and_loaders,
|
||||
split_data_classic
|
||||
)
|
||||
from forecasting_model.train.model import LSTMForecastLightningModule
|
||||
from forecasting_model.train.classic import run_classic_training
|
||||
from forecasting_model.train.classic import run_model_training
|
||||
|
||||
|
||||
# Import helper functions from forecasting_model_run.py
|
||||
from forecasting_model.utils.helper import load_config, set_seeds
|
||||
|
||||
# Import the data processing functions
|
||||
from forecasting_model.data_processing import load_raw_data
|
||||
from forecasting_model import load_raw_data
|
||||
|
||||
# --- Suppress specific PL warnings about logger=True with no logger ---
|
||||
# This is expected behavior in optuna_run.py where logger=False is intentional
|
||||
@ -124,7 +124,7 @@ def objective(
|
||||
if not use_configured_rolling: trial_config_dict['features']['rolling_window_sizes'] = []
|
||||
trial_config_dict['features']['use_time_features'] = trial.suggest_categorical('use_time_features', [True, False])
|
||||
trial_config_dict['features']['sinus_curve'] = trial.suggest_categorical('sinus_curve', [True, False])
|
||||
trial_config_dict['features']['cosin_curve'] = trial.suggest_categorical('cosin_curve', [True, False])
|
||||
trial_config_dict['features']['cosine_curve'] = trial.suggest_categorical('cosine_curve', [True, False])
|
||||
trial_config_dict['features']['fill_nan'] = trial.suggest_categorical('fill_nan', ['ffill', 'bfill', 0])
|
||||
# ----- End of Hyperparameter Suggestions -----
|
||||
|
||||
@ -156,7 +156,6 @@ def objective(
|
||||
|
||||
# --- 3. Run Classic Train/Test ---
|
||||
logger.info(f"Trial {trial.number}: Starting Classic Run...")
|
||||
validation_metric_value = worst_value # Initialize to worst
|
||||
try:
|
||||
n_samples = len(df)
|
||||
val_frac = trial_config.cross_validation.val_size_fraction
|
||||
@ -165,7 +164,7 @@ def objective(
|
||||
|
||||
|
||||
# Prepare data for classic split
|
||||
train_loader_cl, val_loader_cl, test_loader_cl, target_scaler_cl, input_size_cl = prepare_fold_data_and_loaders(
|
||||
train_loader_cl, val_loader_cl, test_loader_cl, target_scaler_cl, data_scaler, input_size_cl = prepare_fold_data_and_loaders(
|
||||
full_df=df, train_idx=train_idx_cl, val_idx=val_idx_cl, test_idx=test_idx_cl,
|
||||
target_col=trial_config.data.target_col, feature_config=trial_config.features,
|
||||
train_config=trial_config.training, eval_config=trial_config.evaluation
|
||||
@ -174,7 +173,7 @@ def objective(
|
||||
# Initialize Model
|
||||
model_cl = LSTMForecastLightningModule(
|
||||
model_config=trial_config.model, train_config=trial_config.training,
|
||||
input_size=input_size_cl, target_scaler=target_scaler_cl
|
||||
input_size=input_size_cl, target_scaler=target_scaler_cl, data_scaler=data_scaler
|
||||
)
|
||||
|
||||
# Callbacks (EarlyStopping and Pruning)
|
||||
@ -188,6 +187,7 @@ def objective(
|
||||
|
||||
# Trainer for classic run
|
||||
trainer_cl = pl.Trainer(
|
||||
check_val_every_n_epoch=trial_config.training.check_val_n_epoch,
|
||||
accelerator='gpu' if torch.cuda.is_available() else 'cpu', devices=1 if torch.cuda.is_available() else None,
|
||||
max_epochs=trial_config.training.epochs, callbacks=callbacks_cl, logger=False, # logger=False as per original
|
||||
enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False,
|
||||
@ -234,7 +234,7 @@ def objective(
|
||||
logger.info(f"Trial {trial.number}: Finished Classic Run in {time.perf_counter() - trial_start_time:.2f}s")
|
||||
|
||||
except optuna.TrialPruned:
|
||||
# Propagate prune signal, objective will be set to worst later by Optuna
|
||||
# Propagate prune signal, objective will be set to "worst" later by Optuna
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Trial {trial.number}: Failed during classic run phase: {e}", exc_info=True)
|
||||
@ -260,43 +260,65 @@ def run_hpo():
|
||||
args = parse_arguments()
|
||||
config_path = Path(args.config)
|
||||
try:
|
||||
base_config = load_config(config_path) # Load base config once
|
||||
base_config = load_config(config_path, MainConfig) # Load base config once
|
||||
logger.info(f"Successfully loaded configuration from {config_path}")
|
||||
except Exception as e:
|
||||
logger.critical(f"Failed to load configuration from {config_path}: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
# Setup output dir...
|
||||
# --- Setup Output Dir ---
|
||||
# 1. Determine the main output directory
|
||||
if args.output_dir:
|
||||
hpo_base_output_dir = Path(args.output_dir) # Use specific name for HPO dir
|
||||
logger.info(f"Using HPO output directory from command line: {hpo_base_output_dir}")
|
||||
elif base_config.optuna.storage and base_config.optuna.storage.startswith("sqlite:///"):
|
||||
hpo_base_output_dir = Path(base_config.optuna.storage.replace("sqlite:///", "")).parent
|
||||
logger.info(f"Using HPO output directory from Optuna storage path: {hpo_base_output_dir}")
|
||||
main_output_dir = Path(args.output_dir)
|
||||
logger.info(f"Using main output directory from command line: {main_output_dir}")
|
||||
elif hasattr(base_config, 'output_dir') and base_config.output_dir:
|
||||
main_output_dir = Path(base_config.output_dir)
|
||||
logger.info(f"Using main output directory from config file: {main_output_dir}")
|
||||
else:
|
||||
# Use output_dir from main config if available, otherwise default
|
||||
main_output_dir = Path(getattr(base_config, 'output_dir', 'output'))
|
||||
hpo_base_output_dir = main_output_dir / 'hpo_results'
|
||||
logger.info(f"Using HPO output directory: {hpo_base_output_dir}")
|
||||
hpo_base_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
main_output_dir = Path("output") # Default
|
||||
logger.warning(f"No output directory specified in config or args, defaulting to: {main_output_dir}")
|
||||
|
||||
# Setup logging... (ensure file handler uses hpo_base_output_dir)
|
||||
# 2. Define the specific directory for this classic HPO run
|
||||
classic_hpo_output_dir = main_output_dir / "classic"
|
||||
|
||||
# 3. Create directories
|
||||
main_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
classic_hpo_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Classic HPO outputs will be saved under: {classic_hpo_output_dir}")
|
||||
|
||||
# --- Setup Logging ---
|
||||
try:
|
||||
level_name = base_config.log_level.upper()
|
||||
effective_log_level = logging.getLevelName(level_name)
|
||||
log_file = hpo_base_output_dir / f"{base_config.optuna.study_name}_hpo.log"
|
||||
file_handler = logging.FileHandler(log_file, mode='a')
|
||||
# Ensure study name is filesystem-safe
|
||||
safe_study_name = base_config.optuna.study_name
|
||||
safe_study_name = "".join(c if c.isalnum() or c in ('_', '-') else '_' for c in safe_study_name)
|
||||
# Place log file directly inside the classic HPO directory
|
||||
log_file = classic_hpo_output_dir / f"{safe_study_name}_hpo.log" # Changed filename slightly
|
||||
file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8') # Add encoding
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)-25s - %(levelname)-7s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
||||
file_handler.setFormatter(formatter)
|
||||
# Add handler only if it's not already added (e.g., if run_hpo is called multiple times)
|
||||
# Add handler only if it's not already added
|
||||
if not any(isinstance(h, logging.FileHandler) and h.baseFilename == str(log_file.resolve()) for h in logger.handlers):
|
||||
logger.addHandler(file_handler)
|
||||
logger.setLevel(effective_log_level)
|
||||
logger.info(f"Set log level to {level_name}. Logging HPO run to console and {log_file}")
|
||||
if effective_log_level <= logging.DEBUG: logger.debug("Debug logging enabled.")
|
||||
except (AttributeError, ValueError) as e:
|
||||
except (AttributeError, ValueError, TypeError) as e: # Add TypeError
|
||||
logger.warning(f"Could not set log level from config. Defaulting to INFO. Error: {e}")
|
||||
logger.setLevel(logging.INFO)
|
||||
# Still try to log to a default file if possible
|
||||
try:
|
||||
# Default log file also goes into the specific classic directory
|
||||
log_file = classic_hpo_output_dir / "default_classic_hpo.log"
|
||||
file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8')
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)-25s - %(levelname)-7s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
||||
file_handler.setFormatter(formatter)
|
||||
if not any(isinstance(h, logging.FileHandler) and h.baseFilename == str(log_file.resolve()) for h in logger.handlers):
|
||||
logger.addHandler(file_handler)
|
||||
logger.info(f"Logging to default file: {log_file}")
|
||||
except Exception as log_e:
|
||||
logger.error(f"Failed to set up default file logging: {log_e}")
|
||||
|
||||
|
||||
# Setup seeding...
|
||||
@ -326,39 +348,42 @@ def run_hpo():
|
||||
logger.critical("Optuna configuration section ('optuna') missing.")
|
||||
sys.exit(1)
|
||||
|
||||
storage_path = hpo_config.storage
|
||||
if storage_path and storage_path.startswith("sqlite:///"):
|
||||
db_path = Path(storage_path.replace("sqlite:///", ""))
|
||||
if not db_path.is_absolute():
|
||||
db_path = hpo_base_output_dir / db_path # Relative to HPO output dir
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
storage_string = hpo_config.storage
|
||||
storage_path = None # Initialize
|
||||
|
||||
if storage_string and storage_string.startswith("sqlite:///"):
|
||||
db_filename = storage_string.replace("sqlite:///", "").strip()
|
||||
if not db_filename:
|
||||
# Use study name if filename is empty
|
||||
db_filename = f"{safe_study_name}.db" # Default DB name for classic
|
||||
logger.warning(f"SQLite path in config was empty, using default filename: {db_filename}")
|
||||
# Place the DB file inside the classic HPO directory
|
||||
db_path = classic_hpo_output_dir / db_filename
|
||||
storage_path = f"sqlite:///{db_path.resolve()}"
|
||||
logger.info(f"Using Optuna storage: {storage_path}")
|
||||
elif storage_path:
|
||||
logger.info(f"Using Optuna storage: {storage_path} (non-SQLite)")
|
||||
logger.info(f"Using SQLite storage: {storage_path}")
|
||||
elif storage_string:
|
||||
# Assume it's a non-SQLite connection string or a pre-configured path
|
||||
storage_path = storage_string
|
||||
logger.warning(f"Using non-SQLite Optuna storage: {storage_path}. Note: DB file will not be placed inside {classic_hpo_output_dir}")
|
||||
else:
|
||||
logger.warning("No Optuna storage DB specified in config, using in-memory storage.")
|
||||
storage_path = None # Explicitly set to None for in-memory
|
||||
logger.warning("No Optuna storage DB specified, using in-memory storage (results will be lost on exit).")
|
||||
|
||||
|
||||
try:
|
||||
# Change to single objective 'minimize'
|
||||
study = optuna.create_study(
|
||||
study_name=hpo_config.study_name,
|
||||
storage=storage_path,
|
||||
direction="minimize", # Changed to single direction
|
||||
direction=hpo_config.direction, # Use direction from config
|
||||
load_if_exists=True,
|
||||
pruner=optuna.pruners.MedianPruner() if hpo_config.pruning else optuna.pruners.NopPruner()
|
||||
)
|
||||
# Remove multi-objective check/attribute setting
|
||||
# if not study._is_multi_objective:
|
||||
# logger.warning(f"Study '{hpo_config.study_name}' exists but is not multi-objective.")
|
||||
# elif 'objective_names' not in study.user_attrs:
|
||||
# study.set_user_attr('objective_names', objective_names)
|
||||
|
||||
|
||||
# --- Run Optimization ---
|
||||
logger.info(f"Starting Optuna single-objective optimization: study='{hpo_config.study_name}', n_trials={hpo_config.n_trials}") # Updated log message
|
||||
logger.info(f"Starting Optuna single-objective optimization: study='{hpo_config.study_name}', n_trials={hpo_config.n_trials}, direction='{hpo_config.direction}'")
|
||||
study.optimize(
|
||||
lambda trial: objective(trial, base_config, df), # Pass base_config
|
||||
lambda trial: objective(trial, base_config, df), # Objective doesn't need the path here
|
||||
n_trials=hpo_config.n_trials,
|
||||
timeout=None,
|
||||
gc_after_trial=True
|
||||
@ -368,11 +393,13 @@ def run_hpo():
|
||||
logger.info("--- Optuna HPO Finished ---")
|
||||
logger.info(f"Number of finished trials: {len(study.trials)}")
|
||||
|
||||
# Get the single best trial
|
||||
best_trial = study.best_trial
|
||||
if best_trial is None:
|
||||
best_trial = None
|
||||
try:
|
||||
best_trial = study.best_trial
|
||||
except ValueError:
|
||||
logger.warning("Optuna study finished, but no successful trial was completed.")
|
||||
else:
|
||||
|
||||
if best_trial:
|
||||
logger.info(f"Best trial found (Trial {best_trial.number}):")
|
||||
# Log details for the best trial
|
||||
validation_metric_monitor = base_config.optuna.metric_to_optimize
|
||||
@ -383,8 +410,9 @@ def run_hpo():
|
||||
|
||||
# --- Re-run and Save Artifacts for the Best Trial ---
|
||||
logger.info(f"-> Re-running Best Trial {best_trial.number} to save artifacts...")
|
||||
trial_output_dir = hpo_base_output_dir / f"best_trial_num{best_trial.number}" # Simplified directory name
|
||||
trial_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
# Define the output directory for this specific best trial run
|
||||
best_trial_output_dir = classic_hpo_output_dir / f"best_trial_num{best_trial.number}"
|
||||
best_trial_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
# 1. Create config for this trial
|
||||
@ -402,41 +430,53 @@ def run_hpo():
|
||||
|
||||
# Ensure evaluation plots and model saving are enabled for this final run
|
||||
best_config_dict['evaluation']['save_plots'] = True
|
||||
best_config_dict['training']['save_model'] = True # Assuming you want to save the best model
|
||||
# Add a flag to save the model if not already present/configurable
|
||||
# best_config_dict['training']['save_model'] = True # Assuming you have this or handle it in run_model_training
|
||||
|
||||
# Validate the final config for this trial
|
||||
best_trial_config = MainConfig(**best_config_dict)
|
||||
|
||||
# Save the specific config used for this run
|
||||
with open(trial_output_dir / "best_config.yaml", 'w') as f:
|
||||
yaml.dump(best_config_dict, f, default_flow_style=False, sort_keys=False)
|
||||
# Save the specific config used for this best run inside its directory
|
||||
with open(best_trial_output_dir / "best_run_config.yaml", 'w', encoding='utf-8') as f:
|
||||
yaml.dump(best_config_dict, f, default_flow_style=False, sort_keys=False, allow_unicode=True)
|
||||
|
||||
# 2. Run classic training (which saves model & plots)
|
||||
# 2. Run classic training, saving outputs under best_trial_output_dir
|
||||
logger.info(f"-> Running classic training for Best Trial {best_trial.number}...")
|
||||
# Pass the specific config and output directory
|
||||
run_classic_training(
|
||||
run_model_training(
|
||||
config=best_trial_config,
|
||||
full_df=df,
|
||||
output_base_dir=trial_output_dir # outputs -> hpo_results/best_trial_num<n>/classic_run/
|
||||
# Pass the specific directory for this run's artifacts
|
||||
output_base_dir=best_trial_output_dir
|
||||
)
|
||||
logger.info(f"-> Finished re-running and saving artifacts for Best Trial {best_trial.number} to {trial_output_dir}")
|
||||
logger.info(f"-> Finished re-running and saving artifacts for Best Trial {best_trial.number} to {best_trial_output_dir}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"-> Failed to re-run or save artifacts for Best Trial {best_trial.number}: {e}", exc_info=True)
|
||||
|
||||
# --- Save Best Hyperparameters ---
|
||||
best_params_file = hpo_base_output_dir / f"{hpo_config.study_name}_best_params.json" # Simplified filename
|
||||
# --- Save Best Hyperparameters and Config at the Top Level ---
|
||||
# Save best parameters file directly into the classic HPO output dir
|
||||
best_params_file = classic_hpo_output_dir / f"{safe_study_name}_best_params.json"
|
||||
try:
|
||||
with open(best_params_file, 'w') as f:
|
||||
with open(best_params_file, 'w', encoding='utf-8') as f:
|
||||
import json
|
||||
json.dump(best_trial.params, f, indent=4) # Use best_trial.params
|
||||
logger.info(f"Hyperparameters of the best trial saved to {best_params_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save parameters for best trial: {e}")
|
||||
logger.error(f"Failed to save parameters for best trial: {e}", exc_info=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.critical(f"A critical error occurred during the Optuna study: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
# Save the best config file directly into the classic HPO output dir
|
||||
best_config_file = classic_hpo_output_dir / f"{safe_study_name}_best_config.yaml"
|
||||
try:
|
||||
# best_config_dict should still hold the config from the re-run step above
|
||||
with open(best_config_file, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(best_config_dict, f, default_flow_style=False, sort_keys=False, allow_unicode=True)
|
||||
logger.info(f"Configuration for best trial saved to {best_config_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save best configuration: {e}", exc_info=True)
|
||||
|
||||
|
||||
except optuna.exceptions.StorageInternalError as e:
|
||||
logger.critical(f"Optuna storage error: {e}. Check storage path/permissions: {storage_path}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_hpo()
|
@ -2,38 +2,29 @@ import argparse
|
||||
import logging
|
||||
import sys
|
||||
import warnings
|
||||
import copy # For deep copying config
|
||||
import copy
|
||||
from pathlib import Path
|
||||
import time
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import yaml
|
||||
import json # Import json to save best ensemble definition
|
||||
|
||||
import optuna
|
||||
|
||||
# Import necessary components from the forecasting_model package
|
||||
from forecasting_model.utils.forecast_config_model import MainConfig
|
||||
from forecasting_model.data_processing import (
|
||||
load_raw_data,
|
||||
TimeSeriesCrossValidationSplitter,
|
||||
# prepare_fold_data_and_loaders used by run_single_fold
|
||||
)
|
||||
# Import the single fold runner from the main script
|
||||
from forecasting_model import TimeSeriesCrossValidationSplitter, load_raw_data
|
||||
from forecasting_model_run import run_single_fold
|
||||
from forecasting_model.train.ensemble_evaluation import run_ensemble_evaluation
|
||||
from typing import List, Optional, Tuple, Dict, Any # Added Any for dictionary
|
||||
from typing import List, Dict, Any #
|
||||
|
||||
# Import helper functions
|
||||
from forecasting_model.utils.helper import load_config, set_seeds, aggregate_cv_metrics, save_results
|
||||
from forecasting_model.utils.helper import load_config, set_seeds
|
||||
|
||||
# --- Suppress specific PL warnings about logger=True with no logger ---
|
||||
# This is expected behavior in optuna_run.py where logger=False is intentional
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=".*You called `self.log.*logger=True.*no logger configured.*",
|
||||
category=UserWarning, # These specific warnings are often UserWarnings
|
||||
category=UserWarning,
|
||||
module="pytorch_lightning.core.module"
|
||||
)
|
||||
|
||||
@ -43,10 +34,9 @@ mpl_logger.setLevel(logging.WARNING)
|
||||
pil_logger = logging.getLogger('PIL')
|
||||
pil_logger.setLevel(logging.WARNING)
|
||||
pl_logger = logging.getLogger('pytorch_lightning')
|
||||
pl_logger.setLevel(logging.WARNING) # Set PL to WARNING by default, INFO/DEBUG set below if needed
|
||||
pl_logger.setLevel(logging.WARNING)
|
||||
|
||||
# --- Basic Logging Setup ---
|
||||
# Configure logging early. Level will be set properly later based on config.
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)-7s - %(message)s',
|
||||
datefmt='%H:%M:%S')
|
||||
@ -81,7 +71,7 @@ def objective(
|
||||
trial: optuna.Trial,
|
||||
base_config: MainConfig,
|
||||
df: pd.DataFrame,
|
||||
hpo_base_output_dir: Path # Pass base dir for trial outputs
|
||||
ensemble_hpo_output_dir: Path # Renamed parameter for clarity
|
||||
) -> float: # Return the single ensemble metric to optimize
|
||||
"""
|
||||
Optuna objective function optimizing ensemble performance.
|
||||
@ -90,7 +80,7 @@ def objective(
|
||||
trial_start_time = time.perf_counter()
|
||||
|
||||
# Define trial-specific output directory for fold artifacts
|
||||
trial_artifacts_dir = hpo_base_output_dir / "ensemble_runs_artifacts" / f"trial_{trial.number}"
|
||||
trial_artifacts_dir = ensemble_hpo_output_dir / "ensemble_runs_artifacts" / f"trial_{trial.number}"
|
||||
trial_artifacts_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.debug(f"Trial artifacts will be saved to: {trial_artifacts_dir}")
|
||||
|
||||
@ -116,12 +106,12 @@ def objective(
|
||||
# ----- Suggest Hyperparameters -----
|
||||
# Modify trial_config_dict using trial.suggest_*
|
||||
trial_config_dict['training']['learning_rate'] = trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True)
|
||||
trial_config_dict['training']['batch_size'] = trial.suggest_categorical('batch_size', [32, 64, 128, 256])
|
||||
trial_config_dict['training']['batch_size'] = trial.suggest_categorical('batch_size', [16, 32, 64, 128, 256, 512])
|
||||
trial_config_dict['training']['loss_function'] = trial.suggest_categorical('loss_function', ['MSE', 'MAE'])
|
||||
trial_config_dict['model']['hidden_size'] = trial.suggest_int('hidden_size', 18, 498, step=32)
|
||||
trial_config_dict['model']['num_layers'] = trial.suggest_int('num_layers', 1, 8)
|
||||
trial_config_dict['model']['dropout'] = trial.suggest_float('dropout', 0.0, 0.25, step=0.05)
|
||||
trial_config_dict['features']['sequence_length'] = trial.suggest_int('sequence_length', 24, 168, step=12)
|
||||
trial_config_dict['model']['num_layers'] = trial.suggest_int('num_layers', 1, 3)
|
||||
trial_config_dict['model']['dropout'] = trial.suggest_float('dropout', 0.0, 0.5, step=0.05)
|
||||
trial_config_dict['features']['sequence_length'] = trial.suggest_int('sequence_length', 3, 72, step=2)
|
||||
trial_config_dict['features']['scaling_method'] = trial.suggest_categorical('scaling_method', ['standard', 'minmax', None])
|
||||
use_configured_lags = trial.suggest_categorical('use_lags', [True, False])
|
||||
if not use_configured_lags: trial_config_dict['features']['lags'] = []
|
||||
@ -129,7 +119,7 @@ def objective(
|
||||
if not use_configured_rolling: trial_config_dict['features']['rolling_window_sizes'] = []
|
||||
trial_config_dict['features']['use_time_features'] = trial.suggest_categorical('use_time_features', [True, False])
|
||||
trial_config_dict['features']['sinus_curve'] = trial.suggest_categorical('sinus_curve', [True, False])
|
||||
trial_config_dict['features']['cosin_curve'] = trial.suggest_categorical('cosin_curve', [True, False])
|
||||
trial_config_dict['features']['cosine_curve'] = trial.suggest_categorical('cosine_curve', [True, False])
|
||||
trial_config_dict['features']['fill_nan'] = trial.suggest_categorical('fill_nan', ['ffill', 'bfill', 0])
|
||||
# ----- End of Suggestions -----
|
||||
|
||||
@ -164,8 +154,8 @@ def objective(
|
||||
all_fold_best_val_scores = {} # Store best val scores for pruning
|
||||
actual_folds_trained = 0
|
||||
# Store paths to saved models and scalers for this trial
|
||||
fold_model_paths = []
|
||||
fold_scaler_paths = []
|
||||
# fold_model_paths = [] # Removed, using fold_artifact_details instead
|
||||
# fold_scaler_paths = [] # Removed, using fold_artifact_details instead
|
||||
try:
|
||||
cv_splitter = TimeSeriesCrossValidationSplitter(trial_config.cross_validation, len(df))
|
||||
|
||||
@ -176,13 +166,13 @@ def objective(
|
||||
|
||||
try:
|
||||
# Use run_single_fold - it handles training and saving artifacts
|
||||
# Pass trial_output_dir so fold artifacts are saved per trial
|
||||
fold_metrics, best_val_score, saved_model_path, saved_scaler_path, saved_input_size_path, saved_config_path = run_single_fold(
|
||||
fold_metrics, best_val_score, saved_model_path, saved_target_scaler_path, saved_data_scaler_path, saved_input_size_path, saved_config_path = run_single_fold(
|
||||
fold_num=fold_num,
|
||||
train_idx=train_idx, val_idx=val_idx, test_idx=test_idx,
|
||||
config=trial_config, # Use the config with trial's hyperparameters
|
||||
config=trial_config,
|
||||
full_df=df,
|
||||
output_base_dir=trial_artifacts_dir # Save folds under trial dir
|
||||
output_base_dir=trial_artifacts_dir,
|
||||
enable_progress_bar=False
|
||||
)
|
||||
actual_folds_trained += 1
|
||||
all_fold_best_val_scores[fold_id] = best_val_score
|
||||
@ -191,7 +181,8 @@ def objective(
|
||||
fold_artifact_details.append({
|
||||
"fold_id": fold_id,
|
||||
"model_path": str(saved_model_path) if saved_model_path else None,
|
||||
"target_scaler_path": str(saved_scaler_path) if saved_scaler_path else None,
|
||||
"target_scaler_path": str(saved_target_scaler_path) if saved_target_scaler_path else None,
|
||||
"data_scaler_path": str(saved_data_scaler_path) if saved_data_scaler_path else None,
|
||||
"input_size_path": str(saved_input_size_path) if saved_input_size_path else None,
|
||||
"config_path": str(saved_config_path) if saved_config_path else None,
|
||||
})
|
||||
@ -216,9 +207,7 @@ def objective(
|
||||
except Exception as e:
|
||||
logger.error(f"Trial {trial.number}, Fold {fold_id}: Failed CV fold training: {e}", exc_info=True)
|
||||
all_fold_best_val_scores[fold_id] = None # Mark fold as failed
|
||||
# Continue to next fold if possible, but report worst value for this fold
|
||||
trial.report(worst_value, fold_num)
|
||||
# Optionally raise prune here if too many folds fail? Or let the ensemble eval handle it.
|
||||
|
||||
except optuna.TrialPruned:
|
||||
logger.info(f"Trial {trial.number}: Pruned during CV training phase.")
|
||||
@ -239,8 +228,8 @@ def objective(
|
||||
try:
|
||||
# Run evaluation using the artifacts saved in the trial's output directory
|
||||
ensemble_results = run_ensemble_evaluation(
|
||||
config=trial_config, # Pass trial config
|
||||
output_base_dir=trial_artifacts_dir # Directory containing trial's fold subdirs
|
||||
config=trial_config,
|
||||
output_base_dir=trial_artifacts_dir # Pass the specific trial's artifact dir
|
||||
)
|
||||
|
||||
if ensemble_results:
|
||||
@ -326,34 +315,44 @@ def run_hpo():
|
||||
args = parse_arguments()
|
||||
config_path = Path(args.config)
|
||||
try:
|
||||
base_config = load_config(config_path)
|
||||
base_config = load_config(config_path, MainConfig)
|
||||
logger.info(f"Successfully loaded configuration from {config_path}")
|
||||
except Exception as e:
|
||||
logger.critical(f"Failed to load configuration from {config_path}: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
# --- Setup Output Dir ---
|
||||
# 1. Determine the main output directory
|
||||
if args.output_dir:
|
||||
hpo_base_output_dir = Path(args.output_dir)
|
||||
elif base_config.optuna.storage and base_config.optuna.storage.startswith("sqlite:///"):
|
||||
hpo_base_output_dir = Path(base_config.optuna.storage.replace("sqlite:///", "")).parent
|
||||
# Command-line argument overrides config
|
||||
main_output_dir = Path(args.output_dir)
|
||||
logger.info(f"Using main output directory from command line: {main_output_dir}")
|
||||
elif hasattr(base_config, 'output_dir') and base_config.output_dir:
|
||||
main_output_dir = Path(base_config.output_dir)
|
||||
logger.info(f"Using main output directory from config file: {main_output_dir}")
|
||||
else:
|
||||
# Fallback to default if output_dir is not in config either
|
||||
main_output_dir_str = getattr(base_config, 'output_dir', 'output')
|
||||
if not main_output_dir_str: # Handle empty string case
|
||||
main_output_dir_str = 'output'
|
||||
main_output_dir = Path(main_output_dir_str)
|
||||
hpo_base_output_dir = main_output_dir / f'{base_config.optuna.study_name}_ensemble_hpo' # Specific subdir using study name
|
||||
hpo_base_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Using HPO output directory: {hpo_base_output_dir}")
|
||||
main_output_dir = Path("output") # Default if not specified anywhere
|
||||
logger.warning(f"No output directory specified in config or args, defaulting to: {main_output_dir}")
|
||||
|
||||
# 2. Define the specific directory for this ensemble HPO run
|
||||
ensemble_hpo_output_dir = main_output_dir / "ensemble"
|
||||
|
||||
# 3. Create directories
|
||||
main_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
ensemble_hpo_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Ensemble HPO outputs will be saved under: {ensemble_hpo_output_dir}")
|
||||
|
||||
|
||||
# --- Setup Logging ---
|
||||
try:
|
||||
level_name = base_config.log_level.upper()
|
||||
# getattr(logging, 'levelname') is the **new** way to do this ( deprecated, but works! )
|
||||
effective_log_level = logging.getLevelName(level_name)
|
||||
# Ensure study name is filesystem-safe if used directly
|
||||
safe_study_name = "".join(c if c.isalnum() or c in ('_', '-') else '_' for c in base_config.optuna.study_name)
|
||||
log_file = hpo_base_output_dir / f"{safe_study_name}_ensemble_hpo.log"
|
||||
safe_study_name = base_config.optuna.study_name
|
||||
safe_study_name = "".join(c if c.isalnum() or c in ('_', '-') else '_' for c in safe_study_name)
|
||||
# Place log file directly inside the ensemble HPO directory
|
||||
log_file = ensemble_hpo_output_dir / f"{safe_study_name}_ensemble_hpo.log"
|
||||
file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8') # Specify encoding
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)-25s - %(levelname)-7s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
||||
file_handler.setFormatter(formatter)
|
||||
@ -368,7 +367,8 @@ def run_hpo():
|
||||
logger.setLevel(logging.INFO)
|
||||
# Still try to log to a default file if possible
|
||||
try:
|
||||
log_file = hpo_base_output_dir / "default_ensemble_hpo.log"
|
||||
# Default log file also goes into the specific ensemble directory
|
||||
log_file = ensemble_hpo_output_dir / "default_ensemble_hpo.log"
|
||||
file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8')
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)-25s - %(levelname)-7s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
||||
file_handler.setFormatter(formatter)
|
||||
@ -404,26 +404,27 @@ def run_hpo():
|
||||
logger.critical("Optuna configuration section ('optuna') missing.")
|
||||
sys.exit(1)
|
||||
|
||||
storage_path = hpo_config.storage
|
||||
if storage_path and storage_path.startswith("sqlite:///"):
|
||||
db_path_str = storage_path.replace("sqlite:///", "")
|
||||
if not db_path_str:
|
||||
# Default filename if only 'sqlite:///' is provided
|
||||
db_path = hpo_base_output_dir / f"{base_config.optuna.study_name}.db"
|
||||
logger.warning(f"SQLite path was empty, defaulting to: {db_path}")
|
||||
else:
|
||||
db_path = Path(db_path_str)
|
||||
storage_string = hpo_config.storage # Use a more descriptive name
|
||||
storage_path = None # Initialize
|
||||
|
||||
if not db_path.is_absolute():
|
||||
db_path = hpo_base_output_dir / db_path
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True) # Ensure parent dir exists
|
||||
if storage_string and storage_string.startswith("sqlite:///"):
|
||||
db_filename = storage_string.replace("sqlite:///", "").strip()
|
||||
if not db_filename:
|
||||
# Use study name if filename is empty
|
||||
db_filename = f"{safe_study_name}_ensemble.db"
|
||||
logger.warning(f"SQLite path in config was empty, using default filename: {db_filename}")
|
||||
# Place the DB file inside the ensemble HPO directory
|
||||
db_path = ensemble_hpo_output_dir / db_filename
|
||||
storage_path = f"sqlite:///{db_path.resolve()}"
|
||||
logger.info(f"Using SQLite storage: {storage_path}")
|
||||
elif storage_path:
|
||||
logger.info(f"Using Optuna storage: {storage_path} (Assuming non-SQLite or pre-configured)")
|
||||
elif storage_string:
|
||||
# Assume it's a non-SQLite connection string or a pre-configured path
|
||||
storage_path = storage_string
|
||||
logger.warning(f"Using non-SQLite Optuna storage: {storage_path}. Note: DB file will not be placed inside {ensemble_hpo_output_dir}")
|
||||
else:
|
||||
storage_path = None # Explicitly set to None for in-memory
|
||||
logger.warning("No Optuna storage DB specified, using in-memory storage.")
|
||||
logger.warning("No Optuna storage DB specified, using in-memory storage (results will be lost on exit).")
|
||||
|
||||
|
||||
try:
|
||||
# Single objective study based on ensemble performance
|
||||
@ -438,7 +439,7 @@ def run_hpo():
|
||||
# --- Run Optimization ---
|
||||
logger.info(f"Starting Optuna optimization for ensemble performance: study='{hpo_config.study_name}', n_trials={hpo_config.n_trials}, direction='{hpo_config.direction}'")
|
||||
study.optimize(
|
||||
lambda trial: objective(trial, base_config, df, hpo_base_output_dir), # Pass base_config and output dir
|
||||
lambda trial: objective(trial, base_config, df, ensemble_hpo_output_dir), # Pass ensemble output dir
|
||||
n_trials=hpo_config.n_trials,
|
||||
timeout=None,
|
||||
gc_after_trial=True # Garbage collect after trial
|
||||
@ -467,8 +468,8 @@ def run_hpo():
|
||||
for key, value in best_params.items():
|
||||
logger.info(f" {key}: {value}")
|
||||
|
||||
# Save best hyperparameters
|
||||
best_params_file = hpo_base_output_dir / f"{safe_study_name}_best_params.json"
|
||||
# Save best hyperparameters directly into the ensemble output dir
|
||||
best_params_file = ensemble_hpo_output_dir / f"{safe_study_name}_best_params.json"
|
||||
try:
|
||||
with open(best_params_file, 'w', encoding='utf-8') as f:
|
||||
import json
|
||||
@ -477,8 +478,8 @@ def run_hpo():
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save best parameters: {e}", exc_info=True)
|
||||
|
||||
# Save the corresponding config
|
||||
best_config_file = hpo_base_output_dir / f"{safe_study_name}_best_config.yaml"
|
||||
# Save the corresponding config directly into the ensemble output dir
|
||||
best_config_file = ensemble_hpo_output_dir / f"{safe_study_name}_best_config.yaml"
|
||||
try:
|
||||
# Use a fresh deepcopy to avoid modifying the original base_config
|
||||
best_config_dict = copy.deepcopy(base_config.model_dump(mode='python'))
|
||||
@ -491,6 +492,9 @@ def run_hpo():
|
||||
if key in best_config_dict.get('training', {}): best_config_dict['training'][key] = value
|
||||
elif key in best_config_dict.get('model', {}): best_config_dict['model'][key] = value
|
||||
elif key in best_config_dict.get('features', {}): best_config_dict['features'][key] = value
|
||||
elif key in ["use_lags", "use_rolling_windows"]:
|
||||
# IF false, we set this to [] in the parameter suggestion section.
|
||||
pass
|
||||
else:
|
||||
logger.warning(f"Best parameter '{key}' not found in expected config sections (training, model, features).")
|
||||
|
||||
@ -506,73 +510,79 @@ def run_hpo():
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save best configuration: {e}", exc_info=True)
|
||||
|
||||
# Retrieve saved artifact paths and best ensemble method from user attributes
|
||||
best_trial_artifacts_dir = hpo_base_output_dir / "ensemble_runs_artifacts" / f"trial_{best_trial.number}"
|
||||
best_ensemble_method = best_trial.user_attrs.get("best_ensemble_method")
|
||||
# fold_model_paths = best_trial.user_attrs.get("fold_model_paths", []) # Removed
|
||||
# fold_scaler_paths = best_trial.user_attrs.get("fold_scaler_paths", []) # Removed
|
||||
fold_artifact_details = best_trial.user_attrs.get("fold_artifact_details", []) # Retrieve comprehensive details
|
||||
# --- Retrieve Artifacts and Save Ensemble Definition ---
|
||||
# Base directory for this trial's artifacts
|
||||
best_trial_artifacts_dir = ensemble_hpo_output_dir / "ensemble_runs_artifacts" / f"trial_{best_trial.number}"
|
||||
best_ensemble_method = best_trial.user_attrs.get("best_ensemble_method")
|
||||
fold_artifact_details = best_trial.user_attrs.get("fold_artifact_details", [])
|
||||
|
||||
if not best_trial_artifacts_dir.exists():
|
||||
logger.error(f"Artifacts directory for best trial {best_trial.number} not found: {best_trial_artifacts_dir}. Cannot save best ensemble definition.")
|
||||
elif not best_ensemble_method:
|
||||
logger.error(f"Best ensemble method not recorded for best trial {best_trial.number}. Cannot save best ensemble definition.")
|
||||
elif not fold_artifact_details: # Check if any artifact details were recorded
|
||||
logger.error(f"No artifact details recorded for best trial {best_trial.number}. Cannot save best ensemble definition.")
|
||||
else:
|
||||
# --- Save Best Ensemble Definition ---
|
||||
logger.info(f"Saving best ensemble definition for trial {best_trial.number}...")
|
||||
# Check if artifacts exist and data is available
|
||||
if not best_trial_artifacts_dir.exists():
|
||||
logger.error(f"Artifacts directory for best trial {best_trial.number} not found: {best_trial_artifacts_dir}. Cannot save best ensemble definition.")
|
||||
elif not best_ensemble_method:
|
||||
logger.error(f"Best ensemble method not recorded for best trial {best_trial.number}. Cannot save best ensemble definition.")
|
||||
elif not fold_artifact_details: # Check if any artifact details were recorded
|
||||
logger.error(f"No artifact details recorded for best trial {best_trial.number}. Cannot save best ensemble definition.")
|
||||
else:
|
||||
# --- Save Best Ensemble Definition ---
|
||||
logger.info(f"Saving best ensemble definition for trial {best_trial.number}...")
|
||||
|
||||
ensemble_definition_file = hpo_base_output_dir / f"{safe_study_name}_best_ensemble.json"
|
||||
# Save definition file directly into the ensemble output dir
|
||||
ensemble_definition_file = ensemble_hpo_output_dir / f"{safe_study_name}_best_ensemble.json"
|
||||
|
||||
best_ensemble_definition = {
|
||||
"trial_number": best_trial.number,
|
||||
"objective_value": best_trial.value,
|
||||
"hyperparameters": best_trial.params,
|
||||
"ensemble_method": best_ensemble_method,
|
||||
"fold_models": [], # List of dictionaries for each fold's model and scaler, input_size, config
|
||||
"ensemble_artifacts_base_dir": str(best_trial_artifacts_dir.relative_to(hpo_base_output_dir)) # Save path relative to hpo_base_output_dir
|
||||
}
|
||||
best_ensemble_definition = {
|
||||
"trial_number": best_trial.number,
|
||||
"objective_value": best_trial.value,
|
||||
"hyperparameters": best_trial.params,
|
||||
"ensemble_method": best_ensemble_method,
|
||||
# The base dir for artifacts, relative to the main ensemble output dir
|
||||
"ensemble_artifacts_base_dir": str(best_trial_artifacts_dir.relative_to(ensemble_hpo_output_dir)), # Corrected path
|
||||
"fold_models": [],
|
||||
}
|
||||
|
||||
# Populate fold_models with paths to saved artifacts
|
||||
for artifact_detail in fold_artifact_details:
|
||||
fold_def = {
|
||||
"fold_id": artifact_detail.get("fold_id"), # Include fold ID
|
||||
"model_path": None,
|
||||
"target_scaler_path": None,
|
||||
"input_size_path": None,
|
||||
"config_path": None,
|
||||
}
|
||||
# Populate fold_models with paths relative to best_trial_artifacts_dir
|
||||
for artifact_detail in fold_artifact_details:
|
||||
fold_def = {
|
||||
"fold_id": artifact_detail.get("fold_id"),
|
||||
"model_path": None,
|
||||
"target_scaler_path": None,
|
||||
"data_scaler_path": None, # Added placeholder
|
||||
"input_size_path": None,
|
||||
"config_path": None,
|
||||
}
|
||||
|
||||
# Process each path, making it relative if possible
|
||||
for key in ["model_path", "target_scaler_path", "input_size_path", "config_path"]:
|
||||
abs_path_str = artifact_detail.get(key)
|
||||
if abs_path_str:
|
||||
abs_path = Path(abs_path_str)
|
||||
try:
|
||||
# Make path relative to the trial artifacts dir
|
||||
relative_path = str(abs_path.relative_to(best_trial_artifacts_dir))
|
||||
fold_def[key] = relative_path
|
||||
except ValueError:
|
||||
logger.warning(f"Failed to make path {abs_path} relative to {best_trial_artifacts_dir}. Saving absolute path for {key}.")
|
||||
fold_def[key] = str(abs_path) # Fallback to absolute path
|
||||
# Process each path, making it relative if possible
|
||||
# Added "data_scaler_path" to the list of keys to process
|
||||
for key in ["model_path", "target_scaler_path", "data_scaler_path", "input_size_path", "config_path"]:
|
||||
abs_path_str = artifact_detail.get(key)
|
||||
if abs_path_str:
|
||||
abs_path = Path(abs_path_str).absolute()
|
||||
try:
|
||||
# Make path relative to the trial artifacts dir (where models/scalers reside)
|
||||
relative_path = str(abs_path.relative_to(best_trial_artifacts_dir.absolute()))
|
||||
fold_def[key] = relative_path
|
||||
except ValueError:
|
||||
# This shouldn't happen if paths were saved correctly, but handle just in case
|
||||
logger.warning(f"Failed to make path {abs_path} relative to {best_trial_artifacts_dir}. Saving absolute path for {key}.")
|
||||
fold_def[key] = str(abs_path) # Fallback to absolute path
|
||||
|
||||
best_ensemble_definition["fold_models"].append(fold_def)
|
||||
best_ensemble_definition["fold_models"].append(fold_def)
|
||||
|
||||
|
||||
try:
|
||||
with open(ensemble_definition_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(best_ensemble_definition, f, indent=4)
|
||||
logger.info(f"Best ensemble definition saved to {ensemble_definition_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save best ensemble definition: {e}", exc_info=True)
|
||||
try:
|
||||
with open(ensemble_definition_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(best_ensemble_definition, f, indent=4)
|
||||
logger.info(f"Best ensemble definition saved to {ensemble_definition_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save best ensemble definition: {e}", exc_info=True)
|
||||
|
||||
|
||||
|
||||
# --- Optional: Clean up artifact directories for non-best trials ---
|
||||
if not args.keep_artifacts:
|
||||
logger.info("Cleaning up artifact directories for non-best trials...")
|
||||
ensemble_artifacts_base_dir = hpo_base_output_dir / "ensemble_runs_artifacts"
|
||||
# The base path for all trial artifacts within the ensemble dir
|
||||
ensemble_artifacts_base_dir = ensemble_hpo_output_dir / "ensemble_runs_artifacts" # Corrected base path
|
||||
if ensemble_artifacts_base_dir.exists():
|
||||
for item in ensemble_artifacts_base_dir.iterdir():
|
||||
if item.is_dir():
|
||||
|
68
rules.md
68
rules.md
@ -1,68 +0,0 @@
|
||||
## Coding Style Rules & Paradigms
|
||||
|
||||
### Configuration Driven
|
||||
|
||||
* Uses **Pydantic** heavily (`utils/models.py`) to define configuration schemas.
|
||||
* Configuration is loaded from a **YAML file** (`config.yaml`) at runtime (`main.py`).
|
||||
* The `Config` object (or relevant sub-configs) is **passed down** through function calls, making parameters explicit.
|
||||
* A **template configuration** (`_config.yaml`) is often included within the package.
|
||||
|
||||
### Modularity
|
||||
|
||||
* Code is organized into **logical sub-packages** (`io`, `processing`, `pipeline`, `visualization`, `synthesis`, `utils`, `validation`).
|
||||
* Each sub-package has an `__init__.py`, often used to **expose key functions/classes** to the parent level.
|
||||
* **Helper functions** (often internal, prefixed with `_`) are frequently used to break down complex logic within modules (e.g., `processing/surface_helper.py`, `pipeline/runner.py` helpers).
|
||||
|
||||
### Logging
|
||||
|
||||
* Uses the standard **`logging` library**.
|
||||
* Loggers are obtained per module using `logger = logging.getLogger(__name__)`.
|
||||
* **Logging levels** (`DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL`) are used semantically:
|
||||
* `DEBUG`: Verbose internal steps.
|
||||
* `INFO`: Major milestones/stages.
|
||||
* `WARNING`: Recoverable issues or deviations.
|
||||
* `ERROR`: Specific failures that might be handled.
|
||||
* `CRITICAL`: Fatal errors causing exits.
|
||||
* **Root logger configuration** happens in `main.py`, potentially adjusted based on the `debug` flag in the config.
|
||||
|
||||
### Error Handling ("Fail Hard but Helpful")
|
||||
|
||||
* The main entry point (`main.py`) uses a **top-level `try...except` block** to catch major failures during config loading or pipeline execution.
|
||||
* **Critical errors** are logged with tracebacks (`exc_info=True`) and result in `sys.exit(1)`.
|
||||
* Functions often return a **tuple indicating success/failure** and results/error messages (e.g., `(result_data, error_message)` or `(success_flag, result_data)`).
|
||||
* Lower-level functions may log errors/warnings but **allow processing to continue** if feasible and configured (e.g., `allow_segmentation_errors`).
|
||||
* **Specific exceptions** are caught where appropriate (`FileNotFoundError`, `pydicom.errors.InvalidDicomError`, `ValueError`, etc.).
|
||||
* **Pydantic validation errors** during config loading are treated as critical.
|
||||
|
||||
### Typing
|
||||
|
||||
* Consistent use of **Python type hints** (`typing` module: `Optional`, `Dict`, `List`, `Tuple`, `Union`, `Callable`, `Literal`, etc.).
|
||||
* **Pydantic models** rely heavily on type hints for validation.
|
||||
|
||||
### Data Structures
|
||||
|
||||
* **Pydantic models** define primary configuration and result structures (e.g., `Config`, `ProcessingResult`, `CombinedDicomDataset`).
|
||||
* **NumPy arrays** are fundamental for image/volume data.
|
||||
* **Pandas DataFrames** are used for aggregating results, metadata, and creating reports (Excel).
|
||||
* Standard **Python dictionaries** are used extensively for metadata and intermediate data passing.
|
||||
|
||||
### Naming Conventions
|
||||
|
||||
* Follows **PEP 8**: `snake_case` for variables and functions, `PascalCase` for classes.
|
||||
* Internal helper functions are typically prefixed with an **underscore (`_`)**.
|
||||
* Constants are defined in **`UPPER_SNAKE_CASE`** (often in a dedicated `utils/constants.py`).
|
||||
|
||||
### Documentation
|
||||
|
||||
* **Docstrings** are present for most functions and classes, explaining purpose, arguments (`Args:`), and return values (`Returns:`).
|
||||
* Minimal **inline comments**; code aims to be self-explanatory, with docstrings providing higher-level context. (Matches your custom instructions).
|
||||
|
||||
### Dependencies
|
||||
|
||||
* Managed via `requirements.txt`.
|
||||
* Uses standard **scientific Python stack** (`numpy`, `pandas`, `scipy`, `scikit-image`, `matplotlib`), **domain-specific libraries** (`pydicom`), **utility libraries** (`PyYAML`, `joblib`, `tqdm`, `openpyxl`), and `pydantic` for configuration/validation.
|
||||
|
||||
### Parallelism
|
||||
|
||||
* Uses **`joblib`** for parallel processing, configurable via the main config (`mainprocess_core_count`, `subprocess_core_count`).
|
||||
* Parallelism can be **disabled** via configuration or debug mode.
|
Reference in New Issue
Block a user