# Imports # ============================================================================= from pathlib import Path from tqdm import tqdm import warnings import torch from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from ml_lib.modules.utils import LightningBaseModule from ml_lib.utils.logging import Logger # Project Specific Logger SubClasses from util.config import MConfig warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) def fix_all_random_seeds(config_obj): import numpy as np import torch import random np.random.seed(config.main.seed) torch.manual_seed(config.main.seed) random.seed(config.main.seed) def run_lightning_loop(config_obj): # Logging # ================================================================================ # Logger with Logger(config_obj) as logger: # Callbacks # ============================================================================= # Checkpoint Saving checkpoint_callback = ModelCheckpoint( monitor='uar_score', filepath=str(logger.log_dir / 'ckpt_weights'), verbose=False, save_top_k=5, ) # Early Stopping # TODO: For This to work, set a validation step and End Eval and Score early_stopping_callback = EarlyStopping( monitor='uar_score', min_delta=0.01, patience=10, ) # Trainer # ============================================================================= trainer = Trainer(max_epochs=config_obj.train.epochs, show_progress_bar=True, weights_save_path=logger.log_dir, gpus=[0] if torch.cuda.is_available() else None, check_val_every_n_epoch=10, # num_sanity_val_steps=config_obj.train.num_sanity_val_steps, # row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting # log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting checkpoint_callback=checkpoint_callback, logger=logger, fast_dev_run=config_obj.main.debug, early_stop_callback=None ) # Model # ============================================================================= # Build and Init its Weights model: LightningBaseModule = config_obj.build_and_init_model() # Log paramters pytorch_total_params = sum(p.numel() for p in model.parameters()) logger.log_text('n_parameters', pytorch_total_params) # Train It if config_obj.model.type.lower() != 'ensemble': trainer.fit(model) # Save the last state & all parameters trainer.save_checkpoint(logger.log_dir / 'weights.ckpt') model.save_to_disk(logger.log_dir) # Evaluate It if config_obj.main.eval: with torch.no_grad(): model.eval() if torch.cuda.is_available(): model.cuda() outputs = [] from tqdm import tqdm for idx, batch in enumerate(tqdm(model.val_dataloader()[0])): batch_x, label = batch batch_x = batch_x.to(device='cuda' if model.on_gpu else 'cpu') label = label.to(device='cuda' if model.on_gpu else 'cpu') outputs.append( model.validation_step((batch_x, label), idx, 1) ) summary_dict = model.validation_epoch_end([outputs]) print(summary_dict['log']['uar_score']) # trainer.test() outpath = Path(config_obj.train.outpath) model_type = config_obj.model.type parameters = logger.name version = f'version_{logger.version}' inference_out = f'{parameters}_test_out.csv' from main_inference import prepare_dataloader test_dataloader = prepare_dataloader(config) with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile: outfile.write(f'file_name,prediction\n') from tqdm import tqdm for batch in tqdm(test_dataloader, total=len(test_dataloader)): batch_x, file_name = batch batch_x = batch_x.unsqueeze(0).to(device='cuda' if model.on_gpu else 'cpu') y = model(batch_x).main_out prediction = (y.squeeze() >= 0.5).int().item() import variables as V prediction = 'clear' if prediction == V.CLEAR else 'mask' outfile.write(f'{file_name},{prediction}\n') return model if __name__ == "__main__": from _paramters import main_arg_parser config = MConfig.read_argparser(main_arg_parser) fix_all_random_seeds(config) trained_model = run_lightning_loop(config)