from argparse import Namespace import warnings from pytorch_lightning import Trainer from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from ml_lib.utils.config import parse_comandline_args_add_defaults from ml_lib.utils.loggers import Logger from ml_lib.utils.tools import locate_and_import_class, auto_cast import variables as v warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) def run_lightning_loop(h_params, data_class, model_class): with Logger.from_argparse_args(h_params) as logger: # Callbacks # ============================================================================= # Checkpoint Saving ckpt_callback = ModelCheckpoint( monitor='mean_loss', dirpath=str(logger.log_dir), filename='ckpt_weights', verbose=False, save_top_k=3, ) # Learning Rate Logger lr_logger = LearningRateMonitor(logging_interval='epoch') # START # ============================================================================= # Let Datamodule pull what it wants datamodule = data_class.from_argparse_args(h_params) datamodule.setup() # Let Trainer pull what it wants and add callbacks trainer = Trainer.from_argparse_args(h_params, logger=logger, callbacks=[ckpt_callback, lr_logger]) # Let Model pull what it wants model = model_class.from_argparse_args(h_params, in_shape=datamodule.shape, n_classes=v.N_CLASS_multi) model.init_weights() trainer.fit(model, datamodule) # Log paramters pytorch_total_params = sum(p.numel() for p in model.parameters()) # logger.log_text('n_parameters', pytorch_total_params) trainer.save_checkpoint(logger.save_dir / 'weights.ckpt') if __name__ == '__main__': # Parse comandline args, read config and get model cmd_args, found_data_class, found_model_class = parse_comandline_args_add_defaults('_parameters.ini') # To NameSpace hparams = Namespace(**cmd_args) # Start # ----------------- run_lightning_loop(hparams, found_data_class, found_model_class) print('done') pass