from collections import defaultdict from itertools import cycle from abc import ABC from argparse import Namespace import torch import numpy as np from numpy import interp from sklearn.metrics import roc_curve, auc, confusion_matrix, ConfusionMatrixDisplay, f1_score, roc_auc_score import matplotlib.pyplot as plt from torch import nn from torch.optim import Adam from torch_geometric.data import Data, DataLoader from torch_geometric.transforms import Compose, FixedPoints, NormalizeScale from torchcontrib.optim import SWA from ml_lib.modules.util import LightningBaseModule from ml_lib.utils.tools import to_one_hot from utils.project_settings import dataSplit class BaseOptimizerMixin: def configure_optimizers(self): assert isinstance(self, LightningBaseModule) opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay) if self.params.sto_weight_avg: # TODO: Make this glabaly available. opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05) return opt def on_train_end(self): assert isinstance(self, LightningBaseModule) for opt in self.trainer.optimizers: if isinstance(opt, SWA): opt.swap_swa_sgd() def on_epoch_end(self): assert isinstance(self, LightningBaseModule) if self.params.opt_reset_interval: if self.current_epoch % self.params.opt_reset_interval == 0: for opt in self.trainer.optimizers: opt.state = defaultdict(dict) class BaseTrainMixin: # Absolute Error absolute_loss = nn.L1Loss() # negative Log Likelyhood nll_loss = nn.NLLLoss() # Binary Cross Entropy bce_loss = nn.BCELoss() def training_step(self, batch_norm_pos_y, batch_nb, *_, **__): assert isinstance(self, LightningBaseModule) data = self.batch_to_data(batch_norm_pos_y) if not isinstance(batch_norm_pos_y, Data) else batch_norm_pos_y y = self(data).main_out nll_loss = self.nll_loss(y, data.y) return dict(loss=nll_loss, log=dict(batch_nb=batch_nb)) def training_epoch_end(self, outputs): assert isinstance(self, LightningBaseModule) keys = list(outputs[0].keys()) summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key] for output in outputs])) for key in keys if 'loss' in key}) return summary_dict class BaseValMixin: # Absolute Error absolute_loss = nn.L1Loss() # negative Log Likelyhood nll_loss = nn.NLLLoss() # Binary Cross Entropy bce_loss = nn.BCELoss() def validation_step(self, batch_pos_x_n_y_c, batch_idx, *_, **__): assert isinstance(self, LightningBaseModule) data = self.batch_to_data(batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c y = self(data).main_out nll_loss = self.nll_loss(y, data.y) return dict(val_nll_loss=nll_loss, batch_idx=batch_idx, y=y, batch_y=data.y) def validation_epoch_end(self, outputs, *_, **__): assert isinstance(self, LightningBaseModule) summary_dict = dict(log=dict()) # In case of Multiple given dataloader this will outputs will be: list[list[dict[]]] # for output_idx, output in enumerate(outputs): # else:list[dict[]] keys = list(outputs[0].keys()) # Add Every Value das has a "loss" in it, by calc. mean over all occurences. summary_dict['log'].update({f'mean_{key}': torch.mean(torch.stack([output[key] for output in outputs])) for key in keys if 'loss' in key} ) ####################################################################################### # Additional Score - UAR - ROC - Conf. Matrix - F1 ####################################################################################### # # INIT y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy() y_true_one_hot = to_one_hot(y_true, self.n_classes) y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy() y_pred_max = np.argmax(y_pred, axis=1) class_names = {val: key for key, val in self.dataset.test_dataset.classes.items()} ###################################################################################### # # F1 SCORE micro_f1_score = f1_score(y_true, y_pred_max, labels=None, pos_label=1, average='micro', sample_weight=None, zero_division=True) macro_f1_score = f1_score(y_true, y_pred_max, labels=None, pos_label=1, average='macro', sample_weight=None, zero_division=True) summary_dict['log'].update(dict(micro_f1_score=micro_f1_score, macro_f1_score=macro_f1_score)) ####################################################################################### # # ROC Curve # Compute ROC curve and ROC area for each class fpr = dict() tpr = dict() roc_auc = dict() for i in range(self.n_classes): fpr[i], tpr[i], _ = roc_curve(y_true_one_hot[:, i], y_pred[:, i]) roc_auc[i] = auc(fpr[i], tpr[i]) # Compute micro-average ROC curve and ROC area fpr["micro"], tpr["micro"], _ = roc_curve(y_true_one_hot.ravel(), y_pred.ravel()) roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) # First aggregate all false positive rates all_fpr = np.unique(np.concatenate([fpr[i] for i in range(self.n_classes)])) # Then interpolate all ROC curves at this points mean_tpr = np.zeros_like(all_fpr) for i in range(self.n_classes): mean_tpr += interp(all_fpr, fpr[i], tpr[i]) # Finally average it and compute AUC mean_tpr /= self.n_classes fpr["macro"] = all_fpr tpr["macro"] = mean_tpr roc_auc["macro"] = auc(fpr["macro"], tpr["macro"]) # Plot all ROC curves plt.figure() plt.plot(fpr["micro"], tpr["micro"], label=f'micro ROC ({round(roc_auc["micro"], 2)})', color='deeppink', linestyle=':', linewidth=4) plt.plot(fpr["macro"], tpr["macro"], label=f'macro ROC({round(roc_auc["macro"], 2)})', color='navy', linestyle=':', linewidth=4) colors = cycle(['firebrick', 'orangered', 'gold', 'olive', 'limegreen', 'aqua', 'dodgerblue', 'slategrey', 'royalblue', 'indigo', 'fuchsia'], ) for i, color in zip(range(self.n_classes), colors): plt.plot(fpr[i], tpr[i], color=color, lw=2, label=f'{class_names[i]} ({round(roc_auc[i],2 )})') plt.plot([0, 1], [0, 1], 'k--', lw=2) plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.legend(loc="lower right") self.logger.log_image('ROC', image=plt.gcf(), step=self.current_epoch) plt.clf() ####################################################################################### # # ROC SCORE try: macro_roc_auc_ovr = roc_auc_score(y_true_one_hot, y_pred, multi_class="ovr", average="macro") summary_dict['log'].update(macro_roc_auc_ovr=macro_roc_auc_ovr) except ValueError: micro_roc_auc_ovr = roc_auc_score(y_true_one_hot, y_pred, multi_class="ovr", average="micro") summary_dict['log'].update(micro_roc_auc_ovr=micro_roc_auc_ovr) ####################################################################################### # # Confusion matrix cm = confusion_matrix([class_names[x] for x in y_true], [class_names[x] for x in y_pred_max], labels=[class_names[key] for key in class_names.keys()], normalize='all') disp = ConfusionMatrixDisplay(confusion_matrix=cm) disp.plot(include_values=True) self.logger.log_image('Confusion Matrix', image=disp.figure_, step=self.current_epoch) plt.close('all') return summary_dict class DatasetMixin: def build_dataset(self, dataset_class, **kwargs): assert isinstance(self, LightningBaseModule) assert dataset_class.name == self.params.dataset_type, f'Check the dataset! ' + \ f'Expected was {self.params.dataset_type}, ' + \ f'given:{dataset_class.name}' # Dataset # ============================================================================= # Data Augmentations or Utility Transformations transforms = Compose( [ FixedPoints(8096), NormalizeScale() ] ) test_kwargs = kwargs.copy() test_kwargs.update(transform=transforms) # Dataset dataset = Namespace( **dict( # TRAIN DATASET train_dataset=dataset_class(self.params.root, mode=dataSplit.train, collate_per_segment=True, **kwargs), # VALIDATION DATASET val_dataset=dataset_class(self.params.root, mode=dataSplit.devel, collate_per_segment=False, **test_kwargs), # TEST DATASET test_dataset=dataset_class(self.params.root, mode=dataSplit.predict, collate_per_segment=False, **test_kwargs), ) ) return dataset class BaseDataloadersMixin(ABC): # Dataloaders # ================================================================================ # Train Dataloader def train_dataloader(self): assert isinstance(self, LightningBaseModule) # In case you want to implement bootstraping # sampler = RandomSampler(self.dataset.train_dataset, True, len(self.dataset.train_dataset)) sampler = None return DataLoader(dataset=self.dataset.train_dataset, shuffle=False if not sampler else None, sampler=sampler, batch_size=self.params.batch_size, num_workers=self.params.worker) # Test Dataloader def test_dataloader(self): assert isinstance(self, LightningBaseModule) return DataLoader(dataset=self.dataset.test_dataset, shuffle=False, batch_size=self.params.batch_size, num_workers=self.params.worker) # Validation Dataloader def val_dataloader(self): assert isinstance(self, LightningBaseModule) val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=False, batch_size=self.params.batch_size, num_workers=self.params.worker) # Alternative return [val_dataloader, alternative dataloader], there will be a dataloader_idx in validation_step return val_dataloader