point_to_primitive/utils/module_mixins.py

268 lines
10 KiB
Python

from collections import defaultdict
from itertools import cycle
from abc import ABC
from argparse import Namespace
import torch
import numpy as np
from numpy import interp
from sklearn.metrics import roc_curve, auc, confusion_matrix, ConfusionMatrixDisplay, f1_score, roc_auc_score
import matplotlib.pyplot as plt
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchcontrib.optim import SWA
from torchvision.transforms import Compose
from ml_lib.modules.util import LightningBaseModule
from ml_lib.utils.tools import to_one_hot
from ml_lib.utils.transforms import ToTensor
from ml_lib.point_toolset.point_io import BatchToData
from .project_config import GlobalVar
class BaseOptimizerMixin:
def configure_optimizers(self):
assert isinstance(self, LightningBaseModule)
opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
if self.params.sto_weight_avg:
# TODO: Make this glabaly available.
opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05)
return opt
def on_train_end(self):
assert isinstance(self, LightningBaseModule)
for opt in self.trainer.optimizers:
if isinstance(opt, SWA):
opt.swap_swa_sgd()
def on_epoch_end(self):
assert isinstance(self, LightningBaseModule)
if self.params.opt_reset_interval:
if self.current_epoch % self.params.opt_reset_interval == 0:
for opt in self.trainer.optimizers:
opt.state = defaultdict(dict)
class BaseTrainMixin:
# Absolute Error
absolute_loss = nn.L1Loss()
# negative Log Likelyhood
nll_loss = nn.NLLLoss()
# Binary Cross Entropy
bce_loss = nn.BCELoss()
# Batch To Data
batch_to_data = BatchToData()
def training_step(self, batch_pos_x_y, batch_nb, *_, **__):
assert isinstance(self, LightningBaseModule)
data = self.batch_to_data(*batch_pos_x_y)
y = self(data).main_out
nll_loss = self.nll_loss(y, data.y)
return dict(loss=nll_loss, log=dict(batch_nb=batch_nb))
def training_epoch_end(self, outputs):
assert isinstance(self, LightningBaseModule)
keys = list(outputs[0].keys())
summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key})
return summary_dict
class BaseValMixin:
# Absolute Error
absolute_loss = nn.L1Loss()
# negative Log Likelyhood
nll_loss = nn.NLLLoss()
# Binary Cross Entropy
bce_loss = nn.BCELoss()
def validation_step(self, batch_pos_x_y, batch_idx, *_, **__):
assert isinstance(self, LightningBaseModule)
data = self.batch_to_data(*batch_pos_x_y)
y = self(data).main_out
nll_loss = self.nll_loss(y, data.y)
return dict(val_nll_loss=nll_loss,
batch_idx=batch_idx, y=y, batch_y=data.y)
def validation_epoch_end(self, outputs, *_, **__):
assert isinstance(self, LightningBaseModule)
summary_dict = dict(log=dict())
# In case of Multiple given dataloader this will outputs will be: list[list[dict[]]]
# for output_idx, output in enumerate(outputs):
# else:list[dict[]]
keys = list(outputs[0].keys())
# Add Every Value das has a "loss" in it, by calc. mean over all occurences.
summary_dict['log'].update({f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key}
)
#######################################################################################
# Additional Score - UAR - ROC - Conf. Matrix - F1
#######################################################################################
#
# INIT
y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy()
y_true_one_hot = to_one_hot(y_true)
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy()
y_pred_max = np.argmax(y_pred, axis=1)
class_names = {val: key for key, val in GlobalVar.classes.__dict__().items()}
######################################################################################
#
# F1 SCORE
micro_f1_score = f1_score(y_true, y_pred_max, labels=None, pos_label=1, average='micro', sample_weight=None,
zero_division=True)
macro_f1_score = f1_score(y_true, y_pred_max, labels=None, pos_label=1, average='macro', sample_weight=None,
zero_division=True)
summary_dict['log'].update(dict(micro_f1_score=micro_f1_score, macro_f1_score=macro_f1_score))
#######################################################################################
#
# ROC Curve
# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(len(GlobalVar.classes)):
fpr[i], tpr[i], _ = roc_curve(y_true_one_hot[:, i], y_pred[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])
# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(y_true_one_hot.ravel(), y_pred.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(len(GlobalVar.classes))]))
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(len(GlobalVar.classes)):
mean_tpr += interp(all_fpr, fpr[i], tpr[i])
# Finally average it and compute AUC
mean_tpr /= len(GlobalVar.classes)
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
# Plot all ROC curves
plt.figure()
plt.plot(fpr["micro"], tpr["micro"],
label=f'micro ROC ({round(roc_auc["micro"], 2)})',
color='deeppink', linestyle=':', linewidth=4)
plt.plot(fpr["macro"], tpr["macro"],
label=f'macro ROC({round(roc_auc["macro"], 2)})]',
color='navy', linestyle=':', linewidth=4)
colors = cycle(['firebrick', 'orangered', 'gold', 'olive', 'limegreen', 'aqua',
'dodgerblue', 'slategrey', 'royalblue', 'indigo', 'fuchsia'], )
for i, color in zip(range(len(GlobalVar.classes)), colors):
plt.plot(fpr[i], tpr[i], color=color, lw=2, label=f'{class_names[i]} ({round(roc_auc[i],2 )})')
plt.plot([0, 1], [0, 1], 'k--', lw=2)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc="lower right")
self.logger.log_image('ROC', image=plt.gcf(), step=self.current_epoch)
plt.clf()
#######################################################################################
#
# ROC SCORE
macro_roc_auc_ovr = roc_auc_score(y_true_one_hot, y_pred, multi_class="ovr",
average="macro")
summary_dict['log'].update(macro_roc_auc_ovr=macro_roc_auc_ovr)
#######################################################################################
#
# Confusion matrix
cm = confusion_matrix(y_true, y_pred_max, labels=[class_name for class_name in class_names], normalize='all')
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(include_values=True)
self.logger.log_image('Confusion Matrix', image=plt.gcf(), step=self.current_epoch)
return summary_dict
class DatasetMixin:
def build_dataset(self, dataset_class):
assert isinstance(self, LightningBaseModule)
# Dataset
# =============================================================================
# Data Augmentations or Utility Transformations
transforms = Compose([ToTensor()])
# Dataset
dataset = Namespace(
**dict(
# TRAIN DATASET
train_dataset=dataset_class(self.params.root, setting=GlobalVar.data_split.train,
transforms=transforms
),
# VALIDATION DATASET
val_dataset=dataset_class(self.params.root, setting=GlobalVar.data_split.devel,
),
# TEST DATASET
test_dataset=dataset_class(self.params.root, setting=GlobalVar.data_split.test,
),
)
)
return dataset
class BaseDataloadersMixin(ABC):
# Dataloaders
# ================================================================================
# Train Dataloader
def train_dataloader(self):
assert isinstance(self, LightningBaseModule)
# In case you want to implement bootstraping
# sampler = RandomSampler(self.dataset.train_dataset, True, len(self.dataset.train_dataset))
sampler = None
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True if not sampler else None, sampler=sampler,
batch_size=self.params.batch_size,
num_workers=self.params.worker)
# Test Dataloader
def test_dataloader(self):
assert isinstance(self, LightningBaseModule)
return DataLoader(dataset=self.dataset.test_dataset, shuffle=False,
batch_size=self.params.batch_size,
num_workers=self.params.worker)
# Validation Dataloader
def val_dataloader(self):
assert isinstance(self, LightningBaseModule)
val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=False,
batch_size=self.params.batch_size, num_workers=self.params.worker)
# Alternative return [val_dataloader, alternative dataloader], there will be a dataloader_idx in validation_step
return val_dataloader