# Imports # ============================================================================= from pathlib import Path import warnings import torch from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor from ml_lib.modules.util import LightningBaseModule from ml_lib.utils.config import Config from ml_lib.utils.logging import Logger # Project Specific Logger SubClasses warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) def fix_all_random_seeds(config_obj): import numpy as np import torch import random np.random.seed(config_obj.main.seed) torch.manual_seed(config_obj.main.seed) random.seed(config_obj.main.seed) def run_lightning_loop(config_obj): # Logging # ================================================================================ # Logger with Logger(config_obj) as logger: # Callbacks # ============================================================================= # Checkpoint Saving ckpt_callback = ModelCheckpoint( monitor='mean_loss', filepath=str(logger.log_dir / 'ckpt_weights'), verbose=False, save_top_k=5, ) # Learning Rate Logger lr_logger = LearningRateMonitor(logging_interval='epoch') # Trainer # ============================================================================= trainer = Trainer(max_epochs=config_obj.train.epochs, weights_save_path=logger.log_dir, gpus=[0] if torch.cuda.is_available() else None, check_val_every_n_epoch=10, # num_sanity_val_steps=config_obj.train.num_sanity_val_steps, # row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting # log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting checkpoint_callback=True, callbacks=[lr_logger, ckpt_callback], logger=logger, fast_dev_run=config_obj.main.debug, auto_lr_find=not config_obj.main.debug ) # Model # ============================================================================= # Build and Init its Weights model: LightningBaseModule = config_obj.build_and_init_model() # Log paramters pytorch_total_params = sum(p.numel() for p in model.parameters()) logger.log_text('n_parameters', pytorch_total_params) # Train It if config_obj.model.type.lower() != 'ensemble': if not config_obj.main.debug and not config_obj.train.lr: trainer.tune(model) # ToDo: LR Finder Plot # fig = lr_finder.plot(suggest=True) trainer.fit(model) # Save the last state & all parameters trainer.save_checkpoint(str(logger.log_dir / 'weights.ckpt')) model.save_to_disk(logger.log_dir) # Evaluate It if config_obj.main.eval: with torch.no_grad(): model.eval() if torch.cuda.is_available(): model.cuda() outputs = [] from tqdm import tqdm for idx, batch in enumerate(tqdm(model.val_dataloader()[0])): batch_x, label = batch batch_x = batch_x.to(device='cuda' if model.on_gpu else 'cpu') label = label.to(device='cuda' if model.on_gpu else 'cpu') outputs.append( model.validation_step((batch_x, label), idx, 1) ) model.validation_epoch_end([outputs]) # trainer.test() outpath = Path(config_obj.train.outpath) model_type = config_obj.model.type parameters = logger.name version = f'version_{logger.version}' inference_out = f'{parameters}_test_out.csv' from main_inference import prepare_dataloader import variables as V test_dataloader = prepare_dataloader(config_obj) with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile: outfile.write(f'file_name,prediction\n') from tqdm import tqdm for batch in tqdm(test_dataloader, total=len(test_dataloader)): batch_x, file_names = batch batch_x = batch_x.to(device='cuda' if model.on_gpu else 'cpu') y = model(batch_x).main_out predictions = (y >= 0.5).int() for prediction, file_name in zip(predictions, file_names): prediction_text = 'clear' if prediction == V.CLEAR else 'mask' outfile.write(f'{file_name},{prediction_text}\n') return model if __name__ == "__main__": from _paramters import main_arg_parser config = Config.read_argparser(main_arg_parser) fix_all_random_seeds(config) trained_model = run_lightning_loop(config)