# Imports # ============================================================================= from pathlib import Path import warnings import torch from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor from ml_lib.modules.util import LightningBaseModule from ml_lib.utils.config import Config from ml_lib.utils.logging import Logger # Project Specific Logger SubClasses warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) def fix_all_random_seeds(config_obj): import numpy as np import torch import random np.random.seed(config_obj.main.seed) torch.manual_seed(config_obj.main.seed) random.seed(config_obj.main.seed) def run_lightning_loop(config_obj): # Logging # ================================================================================ # Logger with Logger(config_obj) as logger: # Callbacks # ============================================================================= # Checkpoint Saving ckpt_callback = ModelCheckpoint( monitor='mean_loss', filepath=str(logger.log_dir / 'ckpt_weights'), verbose=False, save_top_k=5, ) # Learning Rate Logger lr_logger = LearningRateMonitor(logging_interval='epoch') # Trainer # ============================================================================= trainer = Trainer(max_epochs=config_obj.train.epochs, weights_save_path=logger.log_dir, gpus=[0] if torch.cuda.is_available() else None, check_val_every_n_epoch=10, # num_sanity_val_steps=config_obj.train.num_sanity_val_steps, # row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting # log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting checkpoint_callback=True, callbacks=[lr_logger, ckpt_callback], logger=logger, fast_dev_run=config_obj.main.debug, auto_lr_find=not config_obj.main.debug ) # Model # ============================================================================= # Build and Init its Weights model: LightningBaseModule = config_obj.build_and_init_model() # Log paramters pytorch_total_params = sum(p.numel() for p in model.parameters()) logger.log_text('n_parameters', pytorch_total_params) # Train It if config_obj.model.type.lower() != 'ensemble': if not config_obj.main.debug and not config_obj.train.lr: trainer.tune(model) # ToDo: LR Finder Plot # fig = lr_finder.plot(suggest=True) trainer.fit(model) # Save the last state & all parameters trainer.save_checkpoint(str(logger.log_dir / 'weights.ckpt')) model.save_to_disk(logger.log_dir) # trainer.run_evaluation(test_mode=True) return model if __name__ == "__main__": from _paramters import main_arg_parser config = Config.read_argparser(main_arg_parser) fix_all_random_seeds(config) trained_model = run_lightning_loop(config)