from argparse import Namespace import warnings import yaml from pytorch_lightning import Trainer, Callback from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from ml_lib.utils.callbacks import BestScoresCallback from ml_lib.utils.config import parse_comandline_args_add_defaults from ml_lib.utils.loggers import LightningLogger import variables as v from ml_lib.utils.tools import fix_all_random_seeds warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) def run_lightning_loop(h_params :Namespace, data_class, model_class, seed=69, additional_callbacks=None): fix_all_random_seeds(seed) with LightningLogger.from_argparse_args(h_params) as logger: # Callbacks # ============================================================================= # Checkpoint Saving ckpt_callback = ModelCheckpoint( monitor='PL_recall_score', dirpath=str(logger.log_dir), filename='ckpt_weights', mode='max', verbose=False, save_top_k=3, save_last=True ) # Learning Rate Logger lr_logger = LearningRateMonitor(logging_interval='epoch') # Track best scores score_callback = BestScoresCallback(['PL_recall_score']) callbacks = [ckpt_callback, lr_logger, score_callback] if additional_callbacks and isinstance(additional_callbacks, Callback): callbacks.append(additional_callbacks) elif additional_callbacks and isinstance(additional_callbacks, list): callbacks.extend(additional_callbacks) else: pass # START # ============================================================================= # Let Datamodule pull what it wants and init datamodule = data_class.from_argparse_args(h_params) # Final h_params Setup: h_params = vars(h_params) try: h_params.update(in_shape=datamodule.shape, n_classes=datamodule.n_classes) except KeyError: datamodule.manual_setup() datamodule.prepare_data() h_params.update(in_shape=datamodule.shape, n_classes=datamodule.n_classes) h_params = Namespace(**h_params) # Let Trainer pull what it wants and add callbacks trainer = Trainer.from_argparse_args(h_params, logger=logger, callbacks=callbacks) # Let Model pull what it wants model = model_class.from_argparse_args(h_params) model.init_weights() # Store Model in Object File: model.save_to_disk(logger.save_dir) # Store h_params to yaml_file File & Neptune (if available): logger.log_hyperparams(h_params) trainer.fit(model, datamodule) trainer.save_checkpoint(logger.save_dir / 'last_weights.ckpt') trainer.test(model=model, datamodule=datamodule, ckpt_path='best') return Namespace(model=model, best_model_path=ckpt_callback.best_model_path, best_model_score=ckpt_callback.best_model_score.item(), max_score_monitor=score_callback.best_scores) if __name__ == '__main__': # Parse comandline args, read config and get model cmd_args, found_data_class, found_model_class, found_seed = parse_comandline_args_add_defaults('_parameters.ini') # To NameSpace hparams = Namespace(**cmd_args) # Start # ----------------- print(run_lightning_loop(hparams, found_data_class, found_model_class, found_seed)) print('done') pass