import pandas as pd import numpy as np import yaml import logging import matplotlib.pyplot as plt import seaborn as sns from pathlib import Path # Import Forecasting Providers from forecasting_model.data_processing import load_raw_data from optimizer.forecasting.base import ForecastProvider from optimizer.forecasting.single_model import SingleModelProvider from optimizer.forecasting.ensemble import EnsembleProvider from optimizer.optimization.battery import solve_battery_optimization_hourly from optimizer.utils.optim_config import OptimizationRunConfig from forecasting_model.utils.forecast_config_model import DataConfig, MainConfig # Import the newly created loading functions from optimizer.utils.model_io import load_single_model_artifact, load_ensemble_artifact from typing import Dict, Any, Optional, Union # Added Union # Silence overly verbose libraries if needed mpl_logger = logging.getLogger('matplotlib') mpl_logger.setLevel(logging.WARNING) pil_logger = logging.getLogger('PIL') pil_logger.setLevel(logging.WARNING) # --- Basic Logging Setup --- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)-7s - %(message)s', datefmt='%H:%M:%S') logger = logging.getLogger() def load_optimization_config(config_path: str) -> OptimizationRunConfig | None: """Loads the main optimization configuration from a YAML file.""" logger.info(f"Loading optimization config from {config_path}") try: with open(config_path, 'r') as f: config_data = yaml.safe_load(f) return OptimizationRunConfig(**config_data) except FileNotFoundError: logger.error(f"Optimization config file not found at {config_path}") return None except yaml.YAMLError as e: logger.error(f"Error parsing YAML optimization config file: {e}") return None except Exception as e: logger.error(f"Error loading optimization config: {e}", exc_info=True) return None def load_main_forecasting_config(config_path: str) -> MainConfig | None: """Loads the main forecasting configuration from a YAML file.""" logger.info(f"Loading main forecasting config from: {config_path}") try: with open(config_path, 'r') as f: config_data = yaml.safe_load(f) # Assuming MainConfig is the top-level model in forecast_config_model.py return MainConfig(**config_data) except FileNotFoundError: logger.error(f"Main forecasting config file not found at {config_path}") return None except yaml.YAMLError as e: logger.error(f"Error parsing YAML main forecasting config file: {e}") return None except Exception as e: logger.error(f"Error loading main forecasting config: {e}", exc_info=True) return None # --- Main Execution Logic --- # 1. Load configs # 2. Initialize forecast providers # 3. For each time window: # a. Get forecasts for all horizons # b. Run optimization for each horizon # c. Store results # 4. Evaluate and visualize if __name__ == "__main__": logger.info("Starting battery optimization evaluation with baseline and models/ensembles.") # --- Load Main Optimization Config --- optimization_config_path = "optim_config.yaml" optimization_config = load_optimization_config(optimization_config_path) if optimization_config is None: logger.critical("Failed to load main optimization config. Exiting.") # Use critical for exit exit(1) # Use non-zero exit code for error optim_run_script_dir = Path(__file__).parent if not optimization_config.models: logger.critical("No models or ensembles specified in optimization config. Exiting.") exit(1) # Try to load the main forecasting config for the first model/ensemble to get the data path first_model_config_path = Path(optimization_config.models[0].model_config_path) main_forecasting_config_for_data = load_main_forecasting_config(str(first_model_config_path)) if main_forecasting_config_for_data is None: logger.critical("Failed to load forecasting config for the first specified model/ensemble to get data path. Exiting.") exit(1) # Use the DataConfig from the first loaded forecasting config historical_data_config = DataConfig( data_path=main_forecasting_config_for_data.data.data_path, raw_datetime_col=main_forecasting_config_for_data.data.raw_datetime_col, raw_datetime_format=main_forecasting_config_for_data.data.raw_datetime_format, datetime_col=main_forecasting_config_for_data.data.datetime_col, raw_target_col=main_forecasting_config_for_data.data.raw_target_col, target_col=main_forecasting_config_for_data.data.target_col, expected_frequency=main_forecasting_config_for_data.data.expected_frequency, fill_initial_target_nans=main_forecasting_config_for_data.data.fill_initial_target_nans ) logger.info(f"Loading original historical data from: {historical_data_config.data_path}") try: full_historical_df = load_raw_data(historical_data_config) if full_historical_df.empty: logger.critical("Loaded original historical data is empty. Cannot proceed. Exiting.") exit(1) # Ensure data is at the expected frequency and sorted full_historical_df = full_historical_df.sort_index().asfreq(historical_data_config.expected_frequency) # Fill any NaNs introduced by asfreq if not already handled by fill_initial_target_nans if full_historical_df[historical_data_config.target_col].isnull().any(): logger.warning(f"NaNs found after setting frequency {historical_data_config.expected_frequency}. Applying ffill().bfill().") full_historical_df[historical_data_config.target_col] = full_historical_df[historical_data_config.target_col].ffill().bfill() if full_historical_df[historical_data_config.target_col].isnull().any(): logger.critical("NaNs still remain after filling. Cannot proceed. Exiting.") exit(1) logger.info(f"Original historical data loaded and prepared. Shape: {full_historical_df.shape}") except Exception as e: logger.critical(f"Failed to load or prepare original historical data from {historical_data_config.data_path}: {e}", exc_info=True) exit(1) # --- Define Evaluation Window and Step --- optimization_horizon_hours = optimization_config.optimization_horizon_hours step_size_hours = 1 # Evaluate every hour by sliding the window by 1 hour logger.info(f"Using optimization horizon: {optimization_horizon_hours} hours with a step size of {step_size_hours} hour(s).") # --- Storage for results per time window --- window_results_list = [] # --- Load Models/Ensembles and Instantiate Providers --- # Store loaded provider instances, keyed by the name from optim_config forecast_providers: Dict[str, ForecastProvider] = {} # Store provider instances for model_eval_config in optimization_config.models: provider_name = model_eval_config.name artifact_type = model_eval_config.type artifact_path = Path(model_eval_config.model_path) # Path to .ckpt or .json config_path = Path(model_eval_config.model_config_path) # Path to YAML config provider_instance: Optional[ForecastProvider] = None # Initialize provider instance if artifact_type == 'model': logger.info(f"Attempting to load single model artifact and create provider: {provider_name}") target_scaler_path = Path(model_eval_config.target_scaler_path) if model_eval_config.target_scaler_path else None input_size_path = artifact_path.parent / "input_size.pt" # Derive path convention if not input_size_path.exists() and artifact_path.parent.name == 'checkpoints': input_size_path = artifact_path.parent.parent / "input_size.pt" loaded_artifact_info = load_single_model_artifact( model_path=artifact_path, config_path=config_path, input_size_path=input_size_path, target_scaler_path=target_scaler_path ) if loaded_artifact_info: try: provider_instance = SingleModelProvider( model_instance=loaded_artifact_info['model_instance'], feature_config=loaded_artifact_info['feature_config'], target_col=loaded_artifact_info['main_forecasting_config'].data.target_col, # Get target col from loaded config target_scaler=loaded_artifact_info['target_scaler'] ) # Validation check (basic horizon check) if 1 not in provider_instance.feature_config.forecast_horizon: logger.error(f"Model '{provider_name}' forecast horizon {provider_instance.feature_config.forecast_horizon} does not include 1 hour. Cannot use for this evaluation.") provider_instance = None # Discard if validation fails else: logger.info(f"Successfully created SingleModelProvider for '{provider_name}'.") except Exception as e: logger.error(f"Failed to instantiate SingleModelProvider for '{provider_name}': {e}", exc_info=True) else: logger.warning(f"Single model artifact '{provider_name}' could not be loaded. Skipping provider creation.") elif artifact_type == 'ensemble': logger.info(f"Attempting to load ensemble artifact and create provider: {provider_name}") hpo_base_output_dir_for_ensemble = artifact_path.parent loaded_artifact_info = load_ensemble_artifact( ensemble_definition_path=artifact_path, hpo_base_output_dir=hpo_base_output_dir_for_ensemble ) if loaded_artifact_info: try: # Ensure necessary keys are present before instantiation required_keys = ['fold_artifacts', 'ensemble_method', 'ensemble_feature_config', 'ensemble_target_col'] if not all(key in loaded_artifact_info for key in required_keys): missing_keys = [key for key in required_keys if key not in loaded_artifact_info] raise ValueError(f"Ensemble artifact info is missing required keys: {missing_keys}") provider_instance = EnsembleProvider( fold_artifacts=loaded_artifact_info['fold_artifacts'], ensemble_method=loaded_artifact_info['ensemble_method'], ensemble_feature_config=loaded_artifact_info['ensemble_feature_config'], ensemble_target_col=loaded_artifact_info['ensemble_target_col'] ) # Validation check (basic horizon check) if 1 not in provider_instance.common_forecast_horizons: logger.error(f"Ensemble '{provider_name}' common forecast horizon {provider_instance.common_forecast_horizons} does not include 1 hour. Cannot use for this evaluation.") provider_instance = None # Discard if validation fails else: logger.info(f"Successfully created EnsembleProvider for '{provider_name}'.") except Exception as e: logger.error(f"Failed to instantiate EnsembleProvider for '{provider_name}': {e}", exc_info=True) else: logger.warning(f"Ensemble artifact '{provider_name}' could not be loaded. Skipping provider creation.") else: logger.error(f"Unknown artifact type '{artifact_type}' for '{provider_name}'. Skipping.") continue # Skip to next model_eval_config # Store the successfully created provider instance if provider_instance: forecast_providers[provider_name] = provider_instance # --- End Loading --- if not forecast_providers: logger.critical("No forecast providers were successfully created. Cannot proceed with evaluation. Exiting.") exit(1) # --- Calculate Max Lookback Needed Across All Providers --- max_required_lookback = 0 for provider_name, provider in forecast_providers.items(): try: lookback = provider.get_required_lookback() max_required_lookback = max(max_required_lookback, lookback) logger.debug(f"Provider '{provider_name}' requires lookback: {lookback}") except AttributeError: logger.error(f"Provider '{provider_name}' does not have a 'get_required_lookback' method. Cannot determine lookback requirements accurately. Exiting.") exit(1) except Exception as e: logger.error(f"Error getting lookback for provider '{provider_name}': {e}. Exiting.", exc_info=True) exit(1) logger.info(f"Maximum lookback required across all providers: {max_required_lookback} hours.") # The first timestamp for which we can generate a forecast needs `max_required_lookback` points *before* it. # If optimization starts at `window_start_time` (iloc `i`), the forecast generation needs data up to `i-1`. # The historical slice passed to `get_forecast` must contain `max_required_lookback` points, ending at `i-1`. # Therefore, the slice starts at `i - max_required_lookback`. This must be >= 0. # So, `i >= max_required_lookback`. first_window_start_iloc = max_required_lookback # The last window starts such that the window ends within the data: `i + optimization_horizon_hours - 1 < len(df)` # So, `i < len(df) - optimization_horizon_hours + 1`. last_window_start_iloc = len(full_historical_df) - optimization_horizon_hours if first_window_start_iloc > last_window_start_iloc: logger.critical(f"Not enough historical data ({len(full_historical_df)} hours) for the required lookback ({max_required_lookback}) and optimization horizon ({optimization_horizon_hours}). First possible window start iloc: {first_window_start_iloc}, last possible: {last_window_start_iloc}. Exiting.") exit(1) logger.info(f"Evaluating over historical windows from iloc {first_window_start_iloc} to {last_window_start_iloc}.") # --- Evaluation Loop --- for i in range(first_window_start_iloc, last_window_start_iloc + 1, step_size_hours): # Define the actual optimization window in terms of iloc and time window_start_iloc = i window_end_iloc = i + optimization_horizon_hours - 1 # Inclusive index for the window end # Check if the window is complete within the dataset bounds if window_end_iloc >= len(full_historical_df): logger.warning(f"Skipping window starting at iloc {window_start_iloc}: extends beyond available data (needs up to iloc {window_end_iloc}, max is {len(full_historical_df)-1}).") continue window_timestamps = full_historical_df.index[window_start_iloc : window_end_iloc + 1] # Double-check length just in case if len(window_timestamps) != optimization_horizon_hours: logger.warning(f"Skipping window starting at iloc {window_start_iloc} due to unexpected timestamp slice length ({len(window_timestamps)} instead of {optimization_horizon_hours} hours).") continue window_start_time = window_timestamps[0] window_end_time = window_timestamps[-1] logger.info(f"Processing window: {window_start_time.strftime('%Y-%m-%d %H:%M')} to {window_end_time.strftime('%Y-%m-%d %H:%M')} (iloc {window_start_iloc})") # --- Prepare Historical Slice for Forecasting --- # We need data *up to* the beginning of the optimization window, including lookback. # Slice should end at iloc `window_start_iloc - 1`. # Slice should start at `window_start_iloc - max_required_lookback`. hist_slice_start_iloc = max(0, window_start_iloc - max_required_lookback) hist_slice_end_iloc = window_start_iloc # Exclusive end iloc for slicing, so it includes up to window_start_iloc - 1 if hist_slice_end_iloc <= hist_slice_start_iloc: logger.error(f"Invalid historical slice range for window starting at {window_start_time}: start_iloc={hist_slice_start_iloc}, end_iloc={hist_slice_end_iloc}. Skipping window.") continue historical_slice_for_forecasting = full_historical_df.iloc[hist_slice_start_iloc : hist_slice_end_iloc].copy() # Check if the slice has the expected length (at least max_required_lookback, unless near start of data) if len(historical_slice_for_forecasting) < max_required_lookback and window_start_iloc >= max_required_lookback: logger.warning(f"Historical slice for window starting {window_start_time} has unexpected length {len(historical_slice_for_forecasting)}, expected {max_required_lookback}. Check slicing logic. Skipping.") continue elif len(historical_slice_for_forecasting) == 0: logger.warning(f"Historical slice for window starting {window_start_time} is empty. Skipping.") continue logger.debug(f"Using historical slice from {historical_slice_for_forecasting.index.min()} to {historical_slice_for_forecasting.index.max()} (Length: {len(historical_slice_for_forecasting)}) for forecasting.") # --- Collect Window Results --- window_results = { 'start_time': window_start_time, 'end_time': window_end_time, 'actual_prices': full_historical_df[historical_data_config.target_col].iloc[window_start_iloc : window_end_iloc + 1].tolist() } # --- Baseline Optimization --- baseline_prices_input = np.array(window_results['actual_prices']) logger.debug(f"Running baseline optimization for window starting {window_start_time}") try: baseline_status, baseline_profit, baseline_power, baseline_B = solve_battery_optimization_hourly( baseline_prices_input, optimization_config.initial_b, optimization_config.max_capacity, optimization_config.max_rate ) window_results['baseline'] = { "status": baseline_status, "profit": baseline_profit, "power_schedule": baseline_power.tolist() if baseline_power is not None else None, "B_schedule": baseline_B.tolist() if baseline_B is not None else None } logger.debug(f"Baseline profit: {baseline_profit if baseline_profit is not None else 'N/A'}") except Exception as e: logger.error(f"Baseline optimization failed for window starting {window_start_time}: {e}", exc_info=True) window_results['baseline'] = {"status": "Error", "profit": None, "power_schedule": None, "B_schedule": None} # --- Forecast Provider Optimizations --- for provider_name, provider_instance in forecast_providers.items(): logger.debug(f"Generating forecast and running optimization for provider '{provider_name}' for window starting {window_start_time}") # Generate forecast using the provider's get_forecast method try: forecast_prices_input = provider_instance.get_forecast( historical_data_slice=historical_slice_for_forecasting.copy(), # Pass a copy optimization_horizon_hours=optimization_horizon_hours ) except Exception as e: logger.error(f"Error calling get_forecast for provider '{provider_name}': {e}", exc_info=True) forecast_prices_input = None if forecast_prices_input is None or len(forecast_prices_input) != optimization_horizon_hours: logger.warning(f"Forecast generation failed or returned incorrect length ({len(forecast_prices_input) if forecast_prices_input is not None else 0} instead of {optimization_horizon_hours}) for provider '{provider_name}' window starting {window_start_time}. Skipping optimization.") window_results[provider_name] = {"status": "Forecast Generation Failed", "profit": None, "power_schedule": None, "B_schedule": None} continue # Skip optimization for this provider/window # Ensure the forecast input is a numpy array of the correct shape if not isinstance(forecast_prices_input, np.ndarray) or forecast_prices_input.shape != (optimization_horizon_hours,): logger.error(f"Forecast input for provider '{provider_name}' has incorrect format ({type(forecast_prices_input)}, shape {forecast_prices_input.shape if isinstance(forecast_prices_input, np.ndarray) else 'N/A'}). Expected ({optimization_horizon_hours},). Skipping optimization.") window_results[provider_name] = {"status": "Invalid Forecast Format", "profit": None, "power_schedule": None, "B_schedule": None} continue # --- Run Optimization with Forecast Prices --- try: model_status, model_profit, model_power, model_B = solve_battery_optimization_hourly( forecast_prices_input, optimization_config.initial_b, optimization_config.max_capacity, optimization_config.max_rate ) window_results[provider_name] = { "status": model_status, "profit": model_profit, "power_schedule": model_power.tolist() if model_power is not None else None, "B_schedule": model_B.tolist() if model_B is not None else None } logger.debug(f"Provider '{provider_name}' profit: {model_profit if model_profit is not None else 'N/A'}") except Exception as e: logger.error(f"Optimization failed for provider '{provider_name}' window starting {window_start_time}: {e}", exc_info=True) window_results[provider_name] = {"status": "Error", "profit": None, "power_schedule": None, "B_schedule": None} # Append results for this window window_results_list.append(window_results) logger.debug(f"Finished processing window starting at: {window_start_time.strftime('%Y-%m-%d %H:%M')}") logger.info("Finished processing all evaluation windows.") # --- Post-processing and Plotting --- logger.info("Starting results analysis and plotting.") if not window_results_list: logger.warning("No window results were collected. Skipping plotting.") exit(0) # Not necessarily an error state # Convert results list to a DataFrame flat_results = [] successfully_loaded_provider_names = list(forecast_providers.keys()) # Names of providers used for window_res in window_results_list: base_info = { 'start_time': window_res['start_time'], 'end_time': window_res['end_time'], } # Add baseline results flat_results.append({**base_info, 'type': 'baseline', **window_res.get('baseline', {})}) # Add provider results for provider_name in successfully_loaded_provider_names: provider_res = window_res.get(provider_name, {}) # Get results or empty dict flat_results.append({**base_info, 'type': provider_name, **provider_res}) results_df = pd.DataFrame(flat_results) results_df['start_time'] = pd.to_datetime(results_df['start_time']) # Ensure datetime type # Filter out rows where essential optimization results are missing # results_df.dropna(subset=['profit', 'power_schedule'], inplace=True) # Be careful with dropna # Calculate Profit Absolute Error over time profit_pivot = results_df.pivot_table(index='start_time', columns='type', values='profit') mae_df = pd.DataFrame(index=profit_pivot.index) if 'baseline' in profit_pivot.columns: for provider_name in successfully_loaded_provider_names: if provider_name in profit_pivot.columns: # Use .sub() and .abs() to handle potential NaNs gracefully mae_df[f'Profit_Abs_Error_{provider_name}'] = profit_pivot[provider_name].sub(profit_pivot['baseline']).abs() else: logger.warning(f"Cannot calculate profit MAE for provider '{provider_name}'. Data not found in pivoted results.") else: logger.warning("Cannot calculate profit MAE because baseline results are missing or incomplete.") # --- Plotting --- # Plot 1: Price and First Hour's Power Schedule Over Time logger.info("Generating Price and Power Schedule plot.") continuous_power_data = [] for window_res in window_results_list: start_time = window_res['start_time'] # Baseline power baseline_data = window_res.get('baseline', {}) if baseline_data.get('power_schedule') and len(baseline_data['power_schedule']) > 0: continuous_power_data.append({'time': start_time, 'type': 'baseline', 'power': baseline_data['power_schedule'][0]}) # Provider powers for provider_name in successfully_loaded_provider_names: provider_data = window_res.get(provider_name, {}) if provider_data.get('power_schedule') and len(provider_data['power_schedule']) > 0: continuous_power_data.append({'time': start_time, 'type': provider_name, 'power': provider_data['power_schedule'][0]}) continuous_power_df = pd.DataFrame(continuous_power_data) if not continuous_power_df.empty: continuous_power_df['time'] = pd.to_datetime(continuous_power_df['time']) # Get historical prices corresponding to the evaluation window start times eval_start_times = results_df['start_time'].unique() price_plot_df = full_historical_df.loc[eval_start_times, [historical_data_config.target_col]].reset_index() price_plot_df.rename(columns={price_plot_df.columns[0]: 'time', historical_data_config.target_col: 'price'}, inplace=True) # Use positional index for timestamp column rename plot_range_start = continuous_power_df['time'].min() plot_range_end = continuous_power_df['time'].max() # Filter data for the plot range filtered_price_df = price_plot_df[(price_plot_df['time'] >= plot_range_start) & (price_plot_df['time'] <= plot_range_end)] filtered_power_df = continuous_power_df[(continuous_power_df['time'] >= plot_range_start) & (continuous_power_df['time'] <= plot_range_end)] if not filtered_power_df.empty: fig1, ax1 = plt.subplots(figsize=(15, 7)) ax2 = ax1.twinx() sns.lineplot(data=filtered_price_df, x='time', y='price', ax=ax1, color='gray', linestyle='--', label='Historical Price (Window Start)', zorder=1) ax1.set_ylabel('Price (€/MWh)', color='gray') ax1.tick_params(axis='y', labelcolor='gray') sns.lineplot(data=filtered_power_df, x='time', y='power', hue='type', ax=ax2, zorder=2) ax2.set_ylabel('Power (MW)') h1, l1 = ax1.get_legend_handles_labels() h2, l2 = ax2.get_legend_handles_labels() ax2.legend(h1 + h2, l1 + l2, loc='upper left', title='Schedule Type') ax1.get_legend().remove() # Remove the original legend from ax1 ax1.set_xlabel('Time') ax1.set_title('Battery Power Schedule (1st Hour) vs. Historical Price (Window Start)') plt.tight_layout() plt.savefig("power_schedule_vs_price.png") logger.info("Price and Power Schedule plot saved as power_schedule_vs_price.png") # plt.show() else: logger.warning("No power data available within the determined plot range.") else: logger.warning("No continuous power data generated for plotting power schedule.") # Plot 2: Absolute Profit Error over time logger.info("Generating Profit Absolute Error plot.") if not mae_df.empty and not mae_df.isnull().all().all(): # Check if not empty and not all NaN fig2, ax = plt.subplots(figsize=(15, 7)) # Use the plot range from the power plot if available mae_plot_range_start = plot_range_start if 'plot_range_start' in locals() else mae_df.index.min() mae_plot_range_end = plot_range_end if 'plot_range_end' in locals() else mae_df.index.max() filtered_mae_df = mae_df[(mae_df.index >= mae_plot_range_start) & (mae_df.index <= mae_plot_range_end)].copy() # Create copy # Optional: Handle or remove columns that are all NaN within the range filtered_mae_df.dropna(axis=1, how='all', inplace=True) if not filtered_mae_df.empty: sns.lineplot(data=filtered_mae_df, ax=ax) ax.set_xlabel('Time') ax.set_ylabel('Absolute Profit Error vs. Baseline (€)') ax.set_title('Absolute Profit Error of Providers vs. Baseline over Time') ax.legend(title='Provider Type') plt.tight_layout() plt.savefig("profit_abs_error_over_time.png") logger.info("Profit Absolute Error plot saved as profit_abs_error_over_time.png") # plt.show() else: logger.warning("MAE data is all NaN or empty within the plot range. Skipping MAE plot.") else: logger.warning("No valid data available to plot Profit Absolute Error.") logger.info("Evaluation and plotting completed.")