intermediate backup
This commit is contained in:
@ -13,21 +13,24 @@ from pytorch_lightning.loggers import CSVLogger
|
||||
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
||||
|
||||
# Import necessary components from your project structure
|
||||
# Assuming forecasting_model is a package installable or in PYTHONPATH
|
||||
from forecasting_model.utils.forecast_config_model import MainConfig
|
||||
from forecasting_model.data_processing import (
|
||||
load_raw_data,
|
||||
TimeSeriesCrossValidationSplitter,
|
||||
from forecasting_model.utils.data_processing import (
|
||||
prepare_fold_data_and_loaders
|
||||
)
|
||||
from forecasting_model.utils.dataset_splitter import TimeSeriesCrossValidationSplitter
|
||||
from forecasting_model.io.data import load_raw_data
|
||||
from forecasting_model.train.model import LSTMForecastLightningModule
|
||||
from forecasting_model.evaluation import evaluate_fold_predictions
|
||||
from forecasting_model.utils.evaluation import evaluate_fold_predictions
|
||||
from forecasting_model.train.ensemble_evaluation import run_ensemble_evaluation
|
||||
|
||||
# Import the new classic training function
|
||||
from forecasting_model.train.classic import run_classic_training
|
||||
from forecasting_model.train.classic import run_model_training
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from forecasting_model.utils.helper import parse_arguments, load_config, set_seeds, aggregate_cv_metrics, save_results
|
||||
from forecasting_model.utils.helper import (
|
||||
parse_arguments, load_config,
|
||||
set_seeds, aggregate_cv_metrics,
|
||||
save_results, calculate_h1_target_index
|
||||
)
|
||||
from forecasting_model.io.plotting import plot_loss_curve_from_csv, create_multi_horizon_time_series_plot, save_plot
|
||||
|
||||
# Silence overly verbose libraries if needed
|
||||
@ -46,6 +49,7 @@ logger = logging.getLogger()
|
||||
|
||||
|
||||
# --- Single Fold Processing Function ---
|
||||
# noinspection PyInconsistentReturns
|
||||
def run_single_fold(
|
||||
fold_num: int,
|
||||
train_idx: np.ndarray,
|
||||
@ -53,8 +57,9 @@ def run_single_fold(
|
||||
test_idx: np.ndarray,
|
||||
config: MainConfig,
|
||||
full_df: pd.DataFrame,
|
||||
output_base_dir: Path # Receives Path object from run_training_pipeline
|
||||
) -> Tuple[Dict[str, float], Optional[float], Optional[Path], Optional[Path], Optional[Path], Optional[Path]]:
|
||||
output_base_dir: Path,
|
||||
enable_progress_bar: bool = True
|
||||
) -> Optional[Tuple[Dict[str, float], Optional[float], Optional[Path], Optional[Path], Optional[Path], Optional[Path], Optional[Path]]]:
|
||||
"""
|
||||
Runs the pipeline for a single cross-validation fold.
|
||||
|
||||
@ -66,6 +71,7 @@ def run_single_fold(
|
||||
config: The main configuration object.
|
||||
full_df: The complete raw DataFrame.
|
||||
output_base_dir: The base directory Path for saving results.
|
||||
enable_progress_bar: Whether to enable progress bar.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
@ -73,6 +79,7 @@ def run_single_fold(
|
||||
- best_val_score: The best validation score achieved during training (or None).
|
||||
- saved_model_path: Path to the best saved model checkpoint (or None).
|
||||
- saved_target_scaler_path: Path to the saved target scaler (or None).
|
||||
- saved_data_scaler_path: Path to the saved data feature scaler (or None).
|
||||
- saved_input_size_path: Path to the saved input size file (or None).
|
||||
- saved_config_path: Path to the saved config file for this fold (or None).
|
||||
"""
|
||||
@ -92,19 +99,23 @@ def run_single_fold(
|
||||
all_preds_scaled: Optional[np.ndarray] = None
|
||||
all_targets_scaled: Optional[np.ndarray] = None
|
||||
target_scaler: Optional[Union[StandardScaler, MinMaxScaler]] = None # Need to keep scaler reference
|
||||
data_scaler: Optional[Union[StandardScaler, MinMaxScaler]] = None # Added to keep data scaler reference
|
||||
prediction_target_time_index_h1: Optional[pd.DatetimeIndex] = None
|
||||
pl_logger = None
|
||||
|
||||
# Variables to store paths of saved artifacts
|
||||
saved_model_path: Optional[Path] = None
|
||||
saved_target_scaler_path: Optional[Path] = None
|
||||
saved_data_scaler_path: Optional[Path] = None # Added
|
||||
saved_input_size_path: Optional[Path] = None
|
||||
saved_config_path: Optional[Path] = None
|
||||
|
||||
try:
|
||||
# --- Per-Fold Data Preparation ---
|
||||
logger.info("Preparing data loaders for the fold...")
|
||||
# Keep scaler and input_size references returned by prepare_fold_data_and_loaders
|
||||
train_loader, val_loader, test_loader, target_scaler_fold, input_size = prepare_fold_data_and_loaders( # Renamed target_scaler
|
||||
# Assume prepare_fold_data_and_loaders returns the data_scaler as the 5th element
|
||||
# Modify this call based on the actual return signature of prepare_fold_data_and_loaders
|
||||
train_loader, val_loader, test_loader, target_scaler_fold, data_scaler_fold, input_size = prepare_fold_data_and_loaders(
|
||||
full_df=full_df,
|
||||
train_idx=train_idx,
|
||||
val_idx=val_idx,
|
||||
@ -114,13 +125,17 @@ def run_single_fold(
|
||||
train_config=config.training,
|
||||
eval_config=config.evaluation
|
||||
)
|
||||
target_scaler = target_scaler_fold # Store the scaler in the outer scope
|
||||
target_scaler = target_scaler_fold # Store the target scaler in the outer scope
|
||||
data_scaler = data_scaler_fold # Store the data scaler in the outer scope
|
||||
logger.info(f"Data loaders prepared. Input size determined: {input_size}")
|
||||
|
||||
# Save necessary items for potential later use (e.g., ensemble)
|
||||
# Save necessary items for potential later use (e.g., ensemble, inference)
|
||||
# Capture the paths when saving
|
||||
saved_target_scaler_path = fold_output_dir / "target_scaler.pt"
|
||||
torch.save(target_scaler, saved_target_scaler_path)
|
||||
saved_data_scaler_path = fold_output_dir / "data_scaler.pt"
|
||||
torch.save(data_scaler, saved_data_scaler_path)
|
||||
|
||||
torch.save(test_loader, fold_output_dir / "test_loader.pt") # Test loader might be large, consider if needed
|
||||
|
||||
# Save input size and capture path
|
||||
@ -140,13 +155,14 @@ def run_single_fold(
|
||||
model_config=config.model,
|
||||
train_config=config.training,
|
||||
input_size=input_size,
|
||||
target_scaler=target_scaler_fold # Pass scaler during init
|
||||
target_scaler=target_scaler_fold,
|
||||
data_scaler=data_scaler
|
||||
)
|
||||
logger.info("LSTMForecastLightningModule initialized.")
|
||||
|
||||
# --- PyTorch Lightning Callbacks ---
|
||||
# Ensure monitor_metric matches the exact name logged in model.py
|
||||
monitor_metric = "val_MeanAbsoluteError_Original_Scale" # Corrected metric name
|
||||
monitor_metric = "val_MeanAbsoluteError" # Corrected metric name
|
||||
monitor_mode = "min"
|
||||
|
||||
early_stop_callback = None
|
||||
@ -174,6 +190,7 @@ def run_single_fold(
|
||||
|
||||
callbacks = [checkpoint_callback, lr_monitor]
|
||||
if early_stop_callback:
|
||||
# noinspection PyTypeChecker
|
||||
callbacks.append(early_stop_callback)
|
||||
|
||||
# --- PyTorch Lightning Logger ---
|
||||
@ -190,12 +207,13 @@ def run_single_fold(
|
||||
|
||||
trainer = pl.Trainer(
|
||||
accelerator=accelerator,
|
||||
check_val_every_n_epoch=config.training.check_val_n_epoch,
|
||||
devices=devices,
|
||||
enable_progress_bar=enable_progress_bar,
|
||||
max_epochs=config.training.epochs,
|
||||
callbacks=callbacks,
|
||||
logger=pl_logger,
|
||||
log_every_n_steps=max(1, len(train_loader)//10),
|
||||
enable_progress_bar=True,
|
||||
gradient_clip_val=getattr(config.training, 'gradient_clip_val', None),
|
||||
precision=precision,
|
||||
)
|
||||
@ -262,59 +280,33 @@ def run_single_fold(
|
||||
logger.info(f"Processing {n_predictions} prediction results for Fold {fold_id}...")
|
||||
|
||||
# --- Calculate Correct Time Index for Plotting (First Horizon) ---
|
||||
prediction_target_time_index_h1 = calculate_h1_target_index(
|
||||
full_df=full_df,
|
||||
test_idx=test_idx,
|
||||
sequence_length=config.features.sequence_length,
|
||||
forecast_horizon=config.features.forecast_horizon,
|
||||
n_predictions=n_predictions,
|
||||
fold_id=fold_id
|
||||
)
|
||||
|
||||
# --- Handle Saving/Cleanup of the Index File ---
|
||||
prediction_target_time_index_h1_path = fold_output_dir / "prediction_target_time_index_h1.pt"
|
||||
|
||||
prediction_target_time_index_h1 = None
|
||||
|
||||
if test_idx is not None and config.features.forecast_horizon and len(config.features.forecast_horizon) > 0:
|
||||
if prediction_target_time_index_h1 is not None and config.evaluation.save_plots:
|
||||
# Save the calculated index if valid and plots are enabled
|
||||
try:
|
||||
test_block_index = full_df.index[test_idx]
|
||||
seq_len = config.features.sequence_length
|
||||
first_horizon = config.features.forecast_horizon[0]
|
||||
torch.save(prediction_target_time_index_h1, prediction_target_time_index_h1_path)
|
||||
logger.debug(f"Saved prediction target time index for h1 to {prediction_target_time_index_h1_path}")
|
||||
except Exception as save_e:
|
||||
logger.warning(f"Failed to save prediction target time index file {prediction_target_time_index_h1_path}: {save_e}")
|
||||
elif prediction_target_time_index_h1_path.exists():
|
||||
# Remove outdated file if index is invalid/not calculated OR plots disabled
|
||||
logger.debug(f"Removing potentially outdated time index file: {prediction_target_time_index_h1_path}")
|
||||
try:
|
||||
prediction_target_time_index_h1_path.unlink()
|
||||
except OSError as e:
|
||||
logger.warning(f"Could not remove outdated prediction target index h1 file {prediction_target_time_index_h1_path}: {e}")
|
||||
|
||||
target_indices_h1 = test_idx + seq_len + first_horizon - 1
|
||||
|
||||
valid_target_indices_h1_mask = target_indices_h1 < len(full_df)
|
||||
valid_target_indices_h1 = target_indices_h1[valid_target_indices_h1_mask]
|
||||
|
||||
if len(valid_target_indices_h1) >= n_predictions: # Should be exactly n_predictions if no indices were out of bounds
|
||||
prediction_target_time_index_h1 = full_df.index[valid_target_indices_h1[:n_predictions]]
|
||||
if len(prediction_target_time_index_h1) != n_predictions:
|
||||
logger.warning(f"Fold {fold_id}: Calculated target time index length ({len(prediction_target_time_index_h1)}) "
|
||||
f"does not match prediction count ({n_predictions}). Plotting x-axis might be misaligned.")
|
||||
prediction_target_time_index_h1 = None
|
||||
|
||||
else:
|
||||
logger.warning(f"Fold {fold_id}: Cannot calculate target time index for h1; insufficient valid indices ({len(valid_target_indices_h1)} < {n_predictions}).")
|
||||
prediction_target_time_index_h1 = None
|
||||
|
||||
|
||||
# Save the calculated index if it's valid and evaluation plots are enabled
|
||||
if prediction_target_time_index_h1 is not None and not prediction_target_time_index_h1.empty and config.evaluation.save_plots:
|
||||
try:
|
||||
torch.save(prediction_target_time_index_h1, prediction_target_time_index_h1_path)
|
||||
logger.debug(f"Saved prediction target time index for h1 to {prediction_target_time_index_h1_path}")
|
||||
except Exception as save_e:
|
||||
logger.warning(f"Failed to save prediction target time index file {prediction_target_time_index_h1_path}: {save_e}")
|
||||
|
||||
elif prediction_target_time_index_h1_path.exists():
|
||||
try:
|
||||
prediction_target_time_index_h1_path.unlink()
|
||||
logger.debug("Removed outdated prediction target time index h1 file.")
|
||||
except OSError as e:
|
||||
logger.warning(f"Could not remove outdated prediction target index h1 file {prediction_target_time_index_h1_path}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Fold {fold_id}: Error calculating or saving target time index for plotting: {e}", exc_info=True)
|
||||
prediction_target_time_index_h1 = None
|
||||
else:
|
||||
logger.warning(f"Fold {fold_id}: Skipping target time index calculation (missing test_idx, forecast_horizon, or empty list).")
|
||||
if prediction_target_time_index_h1_path.exists():
|
||||
try:
|
||||
prediction_target_time_index_h1_path.unlink()
|
||||
logger.debug("Removed outdated prediction target time index h1 file as calculation was skipped.")
|
||||
except OSError as e:
|
||||
logger.warning(f"Could not remove outdated prediction target index h1 file {prediction_target_time_index_h1_path}: {e}")
|
||||
# --- End Index Calculation and Saving ---
|
||||
|
||||
|
||||
@ -324,6 +316,7 @@ def run_single_fold(
|
||||
y_true_scaled=all_targets_scaled, # Pass the (N, H) array
|
||||
y_pred_scaled=all_preds_scaled, # Pass the (N, H) array
|
||||
target_scaler=target_scaler,
|
||||
data_scaler=data_scaler,
|
||||
eval_config=config.evaluation,
|
||||
fold_num=fold_num, # Pass zero-based index
|
||||
output_dir=str(fold_output_dir),
|
||||
@ -331,7 +324,7 @@ def run_single_fold(
|
||||
# Pass the calculated index for the targets being plotted (h1 reference)
|
||||
prediction_time_index=prediction_target_time_index_h1, # Use the calculated index here (for h1)
|
||||
forecast_horizons=config.features.forecast_horizon, # Pass the list of horizons
|
||||
plot_title_prefix=f"CV Fold {fold_id}"
|
||||
plot_title_prefix=f"CV Fold {fold_id}",
|
||||
)
|
||||
save_results(fold_metrics, fold_output_dir / "test_metrics.json")
|
||||
else:
|
||||
@ -376,31 +369,14 @@ def run_single_fold(
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing prediction results for Fold {fold_id}: {e}", exc_info=True)
|
||||
|
||||
# --- Plot Loss Curve for Fold ---
|
||||
try:
|
||||
actual_log_dir = Path(pl_logger.log_dir) / pl_logger.name # Should be .../fold_XX/training_logs
|
||||
metrics_file_path = actual_log_dir / "metrics.csv"
|
||||
|
||||
if metrics_file_path.is_file():
|
||||
plot_loss_curve_from_csv(
|
||||
metrics_csv_path=metrics_file_path,
|
||||
output_path=fold_output_dir / "plots" / "loss_curve.png", # Save in plots subdir
|
||||
title=f"Fold {fold_id} Training Progression",
|
||||
train_loss_col='train_loss',
|
||||
val_loss_col='val_loss' # This function handles fallback
|
||||
)
|
||||
logger.info(f"Loss curve plot saved for Fold {fold_id} to {fold_output_dir / 'plots' / 'loss_curve.png'}.")
|
||||
else:
|
||||
logger.warning(f"Fold {fold_id}: Could not find metrics.csv at {metrics_file_path} for loss curve plot.")
|
||||
except Exception as e:
|
||||
logger.error(f"Fold {fold_id}: Failed to generate loss curve plot: {e}", exc_info=True)
|
||||
# --- End Loss Curve Plotting ---
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred during Fold {fold_id} pipeline: {e}", exc_info=True)
|
||||
# Ensure paths are None if an error occurs before they are set
|
||||
if saved_model_path is None: saved_model_path = None
|
||||
if saved_target_scaler_path is None: saved_target_scaler_path = None
|
||||
if saved_data_scaler_path is None: saved_data_scaler_path = None # Added check
|
||||
if saved_input_size_path is None: saved_input_size_path = None
|
||||
if saved_config_path is None: saved_config_path = None
|
||||
|
||||
|
||||
finally:
|
||||
@ -413,11 +389,39 @@ def run_single_fold(
|
||||
# Delete loaders explicitly if they might hold references
|
||||
del train_loader, val_loader, test_loader
|
||||
|
||||
# --- Plot Loss Curve for Fold ---
|
||||
if pl_logger and hasattr(pl_logger, 'log_dir') and pl_logger.log_dir: # Check if logger exists and has log_dir
|
||||
try:
|
||||
# Use the logger's log_dir directly, it already includes the 'name' segment
|
||||
actual_log_dir = Path(pl_logger.log_dir) # FIX: Remove appending pl_logger.name
|
||||
metrics_file_path = actual_log_dir / "metrics.csv"
|
||||
|
||||
if metrics_file_path.is_file():
|
||||
plot_loss_curve_from_csv(
|
||||
metrics_csv_path=metrics_file_path,
|
||||
# Save plot inside the specific fold's plot directory
|
||||
output_path=fold_output_dir / "plots" / "loss_curve.png",
|
||||
title=f"Fold {fold_id} Training Progression",
|
||||
train_loss_col='train_loss', # Ensure these column names match your CSVLogger output
|
||||
val_loss_col='val_loss' # Ensure these column names match your CSVLogger output
|
||||
)
|
||||
logger.info(f"Loss curve plot saved for Fold {fold_id} to {fold_output_dir / 'plots' / 'loss_curve.png'}.")
|
||||
else:
|
||||
logger.warning(f"Fold {fold_id}: Could not find metrics.csv at {metrics_file_path} for loss curve plot.")
|
||||
except AttributeError:
|
||||
logger.warning(f"Fold {fold_id}: Could not plot loss curve, CSVLogger object or log_dir attribute missing.")
|
||||
except Exception as e:
|
||||
logger.error(f"Fold {fold_id}: Failed to generate loss curve plot: {e}", exc_info=True)
|
||||
else:
|
||||
logger.warning(f"Fold {fold_id}: Skipping loss curve plot generation as CSVLogger was not properly initialized or log_dir is missing.")
|
||||
# --- End Loss Curve Plotting ---
|
||||
|
||||
fold_end_time = time.perf_counter()
|
||||
logger.info(f"--- Finished Fold {fold_id} in {fold_end_time - fold_start_time:.2f} seconds ---")
|
||||
pass
|
||||
|
||||
# Return the calculated fold metrics, best validation score, and saved artifact paths
|
||||
return fold_metrics, best_val_score, saved_model_path, saved_target_scaler_path, saved_input_size_path, saved_config_path
|
||||
return fold_metrics, best_val_score, saved_model_path, saved_target_scaler_path, saved_data_scaler_path, saved_input_size_path, saved_config_path
|
||||
|
||||
|
||||
# --- Main Training & Evaluation Function ---
|
||||
@ -450,8 +454,8 @@ def run_training_pipeline(config: MainConfig, output_base_dir: Path):
|
||||
sys.exit(1)
|
||||
|
||||
for fold_num, (train_idx, val_idx, test_idx) in enumerate(cv_splitter.split()):
|
||||
# Unpack the two new return values from run_single_fold
|
||||
fold_metrics, best_val_score, saved_model_path, saved_target_scaler_path, _input_size_path, _config_path = run_single_fold(
|
||||
# Unpack the return values from run_single_fold, including the new data_scaler path
|
||||
fold_metrics, best_val_score, saved_model_path, saved_target_scaler_path, saved_data_scaler_path, _input_size_path, _config_path = run_single_fold(
|
||||
fold_num=fold_num,
|
||||
train_idx=train_idx,
|
||||
val_idx=val_idx,
|
||||
@ -503,7 +507,7 @@ def run_training_pipeline(config: MainConfig, output_base_dir: Path):
|
||||
classic_output_dir = output_base_dir / "classic_run" # Define dir for logging path
|
||||
try:
|
||||
# Call the original classic training function directly
|
||||
classic_metrics = run_classic_training(
|
||||
classic_metrics = run_model_training(
|
||||
config=config,
|
||||
full_df=df,
|
||||
output_base_dir=output_base_dir # It creates classic_run subdir internally
|
||||
@ -597,7 +601,7 @@ def run():
|
||||
|
||||
# --- Configuration Loading ---
|
||||
try:
|
||||
config = load_config(config_path)
|
||||
config = load_config(config_path, MainConfig)
|
||||
except Exception:
|
||||
# Error already logged in load_config
|
||||
sys.exit(1)
|
||||
@ -629,4 +633,7 @@ def run():
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
raise DeprecationWarning(
|
||||
"This was the intial class for training, is not maintained!\n Exiting...."
|
||||
)
|
||||
exit(-9999)
|
Reference in New Issue
Block a user