intermediate backup
This commit is contained in:
20
forecasting_model/train/__init__.py
Normal file
20
forecasting_model/train/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
# Expose core components for easier import
|
||||
from .ensemble_evaluation import (
|
||||
run_ensemble_evaluation
|
||||
)
|
||||
|
||||
|
||||
# Expose main configuration class from utils
|
||||
from ..utils import MainConfig
|
||||
|
||||
# Define __all__ for explicit public API (optional but good practice)
|
||||
__all__ = [
|
||||
"run_ensemble_evaluation",
|
||||
"MainConfig",
|
||||
]
|
276
forecasting_model/train/classic.py
Normal file
276
forecasting_model/train/classic.py
Normal file
@ -0,0 +1,276 @@
|
||||
"""
|
||||
Classic training routine: Train on initial data segment, validate and test on final segments.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
import torch
|
||||
import yaml
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
|
||||
from pytorch_lightning.loggers import CSVLogger
|
||||
from typing import Dict, Optional
|
||||
|
||||
from forecasting_model.utils.forecast_config_model import MainConfig
|
||||
from forecasting_model.data_processing import prepare_fold_data_and_loaders, split_data_classic
|
||||
from forecasting_model.train.model import LSTMForecastLightningModule
|
||||
from forecasting_model.evaluation import evaluate_fold_predictions
|
||||
|
||||
from forecasting_model.utils.helper import save_results
|
||||
from forecasting_model.io.plotting import plot_loss_curve_from_csv
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def run_classic_training(
|
||||
config: MainConfig,
|
||||
full_df: pd.DataFrame,
|
||||
output_base_dir: Path
|
||||
) -> Optional[Dict[str, float]]:
|
||||
"""
|
||||
Runs a single training pipeline using a classic train/val/test split.
|
||||
|
||||
Args:
|
||||
config: The main configuration object.
|
||||
full_df: The complete raw DataFrame.
|
||||
output_base_dir: The base directory where general outputs are saved.
|
||||
Classic results will be saved in a subdirectory.
|
||||
|
||||
Returns:
|
||||
A dictionary containing test metrics (e.g., {'MAE': ..., 'RMSE': ...})
|
||||
for the classic run, or None if it fails.
|
||||
"""
|
||||
run_start_time = time.perf_counter()
|
||||
logger.info("--- Starting Classic Training Run ---")
|
||||
|
||||
# Define a specific output directory for this run
|
||||
classic_output_dir = output_base_dir / "classic_run"
|
||||
classic_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Classic run outputs will be saved to: {classic_output_dir}")
|
||||
|
||||
test_metrics: Optional[Dict[str, float]] = None
|
||||
best_val_score: Optional[float] = None
|
||||
best_model_path: Optional[str] = None
|
||||
|
||||
try:
|
||||
# --- Data Splitting ---
|
||||
logger.info("Splitting data into classic train/val/test sets...")
|
||||
n_samples = len(full_df)
|
||||
val_frac = config.cross_validation.val_size_fraction
|
||||
test_frac = config.cross_validation.test_size_fraction
|
||||
train_idx, val_idx, test_idx = split_data_classic(n_samples, val_frac, test_frac)
|
||||
|
||||
# Store test datetime index for evaluation plotting
|
||||
test_datetime_index = full_df.iloc[test_idx].index
|
||||
|
||||
# --- Data Preparation ---
|
||||
logger.info("Preparing data loaders for the classic split...")
|
||||
train_loader, val_loader, test_loader, target_scaler, input_size = prepare_fold_data_and_loaders(
|
||||
full_df=full_df,
|
||||
train_idx=train_idx,
|
||||
val_idx=val_idx,
|
||||
test_idx=test_idx,
|
||||
target_col=config.data.target_col,
|
||||
feature_config=config.features,
|
||||
train_config=config.training,
|
||||
eval_config=config.evaluation
|
||||
)
|
||||
logger.info(f"Data loaders prepared. Input size determined: {input_size}")
|
||||
|
||||
# Save artifacts specific to this run if needed (e.g., for later inference)
|
||||
torch.save(test_loader, classic_output_dir / "classic_test_loader.pt")
|
||||
torch.save(target_scaler, classic_output_dir / "classic_target_scaler.pt")
|
||||
torch.save(input_size, classic_output_dir / "classic_input_size.pt")
|
||||
# Save config for this run
|
||||
try: config_dump = config.model_dump()
|
||||
except AttributeError: config_dump = config.model_dump()
|
||||
with open(classic_output_dir / "config.yaml", 'w') as f:
|
||||
yaml.dump(config_dump, f, default_flow_style=False)
|
||||
|
||||
# --- Model Initialization ---
|
||||
model = LSTMForecastLightningModule(
|
||||
model_config=config.model,
|
||||
train_config=config.training,
|
||||
input_size=input_size,
|
||||
target_scaler=target_scaler
|
||||
)
|
||||
logger.info("Classic LSTMForecastLightningModule initialized.")
|
||||
|
||||
# --- PyTorch Lightning Callbacks ---
|
||||
monitor_metric = "val_MeanAbsoluteError" # Monitor same metric as CV folds
|
||||
monitor_mode = "min"
|
||||
|
||||
early_stop_callback = None
|
||||
if config.training.early_stopping_patience is not None and config.training.early_stopping_patience > 0:
|
||||
early_stop_callback = EarlyStopping(
|
||||
monitor=monitor_metric, min_delta=0.0001,
|
||||
patience=config.training.early_stopping_patience, verbose=True, mode=monitor_mode
|
||||
)
|
||||
logger.info(f"Enabled EarlyStopping: monitor='{monitor_metric}', patience={config.training.early_stopping_patience}")
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
dirpath=classic_output_dir / "checkpoints",
|
||||
filename="best_classic_model", # Simple filename
|
||||
save_top_k=1, monitor=monitor_metric, mode=monitor_mode, verbose=True
|
||||
)
|
||||
logger.info(f"Enabled ModelCheckpoint: monitor='{monitor_metric}', mode='{monitor_mode}'")
|
||||
|
||||
lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
||||
callbacks = [checkpoint_callback, lr_monitor]
|
||||
if early_stop_callback: callbacks.append(early_stop_callback)
|
||||
|
||||
# --- PyTorch Lightning Logger ---
|
||||
pl_logger = CSVLogger(save_dir=str(classic_output_dir), name="training_logs")
|
||||
logger.info(f"Using CSVLogger, logs will be saved in: {pl_logger.log_dir}")
|
||||
|
||||
# --- PyTorch Lightning Trainer ---
|
||||
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
|
||||
devices = 1 if accelerator == 'gpu' else None
|
||||
precision = getattr(config.training, 'precision', 32)
|
||||
|
||||
trainer = pl.Trainer(
|
||||
accelerator=accelerator, devices=devices,
|
||||
max_epochs=config.training.epochs,
|
||||
callbacks=callbacks, logger=pl_logger,
|
||||
log_every_n_steps=max(1, len(train_loader)//10),
|
||||
enable_progress_bar=True,
|
||||
gradient_clip_val=getattr(config.training, 'gradient_clip_val', None),
|
||||
precision=precision,
|
||||
)
|
||||
logger.info(f"Initialized PyTorch Lightning Trainer: accelerator='{accelerator}', devices={devices}, precision={precision}")
|
||||
|
||||
# --- Training ---
|
||||
logger.info("Starting classic model training...")
|
||||
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
|
||||
logger.info("Classic model training finished.")
|
||||
|
||||
# Store best validation score and path
|
||||
best_val_score_tensor = trainer.checkpoint_callback.best_model_score
|
||||
best_model_path = trainer.checkpoint_callback.best_model_path
|
||||
best_val_score = best_val_score_tensor.item() if best_val_score_tensor is not None else None
|
||||
|
||||
if best_val_score is not None:
|
||||
logger.info(f"Best validation score ({monitor_metric}): {best_val_score:.4f}")
|
||||
logger.info(f"Best model checkpoint path: {best_model_path}")
|
||||
else:
|
||||
logger.warning(f"Could not retrieve best validation score/path (metric: {monitor_metric}). Evaluation might use last model.")
|
||||
best_model_path = None
|
||||
|
||||
# --- Prediction on Test Set ---
|
||||
logger.info("Starting prediction on classic test set using best checkpoint...")
|
||||
prediction_results_list = trainer.predict(
|
||||
ckpt_path=best_model_path if best_model_path else 'last',
|
||||
dataloaders=test_loader
|
||||
)
|
||||
|
||||
# --- Evaluation ---
|
||||
if not prediction_results_list:
|
||||
logger.error("Predict phase did not return any results for classic run.")
|
||||
test_metrics = None
|
||||
else:
|
||||
try:
|
||||
# Shapes: (n_samples, len(horizons))
|
||||
all_preds_scaled = torch.cat([b['preds_scaled'] for b in prediction_results_list], dim=0).numpy()
|
||||
n_predictions = len(all_preds_scaled) # Number of samples actually predicted
|
||||
|
||||
if 'targets_scaled' in prediction_results_list[0]:
|
||||
all_targets_scaled = torch.cat([b['targets_scaled'] for b in prediction_results_list], dim=0).numpy()
|
||||
if len(all_targets_scaled) != n_predictions:
|
||||
logger.error(f"Classic Run: Mismatch between number of predictions ({n_predictions}) and targets ({len(all_targets_scaled)}).")
|
||||
raise ValueError("Prediction and target count mismatch during classic evaluation.")
|
||||
else:
|
||||
raise ValueError("Targets missing from prediction results.")
|
||||
|
||||
logger.info(f"Processing {n_predictions} prediction results for classic test set...")
|
||||
|
||||
# --- Calculate Correct Time Index for Plotting (First Horizon) ---
|
||||
target_time_index_for_plotting = None
|
||||
if test_idx is not None and config.features.forecast_horizon:
|
||||
try:
|
||||
test_block_index = full_df.index[test_idx] # Use the test_idx from classic split
|
||||
seq_len = config.features.sequence_length
|
||||
first_horizon = config.features.forecast_horizon[0]
|
||||
start_offset = seq_len + first_horizon - 1
|
||||
if start_offset < len(test_block_index):
|
||||
end_index = min(start_offset + n_predictions, len(test_block_index))
|
||||
target_time_index_for_plotting = test_block_index[start_offset:end_index]
|
||||
if len(target_time_index_for_plotting) != n_predictions:
|
||||
logger.warning(f"Classic Run: Calculated target time index length ({len(target_time_index_for_plotting)}) "
|
||||
f"does not match prediction count ({n_predictions}). Plotting x-axis might be misaligned.")
|
||||
target_time_index_for_plotting = None
|
||||
else:
|
||||
logger.warning(f"Classic Run: Cannot calculate target time index, start offset ({start_offset}) "
|
||||
f"exceeds test block length ({len(test_block_index)}).")
|
||||
except Exception as e:
|
||||
logger.error(f"Classic Run: Error calculating target time index for plotting: {e}", exc_info=True)
|
||||
target_time_index_for_plotting = None # Ensure it's None if error occurs
|
||||
else:
|
||||
logger.warning(f"Classic Run: Skipping target time index calculation (missing test_idx or forecast_horizon).")
|
||||
# --- End Index Calculation ---
|
||||
|
||||
# Use the classic run specific objects and config
|
||||
test_metrics = evaluate_fold_predictions(
|
||||
y_true_scaled=all_targets_scaled,
|
||||
y_pred_scaled=all_preds_scaled,
|
||||
target_scaler=target_scaler,
|
||||
eval_config=config.evaluation,
|
||||
fold_num=-1, # Indicate classic run
|
||||
output_dir=str(classic_output_dir),
|
||||
plot_subdir="plots",
|
||||
prediction_time_index=target_time_index_for_plotting, # Pass the correctly calculated index
|
||||
forecast_horizons=config.features.forecast_horizon,
|
||||
plot_title_prefix="Classic Run"
|
||||
)
|
||||
# Save metrics
|
||||
save_results({"overall_metrics": test_metrics}, classic_output_dir / "test_metrics.json")
|
||||
logger.info(f"Classic run test metrics (overall): {test_metrics}")
|
||||
|
||||
# --- Plot Loss Curve for Classic Run ---
|
||||
try:
|
||||
# Adjusted logic to find metrics.csv inside potential version_*/ directories
|
||||
classic_log_dir = classic_output_dir / "training_logs"
|
||||
metrics_file = None
|
||||
version_dirs = list(classic_log_dir.glob("version_*"))
|
||||
if version_dirs:
|
||||
# Assuming the latest version directory contains the relevant logs
|
||||
latest_version_dir = max(version_dirs, key=lambda p: p.stat().st_mtime)
|
||||
potential_metrics_file = latest_version_dir / "metrics.csv"
|
||||
if potential_metrics_file.is_file():
|
||||
metrics_file = potential_metrics_file
|
||||
else:
|
||||
logger.warning(f"Classic Run: metrics.csv not found in latest version directory: {latest_version_dir}")
|
||||
else:
|
||||
# Fallback if no version_* directories exist (less common with CSVLogger)
|
||||
potential_metrics_file = classic_log_dir / "metrics.csv"
|
||||
if potential_metrics_file.is_file():
|
||||
metrics_file = potential_metrics_file
|
||||
|
||||
if metrics_file and metrics_file.is_file():
|
||||
plot_loss_curve_from_csv(
|
||||
metrics_csv_path=metrics_file,
|
||||
output_path=classic_output_dir / "loss_curve.png",
|
||||
title="Classic Run Training Progression",
|
||||
train_loss_col='train_loss', # Changed from 'train_loss_epoch'
|
||||
val_loss_col='val_loss' # Keep as 'val_loss'
|
||||
)
|
||||
logger.info(f"Generating loss curve for classic run from: {metrics_file}")
|
||||
else:
|
||||
logger.warning(f"Classic Run: Could not find metrics.csv in {classic_log_dir} or its version subdirectories for loss curve plot.")
|
||||
except Exception as plot_e:
|
||||
logger.error(f"Classic Run: Failed to generate loss curve plot: {plot_e}", exc_info=True)
|
||||
# --- End Classic Loss Plotting ---
|
||||
|
||||
except (KeyError, ValueError, Exception) as e:
|
||||
logger.error(f"Error processing classic prediction results: {e}", exc_info=True)
|
||||
test_metrics = None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred during the classic training pipeline: {e}", exc_info=True)
|
||||
test_metrics = None # Indicate failure
|
||||
|
||||
finally:
|
||||
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
||||
run_end_time = time.perf_counter()
|
||||
logger.info(f"--- Finished Classic Training Run in {run_end_time - run_start_time:.2f} seconds ---")
|
||||
return test_metrics
|
425
forecasting_model/train/ensemble_evaluation.py
Normal file
425
forecasting_model/train/ensemble_evaluation.py
Normal file
@ -0,0 +1,425 @@
|
||||
"""
|
||||
Ensemble evaluation for time series forecasting models.
|
||||
|
||||
This module provides functionality to evaluate ensemble predictions
|
||||
by combining predictions from n-1 folds and testing on the remaining fold.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml # For loading fold config
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
||||
import pandas as pd # For time index handling
|
||||
import pickle # Need pickle for the specific error check
|
||||
|
||||
from forecasting_model.evaluation import evaluate_fold_predictions
|
||||
from forecasting_model.train.model import LSTMForecastLightningModule
|
||||
from forecasting_model.utils.forecast_config_model import MainConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def load_fold_model_and_objects(
|
||||
fold_dir: Path,
|
||||
) -> Optional[Tuple[LSTMForecastLightningModule, MainConfig, torch.utils.data.DataLoader, Union[StandardScaler, MinMaxScaler, None], int, Optional[pd.Index], List[int]]]:
|
||||
"""
|
||||
Load a trained model, its config, dataloader, scaler, input_size, prediction time index, and forecast horizons.
|
||||
|
||||
Args:
|
||||
fold_dir: Directory containing the fold's artifacts (checkpoint, config, loader, etc.).
|
||||
|
||||
Returns:
|
||||
A tuple containing (model, config, test_loader, target_scaler, input_size, prediction_target_time_index, forecast_horizons)
|
||||
or None if any essential artifact is missing or loading fails.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Loading artifacts from: {fold_dir}")
|
||||
|
||||
# 1. Load Fold Configuration
|
||||
config_path = fold_dir / "config.yaml"
|
||||
if not config_path.is_file():
|
||||
logger.error(f"Fold config file not found in {fold_dir}")
|
||||
return None
|
||||
with open(config_path, 'r') as f:
|
||||
fold_config_dict = yaml.safe_load(f)
|
||||
fold_config = MainConfig(**fold_config_dict) # Validate fold's config
|
||||
|
||||
# 2. Load Saved Objects using torch.load
|
||||
test_loader_path = fold_dir / "test_loader.pt"
|
||||
scaler_path = fold_dir / "target_scaler.pt"
|
||||
input_size_path = fold_dir / "input_size.pt"
|
||||
prediction_index_path = fold_dir / "prediction_target_time_index.pt"
|
||||
|
||||
if not all([p.is_file() for p in [test_loader_path, scaler_path, input_size_path]]):
|
||||
logger.error(f"Missing one or more required artifacts (test_loader, target_scaler, input_size) in {fold_dir}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# --- Explicitly set weights_only=False for non-model objects ---
|
||||
test_loader = torch.load(test_loader_path, weights_only=False)
|
||||
target_scaler = torch.load(scaler_path, weights_only=False)
|
||||
input_size = torch.load(input_size_path, weights_only=False)
|
||||
# --- End Modification ---
|
||||
except pickle.UnpicklingError as e:
|
||||
# Catch potential unpickling errors even with weights_only=False
|
||||
logger.error(f"Failed to unpickle saved object in {fold_dir}: {e}", exc_info=True)
|
||||
return None
|
||||
except AttributeError as e:
|
||||
# Catch potential issues if class definitions changed between saving and loading
|
||||
logger.error(f"AttributeError loading saved object in {fold_dir} (class definition changed?): {e}", exc_info=True)
|
||||
return None
|
||||
except Exception as e:
|
||||
# Catch other potential loading errors
|
||||
logger.error(f"Unexpected error loading saved objects (loader/scaler/size) from {fold_dir}: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
# Retrieve forecast horizon list from the fold's config
|
||||
forecast_horizons = fold_config.features.forecast_horizon
|
||||
|
||||
# --- Extract prediction target time index (if available) ---
|
||||
prediction_target_time_index: Optional[pd.Index] = None
|
||||
if prediction_index_path.is_file():
|
||||
try:
|
||||
prediction_target_time_index = torch.load(prediction_index_path, weights_only=False)
|
||||
# Basic validation
|
||||
if not isinstance(prediction_target_time_index, pd.Index):
|
||||
logger.warning(f"Loaded prediction index from {prediction_index_path} is not a pandas Index.")
|
||||
prediction_target_time_index = None
|
||||
else:
|
||||
logger.debug(f"Loaded prediction target time index from {prediction_index_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load prediction target time index from {prediction_index_path}: {e}")
|
||||
else:
|
||||
logger.warning(f"Prediction target time index file not found at {prediction_index_path}. Plotting x-axis might be inaccurate for ensemble plots.")
|
||||
# --- End Index Extraction ---
|
||||
|
||||
|
||||
# 3. Find Checkpoint and Load Model
|
||||
checkpoint_path = None
|
||||
try:
|
||||
# Use rglob to find the checkpoint potentially nested deeper
|
||||
checkpoints = list(fold_dir.glob("**/best_model_fold_*.ckpt"))
|
||||
if not checkpoints:
|
||||
logger.error(f"No 'best_model_fold_*.ckpt' checkpoint found in {fold_dir} or subdirectories.")
|
||||
return None
|
||||
if len(checkpoints) > 1:
|
||||
logger.warning(f"Multiple checkpoints found in {fold_dir}, using the first one: {checkpoints[0]}")
|
||||
checkpoint_path = checkpoints[0]
|
||||
|
||||
logger.info(f"Loading model from checkpoint: {checkpoint_path}")
|
||||
model = LSTMForecastLightningModule.load_from_checkpoint(
|
||||
checkpoint_path,
|
||||
map_location=torch.device('cpu'), # Optional: load to CPU first if memory is tight
|
||||
model_config=fold_config.model,
|
||||
train_config=fold_config.training,
|
||||
input_size=input_size,
|
||||
target_scaler=target_scaler
|
||||
)
|
||||
model.eval()
|
||||
logger.info(f"Successfully loaded model and artifacts from {fold_dir}")
|
||||
return model, fold_config, test_loader, target_scaler, input_size, prediction_target_time_index, forecast_horizons
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Checkpoint file not found: {checkpoint_path}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model from checkpoint {checkpoint_path} in {fold_dir}: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Generic error loading artifacts from {fold_dir}: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def make_ensemble_predictions(
|
||||
models: List[LSTMForecastLightningModule],
|
||||
test_loader: torch.utils.data.DataLoader,
|
||||
device: Optional[torch.device] = None
|
||||
) -> Tuple[Optional[Dict[str, np.ndarray]], Optional[np.ndarray]]:
|
||||
"""
|
||||
Make predictions using an ensemble of models efficiently.
|
||||
|
||||
Processes the test_loader once, getting predictions from all models per batch.
|
||||
|
||||
Args:
|
||||
models: List of trained models (already in eval mode).
|
||||
test_loader: DataLoader for the test set.
|
||||
device: Device to run predictions on (e.g., torch.device("cuda:0")).
|
||||
If None, attempts to use GPU if available, else CPU.
|
||||
|
||||
Returns:
|
||||
Tuple of (ensemble_predictions, targets):
|
||||
- ensemble_predictions: Dict containing ensemble predictions keyed by method
|
||||
('mean', 'median', 'min', 'max'). Values are np.arrays.
|
||||
Returns None if prediction fails.
|
||||
- targets: Ground truth values as a single np.array. Returns None if prediction fails
|
||||
or targets are unavailable in loader.
|
||||
"""
|
||||
if not models:
|
||||
logger.warning("make_ensemble_predictions received an empty list of models.")
|
||||
return None, None
|
||||
|
||||
if device is None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logger.info(f"Running ensemble predictions on device: {device}")
|
||||
|
||||
# Move all models to the target device
|
||||
for model in models:
|
||||
model.to(device)
|
||||
|
||||
all_batch_preds: List[List[np.ndarray]] = [[] for _ in models] # Outer list: models, Inner list: batches
|
||||
all_batch_targets: List[np.ndarray] = []
|
||||
targets_available = True
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, batch in enumerate(test_loader):
|
||||
try:
|
||||
# Determine if batch contains targets
|
||||
if isinstance(batch, (list, tuple)) and len(batch) == 2:
|
||||
x, y = batch
|
||||
x = x.to(device)
|
||||
# Keep targets on CPU until needed for concatenation
|
||||
all_batch_targets.append(y.cpu().numpy())
|
||||
else:
|
||||
x = batch.to(device)
|
||||
targets_available = False # No targets found in this batch
|
||||
|
||||
# Get predictions from all models for this batch
|
||||
for i, model in enumerate(models):
|
||||
try:
|
||||
pred = model(x) # Shape: (batch, horizon)
|
||||
all_batch_preds[i].append(pred.cpu().numpy())
|
||||
except Exception as model_err:
|
||||
logger.error(f"Error during prediction with model {i} on batch {batch_idx}: {model_err}", exc_info=True)
|
||||
# Handle error: Fill with NaNs? Skip model? For now, fill with NaNs of expected shape
|
||||
# Infer expected shape: (batch_size, horizon)
|
||||
batch_size = x.shape[0]
|
||||
horizon = models[0].output_size # Assume all models have same horizon
|
||||
nan_preds = np.full((batch_size, horizon), np.nan)
|
||||
all_batch_preds[i].append(nan_preds)
|
||||
|
||||
|
||||
except Exception as batch_err:
|
||||
logger.error(f"Error processing batch {batch_idx} for ensemble prediction: {batch_err}", exc_info=True)
|
||||
# If a batch fails catastrophically, we might not be able to proceed reliably
|
||||
return None, None # Indicate failure
|
||||
|
||||
# Concatenate batch results for each model
|
||||
model_preds_concat = []
|
||||
for i in range(len(models)):
|
||||
if not all_batch_preds[i]: # Check if any predictions were collected for this model
|
||||
logger.warning(f"No predictions collected for model index {i}. Skipping this model in ensemble.")
|
||||
continue # Skip this model if it failed on all batches
|
||||
try:
|
||||
model_preds_concat.append(np.concatenate(all_batch_preds[i], axis=0))
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to concatenate predictions for model index {i}: {e}. Check for shape mismatches or empty lists.")
|
||||
# Decide how to handle: skip model or fail? Let's skip for robustness.
|
||||
continue
|
||||
|
||||
if not model_preds_concat:
|
||||
logger.error("No valid predictions collected from any model in the ensemble.")
|
||||
return None, None
|
||||
|
||||
# Concatenate targets if available
|
||||
targets_concat = None
|
||||
if targets_available and all_batch_targets:
|
||||
try:
|
||||
targets_concat = np.concatenate(all_batch_targets, axis=0)
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to concatenate targets: {e}")
|
||||
return None, None # Fail if targets were expected but couldn't be combined
|
||||
elif targets_available and not all_batch_targets:
|
||||
logger.warning("Targets were expected based on first batch, but none were collected.")
|
||||
# Proceed without targets, returning None for them
|
||||
|
||||
# Stack predictions from all models: Shape (num_models, num_samples, horizon)
|
||||
try:
|
||||
stacked_preds = np.stack(model_preds_concat, axis=0)
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to stack model predictions: {e}. Check if all models produced compatible shapes.")
|
||||
return None, targets_concat # Return targets if available, but no ensemble preds
|
||||
|
||||
# Calculate different ensemble predictions (handle NaNs potentially introduced by model failures)
|
||||
# np.nanmean, np.nanmedian etc. ignore NaNs
|
||||
ensemble_preds = {
|
||||
'mean': np.nanmean(stacked_preds, axis=0),
|
||||
'median': np.nanmedian(stacked_preds, axis=0),
|
||||
'min': np.nanmin(stacked_preds, axis=0),
|
||||
'max': np.nanmax(stacked_preds, axis=0)
|
||||
}
|
||||
|
||||
logger.info(f"Ensemble predictions generated using {stacked_preds.shape[0]} models.")
|
||||
return ensemble_preds, targets_concat
|
||||
|
||||
|
||||
def evaluate_ensemble_for_test_fold(
|
||||
test_fold_num: int,
|
||||
all_fold_dirs: List[Path],
|
||||
output_base_dir: Path,
|
||||
# full_data_index: Optional[pd.Index] = None # Removed, get from loaded objects
|
||||
) -> Optional[Dict[str, Dict[str, float]]]:
|
||||
"""
|
||||
Evaluates ensemble predictions for a specific test fold.
|
||||
Args:
|
||||
test_fold_num: The 1-based number of the fold to use as the test set.
|
||||
all_fold_dirs: List of paths to all fold directories.
|
||||
output_base_dir: Base directory for saving evaluation results/plots.
|
||||
Returns:
|
||||
Dictionary containing metrics for each ensemble method for this test fold,
|
||||
or None if evaluation fails.
|
||||
"""
|
||||
logger.info(f"--- Evaluating Ensemble: Test Fold {test_fold_num} ---")
|
||||
test_fold_dir = output_base_dir / f"fold_{test_fold_num:02d}"
|
||||
|
||||
load_result = load_fold_model_and_objects(test_fold_dir)
|
||||
if load_result is None:
|
||||
logger.error(f"Failed to load necessary artifacts for test fold {test_fold_num}. Skipping ensemble evaluation for this fold.")
|
||||
return None
|
||||
# Unpack results including the prediction time index and horizons
|
||||
_, test_fold_config, test_loader, target_scaler, _, prediction_target_time_index, test_forecast_horizons = load_result
|
||||
|
||||
# Load models from all *other* folds
|
||||
ensemble_models: List[LSTMForecastLightningModule] = []
|
||||
model_forecast_horizons = None # Track horizons from loaded models
|
||||
for i, fold_dir in enumerate(all_fold_dirs):
|
||||
current_fold_num = i + 1
|
||||
if current_fold_num == test_fold_num:
|
||||
continue # Skip the test fold itself
|
||||
|
||||
model_load_result = load_fold_model_and_objects(fold_dir)
|
||||
if model_load_result:
|
||||
model, _, _, _, _, _, fold_horizons = model_load_result # Only need the model here
|
||||
if model:
|
||||
ensemble_models.append(model)
|
||||
# Store horizons from the first successful model load
|
||||
if model_forecast_horizons is None:
|
||||
model_forecast_horizons = fold_horizons
|
||||
# Optional: Check consistency of horizons across ensemble models
|
||||
elif set(model_forecast_horizons) != set(fold_horizons):
|
||||
logger.error(f"Inconsistent forecast horizons between ensemble models! Test fold {test_fold_num} expected {test_forecast_horizons}, "
|
||||
f"Model {i+1} has {fold_horizons}. Ensemble may be invalid.")
|
||||
# Decide how to handle: error out, or proceed with caution?
|
||||
# return None # Option: Fail hard
|
||||
else:
|
||||
logger.warning(f"Could not load model from fold {current_fold_num} to include in ensemble for test fold {test_fold_num}.")
|
||||
|
||||
|
||||
if len(ensemble_models) < 2:
|
||||
logger.warning(f"Skipping ensemble evaluation for test fold {test_fold_num}: "
|
||||
f"Need at least 2 models for ensemble, only loaded {len(ensemble_models)}.")
|
||||
return {} # Return empty dict, not None, to indicate process ran but no ensemble formed
|
||||
|
||||
# Check consistency between test fold horizons and ensemble model horizons
|
||||
if model_forecast_horizons is None: # Should not happen if len(ensemble_models) >= 1
|
||||
logger.error(f"Could not determine forecast horizons from ensemble models for test fold {test_fold_num}.")
|
||||
return None
|
||||
if set(test_forecast_horizons) != set(model_forecast_horizons):
|
||||
logger.error(f"Forecast horizons of test fold {test_fold_num} ({test_forecast_horizons}) do not match "
|
||||
f"horizons from ensemble models ({model_forecast_horizons}). Cannot evaluate.")
|
||||
return None
|
||||
|
||||
# Make ensemble predictions using the loaded models and the test fold's data loader
|
||||
# Use the test fold's config to determine device implicitly
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
ensemble_preds_dict, targets_np = make_ensemble_predictions(ensemble_models, test_loader, device=device)
|
||||
|
||||
if ensemble_preds_dict is None or targets_np is None:
|
||||
logger.error(f"Failed to generate ensemble predictions or retrieve targets for test fold {test_fold_num}.")
|
||||
return None # Indicate failure
|
||||
|
||||
# Evaluate each ensemble method's predictions against the test fold's targets
|
||||
fold_ensemble_results: Dict[str, Dict[str, float]] = {}
|
||||
for method, preds_np in ensemble_preds_dict.items():
|
||||
logger.info(f"Evaluating ensemble method '{method}' for test fold {test_fold_num}...")
|
||||
|
||||
# Define a unique output directory for this method's plots
|
||||
method_plot_dir = output_base_dir / "ensemble_eval_plots" / f"test_fold_{test_fold_num:02d}" / f"method_{method}"
|
||||
|
||||
# Use the prediction_target_time_index loaded earlier
|
||||
prediction_time_index_for_plot = None
|
||||
if prediction_target_time_index is not None:
|
||||
if len(prediction_target_time_index) == targets_np.shape[0]:
|
||||
prediction_time_index_for_plot = prediction_target_time_index
|
||||
else:
|
||||
logger.warning(f"Length of loaded prediction target time index ({len(prediction_target_time_index)}) does not match "
|
||||
f"number of samples ({targets_np.shape[0]}) for test fold {test_fold_num}, method '{method}'. Plot x-axis may be incorrect.")
|
||||
|
||||
|
||||
# Call the standard evaluation function
|
||||
metrics = evaluate_fold_predictions(
|
||||
y_true_scaled=targets_np,
|
||||
y_pred_scaled=preds_np,
|
||||
target_scaler=target_scaler,
|
||||
eval_config=test_fold_config.evaluation,
|
||||
fold_num=test_fold_num - 1,
|
||||
output_dir=str(method_plot_dir.parent.parent),
|
||||
plot_subdir=f"method_{method}",
|
||||
prediction_time_index=prediction_time_index_for_plot, # Pass the index
|
||||
forecast_horizons=test_forecast_horizons,
|
||||
plot_title_prefix=f"Ensemble ({method})"
|
||||
)
|
||||
fold_ensemble_results[method] = metrics
|
||||
|
||||
logger.info(f"--- Finished Ensemble Evaluation: Test Fold {test_fold_num} ---")
|
||||
return fold_ensemble_results
|
||||
|
||||
|
||||
def run_ensemble_evaluation(
|
||||
config: MainConfig, # Pass main config for context if needed, though fold configs are loaded
|
||||
output_base_dir: Path,
|
||||
# full_data_index: Optional[pd.Index] = None # Removed, get index from loaded objects
|
||||
) -> Dict[int, Dict[str, Dict[str, float]]]:
|
||||
"""
|
||||
Run ensemble evaluation across all folds, treating each as the test set once.
|
||||
|
||||
Args:
|
||||
config: The main configuration object (potentially unused if fold configs sufficient).
|
||||
output_base_dir: Base directory where fold outputs are stored.
|
||||
Returns:
|
||||
Dictionary containing ensemble metrics for each test fold:
|
||||
{ test_fold_num: { ensemble_method: { metric_name: value, ... }, ... }, ... }
|
||||
"""
|
||||
logger.info("===== Starting Cross-Validated Ensemble Evaluation =====")
|
||||
all_ensemble_results: Dict[int, Dict[str, Dict[str, float]]] = {}
|
||||
|
||||
# Discover fold directories
|
||||
fold_dirs = sorted([d for d in output_base_dir.glob("fold_*") if d.is_dir()])
|
||||
if not fold_dirs:
|
||||
logger.error(f"No fold directories found in {output_base_dir} for ensemble evaluation.")
|
||||
return {}
|
||||
if len(fold_dirs) < 2:
|
||||
logger.warning(f"Need at least 2 folds for ensemble evaluation, found {len(fold_dirs)}. Skipping.")
|
||||
return {}
|
||||
|
||||
logger.info(f"Found {len(fold_dirs)} fold directories.")
|
||||
|
||||
# Iterate through each fold, designating it as the test fold
|
||||
for i, test_fold_dir in enumerate(fold_dirs):
|
||||
test_fold_num = i + 1 # 1-based fold number
|
||||
try:
|
||||
results_for_test_fold = evaluate_ensemble_for_test_fold(
|
||||
test_fold_num=test_fold_num,
|
||||
all_fold_dirs=fold_dirs,
|
||||
output_base_dir=output_base_dir,
|
||||
# full_data_index=full_data_index # Removed
|
||||
)
|
||||
|
||||
if results_for_test_fold is not None:
|
||||
# Only add results if the evaluation didn't fail completely
|
||||
all_ensemble_results[test_fold_num] = results_for_test_fold
|
||||
|
||||
except Exception as e:
|
||||
# Catch unexpected errors during a specific test fold evaluation
|
||||
logger.error(f"Unexpected error during ensemble evaluation with test fold {test_fold_num}: {e}", exc_info=True)
|
||||
continue # Continue to the next fold
|
||||
|
||||
# Saving is handled by the main script (`forecasting_model_run.py`) which calls this
|
||||
if not all_ensemble_results:
|
||||
logger.warning("Ensemble evaluation finished, but no results were generated.")
|
||||
else:
|
||||
logger.info("===== Finished Cross-Validated Ensemble Evaluation =====")
|
||||
|
||||
return all_ensemble_results
|
0
forecasting_model/train/folds.py
Normal file
0
forecasting_model/train/folds.py
Normal file
287
forecasting_model/train/model.py
Normal file
287
forecasting_model/train/model.py
Normal file
@ -0,0 +1,287 @@
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import pytorch_lightning as pl
|
||||
import torchmetrics
|
||||
from typing import Optional, Dict, Any, Union, List, Tuple
|
||||
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
||||
|
||||
# Assuming config_model is in sibling directory utils/
|
||||
from forecasting_model.utils.forecast_config_model import ModelConfig, TrainingConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LSTMForecastLightningModule(pl.LightningModule):
|
||||
"""
|
||||
PyTorch Lightning Module for LSTM-based time series forecasting.
|
||||
|
||||
Encapsulates the model architecture, training, validation, and test logic.
|
||||
Uses torchmetrics for efficient metric calculation.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
train_config: TrainingConfig,
|
||||
input_size: int,
|
||||
target_scaler: Optional[Union[StandardScaler, MinMaxScaler]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# --- Validate & Store Configs ---
|
||||
if input_size <= 0:
|
||||
raise ValueError("`input_size` must be provided as a positive integer during model instantiation.")
|
||||
self._input_size = input_size # Use a temporary attribute
|
||||
|
||||
# Ensure forecast_horizon is a valid list in the config
|
||||
if not hasattr(model_config, 'forecast_horizon') or \
|
||||
not isinstance(model_config.forecast_horizon, list) or \
|
||||
not model_config.forecast_horizon or \
|
||||
any(h <= 0 for h in model_config.forecast_horizon):
|
||||
raise ValueError("ModelConfig requires `forecast_horizon` to be a non-empty list of positive integers.")
|
||||
|
||||
# Output size is the number of horizons we predict
|
||||
self.output_size = len(model_config.forecast_horizon)
|
||||
# Store the actual horizon list for reference if needed, ensure sorted
|
||||
self.forecast_horizons = sorted(model_config.forecast_horizon)
|
||||
|
||||
self.model_config = model_config
|
||||
self.train_config = train_config
|
||||
self.target_scaler = target_scaler # Store scaler for this fold
|
||||
|
||||
# Use save_hyperparameters() - forecast_horizon is part of model_config which is saved
|
||||
self.save_hyperparameters('model_config', 'train_config', 'input_size', ignore=['target_scaler'])
|
||||
# Note: Pydantic models might not be perfectly saved/loaded by PL's hparams, check if needed.
|
||||
# If issues arise loading, might need to flatten relevant hparams manually.
|
||||
|
||||
# --- Define Model Layers ---
|
||||
self.lstm = nn.LSTM(
|
||||
input_size=self.hparams.input_size,
|
||||
hidden_size=self.hparams.model_config.hidden_size,
|
||||
num_layers=self.hparams.model_config.num_layers,
|
||||
batch_first=True,
|
||||
dropout=self.hparams.model_config.dropout if self.hparams.model_config.num_layers > 1 else 0.0
|
||||
)
|
||||
self.dropout = nn.Dropout(self.hparams.model_config.dropout)
|
||||
|
||||
# Output layer maps LSTM hidden state to the number of forecast horizons
|
||||
self.fc = nn.Linear(self.hparams.model_config.hidden_size, self.output_size)
|
||||
|
||||
# Optional residual connection handling
|
||||
self.use_residual_skips = self.hparams.model_config.use_residual_skips
|
||||
self.residual_projection = None
|
||||
if self.use_residual_skips:
|
||||
# If input size doesn't match hidden size, project input
|
||||
if self.hparams.input_size != self.hparams.model_config.hidden_size:
|
||||
# Use hparams.input_size here
|
||||
self.residual_projection = nn.Linear(self.hparams.input_size, self.hparams.model_config.hidden_size)
|
||||
logger.info("Residual connections enabled.")
|
||||
if self.residual_projection:
|
||||
logger.info("Residual projection layer added.")
|
||||
|
||||
# --- Define Loss Function ---
|
||||
if self.hparams.train_config.loss_function.upper() == 'MSE':
|
||||
self.criterion = nn.MSELoss()
|
||||
elif self.hparams.train_config.loss_function.upper() == 'MAE':
|
||||
self.criterion = nn.L1Loss()
|
||||
else:
|
||||
raise ValueError(f"Unsupported loss function: {self.hparams.train_config.loss_function}")
|
||||
|
||||
# --- Define Metrics (TorchMetrics) ---
|
||||
metrics = torchmetrics.MetricCollection([
|
||||
torchmetrics.MeanAbsoluteError(),
|
||||
torchmetrics.MeanSquaredError(squared=False) # RMSE
|
||||
])
|
||||
self.train_metrics = metrics.clone(prefix='train_')
|
||||
self.val_metrics = metrics.clone(prefix='val_')
|
||||
self.test_metrics = metrics.clone(prefix='test_')
|
||||
|
||||
self.val_MeanAbsoluteError_Original_Scale = torchmetrics.MeanAbsoluteError()
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass through the LSTM network.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (batch_size, sequence_length, input_size)
|
||||
|
||||
Returns:
|
||||
Predictions tensor of shape (batch_size, len(forecast_horizons))
|
||||
where each element corresponds to a predicted horizon in sorted order.
|
||||
"""
|
||||
# LSTM forward pass
|
||||
lstm_out, (hidden, cell) = self.lstm(x) # Shape: (batch, seq_len, hidden_size)
|
||||
|
||||
# Output from the last time step
|
||||
last_time_step_out = lstm_out[:, -1, :] # Shape: (batch_size, hidden_size)
|
||||
|
||||
# Apply dropout
|
||||
last_time_step_out = self.dropout(last_time_step_out)
|
||||
|
||||
# Optional Residual Connection
|
||||
if self.use_residual_skips:
|
||||
residual = x[:, -1, :] # Input from the last time step: (batch_size, input_size)
|
||||
if self.residual_projection:
|
||||
residual = self.residual_projection(residual) # Project to hidden_size
|
||||
last_time_step_out = last_time_step_out + residual
|
||||
|
||||
# Final fully connected layer
|
||||
predictions = self.fc(last_time_step_out) # Shape: (batch_size, output_size/len(horizons))
|
||||
|
||||
return predictions # Shape: (batch_size, len(forecast_horizons))
|
||||
|
||||
def _calculate_loss(self, outputs, targets):
|
||||
# Shapes should now be (batch_size, len(horizons)) for both
|
||||
if outputs.shape != targets.shape:
|
||||
# Minimal check, dataset __getitem__ should ensure this
|
||||
raise ValueError(f"Output shape {outputs.shape} doesn't match target shape {targets.shape} for loss calculation.")
|
||||
return self.criterion(outputs, targets)
|
||||
|
||||
def _inverse_transform(self, data: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
"""Helper to inverse transform data (preds or targets) using the stored target scaler."""
|
||||
if self.target_scaler is None:
|
||||
return None
|
||||
|
||||
data_cpu = data.detach().cpu().numpy().astype(np.float64)
|
||||
original_shape = data_cpu.shape # e.g., (batch_size, len(horizons))
|
||||
num_elements = data_cpu.size
|
||||
|
||||
# Scaler expects 2D input (N, 1)
|
||||
data_flat = data_cpu.reshape(num_elements, 1)
|
||||
|
||||
try:
|
||||
inversed_np = self.target_scaler.inverse_transform(data_flat)
|
||||
# Return as tensor on the original device, potentially reshaped
|
||||
inversed_tensor = torch.from_numpy(inversed_np).float().to(data.device)
|
||||
# Reshape back to original multi-horizon shape
|
||||
return inversed_tensor.reshape(original_shape)
|
||||
# return inversed_tensor.flatten() # Keep flat if needed for specific metric inputs
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to inverse transform data: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
|
||||
x, y = batch # Shapes: x=(batch, seq_len, features), y=(batch, len(horizons))
|
||||
outputs = self(x) # Scaled outputs: (batch, len(horizons))
|
||||
loss = self._calculate_loss(outputs, y)
|
||||
|
||||
# Log scaled metrics
|
||||
self.train_metrics.update(outputs, y)
|
||||
self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||
self.log_dict(self.train_metrics, on_step=False, on_epoch=True, logger=True)
|
||||
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
|
||||
x, y = batch
|
||||
outputs = self(x) # Scaled outputs
|
||||
loss = self._calculate_loss(outputs, y)
|
||||
|
||||
# Log scaled metrics
|
||||
self.val_metrics.update(outputs, y)
|
||||
self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||
self.log_dict(self.val_metrics, on_step=False, on_epoch=True, logger=True)
|
||||
|
||||
# Log MAE on ORIGINAL scale (primary metric for checkpoints)
|
||||
if self.target_scaler is not None:
|
||||
# Inverse transform keeps the (batch, len(horizons)) shape
|
||||
outputs_inv = self._inverse_transform(outputs)
|
||||
y_inv = self._inverse_transform(y)
|
||||
|
||||
if outputs_inv is not None and y_inv is not None:
|
||||
# Ensure shapes match
|
||||
if outputs_inv.shape == y_inv.shape:
|
||||
# It will compute the average MAE across all elements if multi-dim
|
||||
self.val_MeanAbsoluteError_Original_Scale.update(outputs_inv, y_inv)
|
||||
self.log('val_MeanAbsoluteError_Original_Scale', self.val_MeanAbsoluteError_Original_Scale, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||
else:
|
||||
logger.warning(f"Shape mismatch after inverse transform in validation: Preds {outputs_inv.shape}, Targets {y_inv.shape}")
|
||||
else:
|
||||
logger.warning("Could not compute original scale MAE in validation due to inverse transform failure.")
|
||||
|
||||
|
||||
def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
|
||||
# Optional: Keep this method ONLY if you want trainer.test() to log metrics.
|
||||
# For getting predictions for evaluation, use predict_step.
|
||||
# If evaluate_fold_predictions handles all metrics, this can be simplified or removed.
|
||||
# Let's simplify it for now to only log loss if needed.
|
||||
try:
|
||||
x, y = batch
|
||||
outputs = self(x)
|
||||
loss = self._calculate_loss(outputs, y)
|
||||
# Log scaled test metrics if you still want trainer.test() to report them
|
||||
metrics = self.test_metrics(outputs, y)
|
||||
self.log('test_loss_step', loss, on_step=True, on_epoch=False) # Log step loss if needed
|
||||
self.log_dict(self.test_metrics, on_step=False, on_epoch=True, logger=True)
|
||||
# No return needed if just logging
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred in test_step for batch {batch_idx}: {e}", exc_info=True)
|
||||
# Optionally log something to indicate failure
|
||||
|
||||
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Runs inference for prediction and returns scaled predictions and targets.
|
||||
'batch' might contain only features depending on the DataLoader setup for predict.
|
||||
Let's assume the test_loader yields (x, y) pairs for convenience here.
|
||||
"""
|
||||
if isinstance(batch, (list, tuple)) and len(batch) == 2:
|
||||
x, y = batch
|
||||
else:
|
||||
# Assume batch contains only features if not a pair
|
||||
x = batch
|
||||
y = None # No targets available during prediction if dataloader only yields features
|
||||
|
||||
outputs = self(x) # Scaled outputs
|
||||
|
||||
result = {'preds_scaled': outputs.detach().cpu()}
|
||||
if y is not None:
|
||||
# Include targets if they were part of the batch (e.g., using test_loader for predict)
|
||||
result['targets_scaled'] = y.detach().cpu()
|
||||
|
||||
return result
|
||||
|
||||
def configure_optimizers(self) -> Union[optim.Optimizer, Tuple[List[optim.Optimizer], List[Dict[str, Any]]]]:
|
||||
"""
|
||||
Configure the optimizer (Adam) and optional LR scheduler.
|
||||
"""
|
||||
optimizer = optim.Adam(
|
||||
self.parameters(),
|
||||
lr=self.hparams.train_config.learning_rate # Access lr via hparams
|
||||
)
|
||||
logger.info(f"Configured Adam optimizer with LR: {self.hparams.train_config.learning_rate}")
|
||||
|
||||
# Optional LR Scheduler configuration
|
||||
scheduler_config = None
|
||||
if hasattr(self.hparams.train_config, 'scheduler_step_size') and \
|
||||
self.hparams.train_config.scheduler_step_size is not None and \
|
||||
hasattr(self.hparams.train_config, 'scheduler_gamma') and \
|
||||
self.hparams.train_config.scheduler_gamma is not None:
|
||||
|
||||
if self.hparams.train_config.scheduler_step_size > 0 and 0 < self.hparams.train_config.scheduler_gamma < 1:
|
||||
logger.info(f"Configuring StepLR scheduler with step_size={self.hparams.train_config.scheduler_step_size} "
|
||||
f"and gamma={self.hparams.train_config.scheduler_gamma}")
|
||||
scheduler = optim.lr_scheduler.StepLR(
|
||||
optimizer,
|
||||
step_size=self.hparams.train_config.scheduler_step_size,
|
||||
gamma=self.hparams.train_config.scheduler_gamma
|
||||
)
|
||||
scheduler_config = {
|
||||
'scheduler': scheduler,
|
||||
'interval': 'epoch', # or 'step'
|
||||
'frequency': 1,
|
||||
'monitor': 'val_loss', # Optional: Only step if monitor improves (for ReduceLROnPlateau)
|
||||
}
|
||||
else:
|
||||
logger.warning("Scheduler parameters provided but invalid (step_size must be >0, 0<gamma<1). No scheduler configured.")
|
||||
|
||||
# Example for ReduceLROnPlateau (if needed later)
|
||||
# scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
|
||||
# scheduler_config = {'scheduler': scheduler, 'monitor': 'val_loss'}
|
||||
|
||||
if scheduler_config:
|
||||
return [optimizer], [scheduler_config]
|
||||
else:
|
||||
return optimizer
|
Reference in New Issue
Block a user