277 lines
14 KiB
Python
277 lines
14 KiB
Python
"""
|
|
Classic training routine: Train on initial data segment, validate and test on final segments.
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
from pathlib import Path
|
|
import pandas as pd
|
|
import torch
|
|
import yaml
|
|
import pytorch_lightning as pl
|
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
|
|
from pytorch_lightning.loggers import CSVLogger
|
|
from typing import Dict, Optional
|
|
|
|
from forecasting_model.utils.forecast_config_model import MainConfig
|
|
from forecasting_model.data_processing import prepare_fold_data_and_loaders, split_data_classic
|
|
from forecasting_model.train.model import LSTMForecastLightningModule
|
|
from forecasting_model.evaluation import evaluate_fold_predictions
|
|
|
|
from forecasting_model.utils.helper import save_results
|
|
from forecasting_model.io.plotting import plot_loss_curve_from_csv
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def run_classic_training(
|
|
config: MainConfig,
|
|
full_df: pd.DataFrame,
|
|
output_base_dir: Path
|
|
) -> Optional[Dict[str, float]]:
|
|
"""
|
|
Runs a single training pipeline using a classic train/val/test split.
|
|
|
|
Args:
|
|
config: The main configuration object.
|
|
full_df: The complete raw DataFrame.
|
|
output_base_dir: The base directory where general outputs are saved.
|
|
Classic results will be saved in a subdirectory.
|
|
|
|
Returns:
|
|
A dictionary containing test metrics (e.g., {'MAE': ..., 'RMSE': ...})
|
|
for the classic run, or None if it fails.
|
|
"""
|
|
run_start_time = time.perf_counter()
|
|
logger.info("--- Starting Classic Training Run ---")
|
|
|
|
# Define a specific output directory for this run
|
|
classic_output_dir = output_base_dir / "classic_run"
|
|
classic_output_dir.mkdir(parents=True, exist_ok=True)
|
|
logger.info(f"Classic run outputs will be saved to: {classic_output_dir}")
|
|
|
|
test_metrics: Optional[Dict[str, float]] = None
|
|
best_val_score: Optional[float] = None
|
|
best_model_path: Optional[str] = None
|
|
|
|
try:
|
|
# --- Data Splitting ---
|
|
logger.info("Splitting data into classic train/val/test sets...")
|
|
n_samples = len(full_df)
|
|
val_frac = config.cross_validation.val_size_fraction
|
|
test_frac = config.cross_validation.test_size_fraction
|
|
train_idx, val_idx, test_idx = split_data_classic(n_samples, val_frac, test_frac)
|
|
|
|
# Store test datetime index for evaluation plotting
|
|
test_datetime_index = full_df.iloc[test_idx].index
|
|
|
|
# --- Data Preparation ---
|
|
logger.info("Preparing data loaders for the classic split...")
|
|
train_loader, val_loader, test_loader, target_scaler, input_size = prepare_fold_data_and_loaders(
|
|
full_df=full_df,
|
|
train_idx=train_idx,
|
|
val_idx=val_idx,
|
|
test_idx=test_idx,
|
|
target_col=config.data.target_col,
|
|
feature_config=config.features,
|
|
train_config=config.training,
|
|
eval_config=config.evaluation
|
|
)
|
|
logger.info(f"Data loaders prepared. Input size determined: {input_size}")
|
|
|
|
# Save artifacts specific to this run if needed (e.g., for later inference)
|
|
torch.save(test_loader, classic_output_dir / "classic_test_loader.pt")
|
|
torch.save(target_scaler, classic_output_dir / "classic_target_scaler.pt")
|
|
torch.save(input_size, classic_output_dir / "classic_input_size.pt")
|
|
# Save config for this run
|
|
try: config_dump = config.model_dump()
|
|
except AttributeError: config_dump = config.model_dump()
|
|
with open(classic_output_dir / "config.yaml", 'w') as f:
|
|
yaml.dump(config_dump, f, default_flow_style=False)
|
|
|
|
# --- Model Initialization ---
|
|
model = LSTMForecastLightningModule(
|
|
model_config=config.model,
|
|
train_config=config.training,
|
|
input_size=input_size,
|
|
target_scaler=target_scaler
|
|
)
|
|
logger.info("Classic LSTMForecastLightningModule initialized.")
|
|
|
|
# --- PyTorch Lightning Callbacks ---
|
|
monitor_metric = "val_MeanAbsoluteError" # Monitor same metric as CV folds
|
|
monitor_mode = "min"
|
|
|
|
early_stop_callback = None
|
|
if config.training.early_stopping_patience is not None and config.training.early_stopping_patience > 0:
|
|
early_stop_callback = EarlyStopping(
|
|
monitor=monitor_metric, min_delta=0.0001,
|
|
patience=config.training.early_stopping_patience, verbose=True, mode=monitor_mode
|
|
)
|
|
logger.info(f"Enabled EarlyStopping: monitor='{monitor_metric}', patience={config.training.early_stopping_patience}")
|
|
|
|
checkpoint_callback = ModelCheckpoint(
|
|
dirpath=classic_output_dir / "checkpoints",
|
|
filename="best_classic_model", # Simple filename
|
|
save_top_k=1, monitor=monitor_metric, mode=monitor_mode, verbose=True
|
|
)
|
|
logger.info(f"Enabled ModelCheckpoint: monitor='{monitor_metric}', mode='{monitor_mode}'")
|
|
|
|
lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
|
callbacks = [checkpoint_callback, lr_monitor]
|
|
if early_stop_callback: callbacks.append(early_stop_callback)
|
|
|
|
# --- PyTorch Lightning Logger ---
|
|
pl_logger = CSVLogger(save_dir=str(classic_output_dir), name="training_logs")
|
|
logger.info(f"Using CSVLogger, logs will be saved in: {pl_logger.log_dir}")
|
|
|
|
# --- PyTorch Lightning Trainer ---
|
|
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
|
|
devices = 1 if accelerator == 'gpu' else None
|
|
precision = getattr(config.training, 'precision', 32)
|
|
|
|
trainer = pl.Trainer(
|
|
accelerator=accelerator, devices=devices,
|
|
max_epochs=config.training.epochs,
|
|
callbacks=callbacks, logger=pl_logger,
|
|
log_every_n_steps=max(1, len(train_loader)//10),
|
|
enable_progress_bar=True,
|
|
gradient_clip_val=getattr(config.training, 'gradient_clip_val', None),
|
|
precision=precision,
|
|
)
|
|
logger.info(f"Initialized PyTorch Lightning Trainer: accelerator='{accelerator}', devices={devices}, precision={precision}")
|
|
|
|
# --- Training ---
|
|
logger.info("Starting classic model training...")
|
|
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
|
|
logger.info("Classic model training finished.")
|
|
|
|
# Store best validation score and path
|
|
best_val_score_tensor = trainer.checkpoint_callback.best_model_score
|
|
best_model_path = trainer.checkpoint_callback.best_model_path
|
|
best_val_score = best_val_score_tensor.item() if best_val_score_tensor is not None else None
|
|
|
|
if best_val_score is not None:
|
|
logger.info(f"Best validation score ({monitor_metric}): {best_val_score:.4f}")
|
|
logger.info(f"Best model checkpoint path: {best_model_path}")
|
|
else:
|
|
logger.warning(f"Could not retrieve best validation score/path (metric: {monitor_metric}). Evaluation might use last model.")
|
|
best_model_path = None
|
|
|
|
# --- Prediction on Test Set ---
|
|
logger.info("Starting prediction on classic test set using best checkpoint...")
|
|
prediction_results_list = trainer.predict(
|
|
ckpt_path=best_model_path if best_model_path else 'last',
|
|
dataloaders=test_loader
|
|
)
|
|
|
|
# --- Evaluation ---
|
|
if not prediction_results_list:
|
|
logger.error("Predict phase did not return any results for classic run.")
|
|
test_metrics = None
|
|
else:
|
|
try:
|
|
# Shapes: (n_samples, len(horizons))
|
|
all_preds_scaled = torch.cat([b['preds_scaled'] for b in prediction_results_list], dim=0).numpy()
|
|
n_predictions = len(all_preds_scaled) # Number of samples actually predicted
|
|
|
|
if 'targets_scaled' in prediction_results_list[0]:
|
|
all_targets_scaled = torch.cat([b['targets_scaled'] for b in prediction_results_list], dim=0).numpy()
|
|
if len(all_targets_scaled) != n_predictions:
|
|
logger.error(f"Classic Run: Mismatch between number of predictions ({n_predictions}) and targets ({len(all_targets_scaled)}).")
|
|
raise ValueError("Prediction and target count mismatch during classic evaluation.")
|
|
else:
|
|
raise ValueError("Targets missing from prediction results.")
|
|
|
|
logger.info(f"Processing {n_predictions} prediction results for classic test set...")
|
|
|
|
# --- Calculate Correct Time Index for Plotting (First Horizon) ---
|
|
target_time_index_for_plotting = None
|
|
if test_idx is not None and config.features.forecast_horizon:
|
|
try:
|
|
test_block_index = full_df.index[test_idx] # Use the test_idx from classic split
|
|
seq_len = config.features.sequence_length
|
|
first_horizon = config.features.forecast_horizon[0]
|
|
start_offset = seq_len + first_horizon - 1
|
|
if start_offset < len(test_block_index):
|
|
end_index = min(start_offset + n_predictions, len(test_block_index))
|
|
target_time_index_for_plotting = test_block_index[start_offset:end_index]
|
|
if len(target_time_index_for_plotting) != n_predictions:
|
|
logger.warning(f"Classic Run: Calculated target time index length ({len(target_time_index_for_plotting)}) "
|
|
f"does not match prediction count ({n_predictions}). Plotting x-axis might be misaligned.")
|
|
target_time_index_for_plotting = None
|
|
else:
|
|
logger.warning(f"Classic Run: Cannot calculate target time index, start offset ({start_offset}) "
|
|
f"exceeds test block length ({len(test_block_index)}).")
|
|
except Exception as e:
|
|
logger.error(f"Classic Run: Error calculating target time index for plotting: {e}", exc_info=True)
|
|
target_time_index_for_plotting = None # Ensure it's None if error occurs
|
|
else:
|
|
logger.warning(f"Classic Run: Skipping target time index calculation (missing test_idx or forecast_horizon).")
|
|
# --- End Index Calculation ---
|
|
|
|
# Use the classic run specific objects and config
|
|
test_metrics = evaluate_fold_predictions(
|
|
y_true_scaled=all_targets_scaled,
|
|
y_pred_scaled=all_preds_scaled,
|
|
target_scaler=target_scaler,
|
|
eval_config=config.evaluation,
|
|
fold_num=-1, # Indicate classic run
|
|
output_dir=str(classic_output_dir),
|
|
plot_subdir="plots",
|
|
prediction_time_index=target_time_index_for_plotting, # Pass the correctly calculated index
|
|
forecast_horizons=config.features.forecast_horizon,
|
|
plot_title_prefix="Classic Run"
|
|
)
|
|
# Save metrics
|
|
save_results({"overall_metrics": test_metrics}, classic_output_dir / "test_metrics.json")
|
|
logger.info(f"Classic run test metrics (overall): {test_metrics}")
|
|
|
|
# --- Plot Loss Curve for Classic Run ---
|
|
try:
|
|
# Adjusted logic to find metrics.csv inside potential version_*/ directories
|
|
classic_log_dir = classic_output_dir / "training_logs"
|
|
metrics_file = None
|
|
version_dirs = list(classic_log_dir.glob("version_*"))
|
|
if version_dirs:
|
|
# Assuming the latest version directory contains the relevant logs
|
|
latest_version_dir = max(version_dirs, key=lambda p: p.stat().st_mtime)
|
|
potential_metrics_file = latest_version_dir / "metrics.csv"
|
|
if potential_metrics_file.is_file():
|
|
metrics_file = potential_metrics_file
|
|
else:
|
|
logger.warning(f"Classic Run: metrics.csv not found in latest version directory: {latest_version_dir}")
|
|
else:
|
|
# Fallback if no version_* directories exist (less common with CSVLogger)
|
|
potential_metrics_file = classic_log_dir / "metrics.csv"
|
|
if potential_metrics_file.is_file():
|
|
metrics_file = potential_metrics_file
|
|
|
|
if metrics_file and metrics_file.is_file():
|
|
plot_loss_curve_from_csv(
|
|
metrics_csv_path=metrics_file,
|
|
output_path=classic_output_dir / "loss_curve.png",
|
|
title="Classic Run Training Progression",
|
|
train_loss_col='train_loss', # Changed from 'train_loss_epoch'
|
|
val_loss_col='val_loss' # Keep as 'val_loss'
|
|
)
|
|
logger.info(f"Generating loss curve for classic run from: {metrics_file}")
|
|
else:
|
|
logger.warning(f"Classic Run: Could not find metrics.csv in {classic_log_dir} or its version subdirectories for loss curve plot.")
|
|
except Exception as plot_e:
|
|
logger.error(f"Classic Run: Failed to generate loss curve plot: {plot_e}", exc_info=True)
|
|
# --- End Classic Loss Plotting ---
|
|
|
|
except (KeyError, ValueError, Exception) as e:
|
|
logger.error(f"Error processing classic prediction results: {e}", exc_info=True)
|
|
test_metrics = None
|
|
|
|
except Exception as e:
|
|
logger.error(f"An error occurred during the classic training pipeline: {e}", exc_info=True)
|
|
test_metrics = None # Indicate failure
|
|
|
|
finally:
|
|
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
|
run_end_time = time.perf_counter()
|
|
logger.info(f"--- Finished Classic Training Run in {run_end_time - run_start_time:.2f} seconds ---")
|
|
return test_metrics
|