Transformer running

This commit is contained in:
Steffen Illium 2021-03-04 12:01:08 +01:00
parent b5e3e5aec1
commit f89f0f8528
14 changed files with 349 additions and 80 deletions

View File

@ -9,7 +9,7 @@ from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from ml_lib.modules.util import LightningBaseModule from ml_lib.modules.util import LightningBaseModule
from ml_lib.utils.config import Config from ml_lib.utils.config import Config
from ml_lib.utils.logging import Logger from ml_lib.utils.loggers import Logger
warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning) warnings.filterwarnings('ignore', category=UserWarning)

0
additions/__init__.py Normal file
View File

43
additions/losses.py Normal file
View File

@ -0,0 +1,43 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.modules.loss._WeightedLoss):
def __init__(self, weight=None, gamma=2,reduction='mean'):
super(FocalLoss, self).__init__(weight,reduction=reduction)
self.gamma = gamma
self.weight = weight # weight parameter will act as the alpha parameter to balance class weights
def forward(self, input, target):
ce_loss = F.cross_entropy(input, target, reduction=self.reduction, weight=self.weight)
pt = torch.exp(-ce_loss)
focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
return focal_loss
class FocalLossRob(nn.Module):
# taken from https://github.com/mathiaszinnen/focal_loss_torch/blob/main/focal_loss/focal_loss.py
def __init__(self, alpha=1, gamma=2, reduction: str = 'mean'):
super().__init__()
if reduction not in ['mean', 'none', 'sum']:
raise NotImplementedError('Reduction {} not implemented.'.format(reduction))
self.reduction = reduction
self.alpha = alpha
self.gamma = gamma
def forward(self, x, target):
x = x.clamp(1e-7, 1. - 1e-7) # own addition
p_t = torch.where(target == 1, x, 1-x)
fl = - 1 * (1 - p_t) ** self.gamma * torch.log(p_t)
fl = torch.where(target == 1, fl * self.alpha, fl)
return self._reduce(fl)
def _reduce(self, x):
if self.reduction == 'mean':
return x.mean()
elif self.reduction == 'sum':
return x.sum()
else:
return x

View File

@ -60,7 +60,8 @@ class LibrosaAudioToMelDataset(Dataset):
self.mel_file_path.unlink(missing_ok=True) self.mel_file_path.unlink(missing_ok=True)
if not self.mel_file_path.exists(): if not self.mel_file_path.exists():
self.mel_file_path.parent.mkdir(parents=True, exist_ok=True) self.mel_file_path.parent.mkdir(parents=True, exist_ok=True)
raw_sample, _ = librosa.core.load(self.audio_path, sr=self.sampling_rate) with self.audio_path.open(mode='rb') as audio_file:
raw_sample, _ = librosa.core.load(audio_file, sr=self.sampling_rate)
mel_sample = self._mel_transform(raw_sample) mel_sample = self._mel_transform(raw_sample)
with self.mel_file_path.open('wb') as mel_file: with self.mel_file_path.open('wb') as mel_file:
pickle.dump(mel_sample, mel_file, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(mel_sample, mel_file, protocol=pickle.HIGHEST_PROTOCOL)

View File

@ -22,10 +22,12 @@ class TorchMelDataset(Dataset):
self.mel_hop_len = int(mel_hop_len) self.mel_hop_len = int(mel_hop_len)
self.sub_segment_hop_len = int(sub_segment_hop_len) self.sub_segment_hop_len = int(sub_segment_hop_len)
self.n = int((self.sampling_rate / self.mel_hop_len) * self.audio_file_len + 1) self.n = int((self.sampling_rate / self.mel_hop_len) * self.audio_file_len + 1)
if self.sub_segment_len and self.sub_segment_hop_len: if self.sub_segment_len and self.sub_segment_hop_len and (self.n - self.sub_segment_len) > 0:
self.offsets = list(range(0, self.n - self.sub_segment_len, self.sub_segment_hop_len)) self.offsets = list(range(0, self.n - self.sub_segment_len, self.sub_segment_hop_len))
else: else:
self.offsets = [0] self.offsets = [0]
if len(self) == 0:
print('what happend here')
self.label = label self.label = label
self.transform = transform self.transform = transform

View File

@ -2,7 +2,9 @@ from itertools import cycle
import numpy as np import numpy as np
import torch import torch
from sklearn.metrics import f1_score, roc_curve, auc, roc_auc_score, ConfusionMatrixDisplay, confusion_matrix from pytorch_lightning.metrics import Recall
from sklearn.metrics import f1_score, roc_curve, auc, roc_auc_score, ConfusionMatrixDisplay, confusion_matrix, \
recall_score
from ml_lib.metrics._base_score import _BaseScores from ml_lib.metrics._base_score import _BaseScores
from ml_lib.utils.tools import to_one_hot from ml_lib.utils.tools import to_one_hot
@ -16,20 +18,21 @@ class MultiClassScores(_BaseScores):
super(MultiClassScores, self).__init__(*args) super(MultiClassScores, self).__init__(*args)
pass pass
def __call__(self, outputs): def __call__(self, outputs, class_names=None):
summary_dict = dict() summary_dict = dict()
class_names = class_names or range(self.model.params.n_classes)
####################################################################################### #######################################################################################
# Additional Score - UAR - ROC - Conf. Matrix - F1 # Additional Score - UAR - ROC - Conf. Matrix - F1
####################################################################################### #######################################################################################
# #
# INIT # INIT
y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy() y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy()
y_true_one_hot = to_one_hot(y_true, self.model.n_classes) y_true_one_hot = to_one_hot(y_true, self.model.params.n_classes)
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy() y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy()
y_pred_max = np.argmax(y_pred, axis=1) y_pred_max = np.argmax(y_pred, axis=1)
class_names = {val: key for key, val in self.model.dataset.test_dataset.classes.items()} class_names = {val: key for val, key in enumerate(class_names)}
###################################################################################### ######################################################################################
# #
# F1 SCORE # F1 SCORE
@ -38,7 +41,12 @@ class MultiClassScores(_BaseScores):
macro_f1_score = f1_score(y_true, y_pred_max, labels=None, pos_label=1, average='macro', sample_weight=None, macro_f1_score = f1_score(y_true, y_pred_max, labels=None, pos_label=1, average='macro', sample_weight=None,
zero_division=True) zero_division=True)
summary_dict.update(dict(micro_f1_score=micro_f1_score, macro_f1_score=macro_f1_score)) summary_dict.update(dict(micro_f1_score=micro_f1_score, macro_f1_score=macro_f1_score))
######################################################################################
#
# Unweichted Average Recall
uar = recall_score(y_true, y_pred_max, labels=[0, 1, 2, 3, 4], average='macro',
sample_weight=None, zero_division='warn')
summary_dict.update(dict(uar_score=uar))
####################################################################################### #######################################################################################
# #
# ROC Curve # ROC Curve
@ -47,7 +55,7 @@ class MultiClassScores(_BaseScores):
fpr = dict() fpr = dict()
tpr = dict() tpr = dict()
roc_auc = dict() roc_auc = dict()
for i in range(self.model.n_classes): for i in range(self.model.params.n_classes):
fpr[i], tpr[i], _ = roc_curve(y_true_one_hot[:, i], y_pred[:, i]) fpr[i], tpr[i], _ = roc_curve(y_true_one_hot[:, i], y_pred[:, i])
roc_auc[i] = auc(fpr[i], tpr[i]) roc_auc[i] = auc(fpr[i], tpr[i])
@ -56,15 +64,15 @@ class MultiClassScores(_BaseScores):
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
# First aggregate all false positive rates # First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(self.model.n_classes)])) all_fpr = np.unique(np.concatenate([fpr[i] for i in range(self.model.params.n_classes)]))
# Then interpolate all ROC curves at this points # Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr) mean_tpr = np.zeros_like(all_fpr)
for i in range(self.model.n_classes): for i in range(self.model.params.n_classes):
mean_tpr += np.interp(all_fpr, fpr[i], tpr[i]) mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
# Finally average it and compute AUC # Finally average it and compute AUC
mean_tpr /= self.model.n_classes mean_tpr /= self.model.params.n_classes
fpr["macro"] = all_fpr fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr tpr["macro"] = mean_tpr
@ -83,7 +91,7 @@ class MultiClassScores(_BaseScores):
colors = cycle(['firebrick', 'orangered', 'gold', 'olive', 'limegreen', 'aqua', colors = cycle(['firebrick', 'orangered', 'gold', 'olive', 'limegreen', 'aqua',
'dodgerblue', 'slategrey', 'royalblue', 'indigo', 'fuchsia'], ) 'dodgerblue', 'slategrey', 'royalblue', 'indigo', 'fuchsia'], )
for i, color in zip(range(self.model.n_classes), colors): for i, color in zip(range(self.model.params.n_classes), colors):
plt.plot(fpr[i], tpr[i], color=color, lw=2, label=f'{class_names[i]} ({round(roc_auc[i], 2)})') plt.plot(fpr[i], tpr[i], color=color, lw=2, label=f'{class_names[i]} ({round(roc_auc[i], 2)})')
plt.plot([0, 1], [0, 1], 'k--', lw=2) plt.plot([0, 1], [0, 1], 'k--', lw=2)
@ -116,9 +124,9 @@ class MultiClassScores(_BaseScores):
fig1, ax1 = plt.subplots(dpi=96) fig1, ax1 = plt.subplots(dpi=96)
cm = confusion_matrix([class_names[x] for x in y_true], [class_names[x] for x in y_pred_max], cm = confusion_matrix([class_names[x] for x in y_true], [class_names[x] for x in y_pred_max],
labels=[class_names[key] for key in class_names.keys()], labels=[class_names[key] for key in class_names.keys()],
normalize='all') normalize='true')
disp = ConfusionMatrixDisplay(confusion_matrix=cm, disp = ConfusionMatrixDisplay(confusion_matrix=cm,
display_labels=[class_names[i] for i in range(self.model.n_classes)] display_labels=[class_names[i] for i in range(self.model.params.n_classes)]
) )
disp.plot(include_values=True, ax=ax1) disp.plot(include_values=True, ax=ax1)

View File

@ -22,8 +22,8 @@ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
################### ###################
class LinearModule(ShapeMixin, nn.Module): class LinearModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, out_features, bias=True, activation=None, def __init__(self, in_shape, out_features, use_bias=True, activation=None,
norm=False, dropout: Union[int, float] = 0, **kwargs): use_norm=False, dropout: Union[int, float] = 0, **kwargs):
if list(kwargs.keys()): if list(kwargs.keys()):
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}') warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
super(LinearModule, self).__init__() super(LinearModule, self).__init__()
@ -31,8 +31,8 @@ class LinearModule(ShapeMixin, nn.Module):
self.in_shape = in_shape self.in_shape = in_shape
self.flat = Flatten(self.in_shape) if isinstance(self.in_shape, (tuple, list)) else F_x(in_shape) self.flat = Flatten(self.in_shape) if isinstance(self.in_shape, (tuple, list)) else F_x(in_shape)
self.dropout = nn.Dropout(dropout) if dropout else F_x(self.flat.shape) self.dropout = nn.Dropout(dropout) if dropout else F_x(self.flat.shape)
self.norm = nn.BatchNorm1d(self.flat.shape) if norm else F_x(self.flat.shape) self.norm = nn.LayerNorm(self.flat.shape) if use_norm else F_x(self.flat.shape)
self.linear = nn.Linear(self.flat.shape, out_features, bias=bias) self.linear = nn.Linear(self.flat.shape, out_features, bias=use_bias)
self.activation = activation() if activation else F_x(self.linear.out_features) self.activation = activation() if activation else F_x(self.linear.out_features)
def forward(self, x): def forward(self, x):
@ -47,13 +47,14 @@ class LinearModule(ShapeMixin, nn.Module):
class ConvModule(ShapeMixin, nn.Module): class ConvModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, conv_filters, conv_kernel, activation: nn.Module = nn.ELU, pooling_size=None, def __init__(self, in_shape, conv_filters, conv_kernel, activation: nn.Module = nn.ELU, pooling_size=None,
bias=True, norm=False, dropout: Union[int, float] = 0, trainable: bool = True, bias=True, use_norm=False, dropout: Union[int, float] = 0, trainable: bool = True,
conv_class=nn.Conv2d, conv_stride=1, conv_padding=0, **kwargs): conv_class=nn.Conv2d, conv_stride=1, conv_padding=0, **kwargs):
super(ConvModule, self).__init__() super(ConvModule, self).__init__()
assert isinstance(in_shape, (tuple, list)), f'"in_shape" should be a [list, tuple], but was {type(in_shape)}' assert isinstance(in_shape, (tuple, list)), f'"in_shape" should be a [list, tuple], but was {type(in_shape)}'
assert len(in_shape) == 3, f'Length should be 3, but was {len(in_shape)}' assert len(in_shape) == 3, f'Length should be 3, but was {len(in_shape)}'
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}') if len(kwargs.keys()):
if norm and not trainable: warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
if use_norm and not trainable:
warnings.warn('You set this module to be not trainable but the running norm is active.\n' + warnings.warn('You set this module to be not trainable but the running norm is active.\n' +
'We set it to "eval" mode.\n' + 'We set it to "eval" mode.\n' +
'Keep this in mind if you do a finetunning or retraining step.' 'Keep this in mind if you do a finetunning or retraining step.'
@ -72,9 +73,9 @@ class ConvModule(ShapeMixin, nn.Module):
# Modules # Modules
self.activation = activation() or F_x(None) self.activation = activation() or F_x(None)
self.norm = nn.LayerNorm(self.in_shape, eps=1e-04) if use_norm else F_x(None)
self.dropout = nn.Dropout2d(dropout) if dropout else F_x(None) self.dropout = nn.Dropout2d(dropout) if dropout else F_x(None)
self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else F_x(None) self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else F_x(None)
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if norm else F_x(None)
self.conv = conv_class(in_channels, self.conv_filters, self.conv_kernel, bias=bias, self.conv = conv_class(in_channels, self.conv_filters, self.conv_kernel, bias=bias,
padding=self.padding, stride=self.stride padding=self.padding, stride=self.stride
) )
@ -134,7 +135,7 @@ class DeConvModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, conv_filters, conv_kernel, conv_stride=1, conv_padding=0, def __init__(self, in_shape, conv_filters, conv_kernel, conv_stride=1, conv_padding=0,
dropout: Union[int, float] = 0, autopad=0, dropout: Union[int, float] = 0, autopad=0,
activation: Union[None, nn.Module] = nn.ReLU, interpolation_scale=0, activation: Union[None, nn.Module] = nn.ReLU, interpolation_scale=0,
bias=True, norm=False, **kwargs): bias=True, use_norm=False, **kwargs):
super(DeConvModule, self).__init__() super(DeConvModule, self).__init__()
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}') warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2] in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
@ -146,7 +147,7 @@ class DeConvModule(ShapeMixin, nn.Module):
self.autopad = AutoPad() if autopad else lambda x: x self.autopad = AutoPad() if autopad else lambda x: x
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if norm else F_x(self.in_shape) self.norm = nn.LayerNorm(in_channels, eps=1e-04) if use_norm else F_x(self.in_shape)
self.dropout = nn.Dropout2d(dropout) if dropout else F_x(self.in_shape) self.dropout = nn.Dropout2d(dropout) if dropout else F_x(self.in_shape)
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, self.conv_kernel, bias=bias, self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, self.conv_kernel, bias=bias,
padding=self.padding, stride=self.stride) padding=self.padding, stride=self.stride)
@ -166,14 +167,13 @@ class DeConvModule(ShapeMixin, nn.Module):
class ResidualModule(ShapeMixin, nn.Module): class ResidualModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, module_class, n, norm=False, **module_parameters): def __init__(self, in_shape, module_class, n, use_norm=False, **module_parameters):
assert n >= 1 assert n >= 1
super(ResidualModule, self).__init__() super(ResidualModule, self).__init__()
self.in_shape = in_shape self.in_shape = in_shape
module_parameters.update(in_shape=in_shape) module_parameters.update(in_shape=in_shape)
if norm: if use_norm:
norm = nn.BatchNorm1d if len(self.in_shape) <= 2 else nn.BatchNorm2d self.norm = nn.LayerNorm(self.in_shape if isinstance(self.in_shape, int) else self.in_shape[0])
self.norm = norm(self.in_shape if isinstance(self.in_shape, int) else self.in_shape[0])
else: else:
self.norm = F_x(self.in_shape) self.norm = F_x(self.in_shape)
self.activation = module_parameters.get('activation', None) self.activation = module_parameters.get('activation', None)
@ -216,13 +216,14 @@ class RecurrentModule(ShapeMixin, nn.Module):
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.): def __init__(self, dim, hidden_dim, dropout=0., activation=nn.GELU):
super().__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
nn.Linear(dim, hidden_dim), nn.Linear(dim, hidden_dim),
nn.GELU(), activation() or F_x(None),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(hidden_dim, dim), nn.Linear(hidden_dim, dim),
activation() or F_x(None),
nn.Dropout(dropout) nn.Dropout(dropout)
) )
@ -272,18 +273,20 @@ class Attention(nn.Module):
class TransformerModule(ShapeMixin, nn.Module): class TransformerModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, depth, heads, mlp_dim, dropout=None, use_norm=False, activation='gelu'): def __init__(self, in_shape, depth, heads, mlp_dim, dropout=None, use_norm=False,
activation=nn.GELU, use_residual=True):
super(TransformerModule, self).__init__() super(TransformerModule, self).__init__()
self.in_shape = in_shape self.in_shape = in_shape
self.use_residual = use_residual
self.flat = Flatten(self.in_shape) if isinstance(self.in_shape, (tuple, list)) else F_x(in_shape) self.flat = Flatten(self.in_shape) if isinstance(self.in_shape, (tuple, list)) else F_x(in_shape)
self.layers = nn.ModuleList([])
self.embedding_dim = self.flat.flat_shape self.embedding_dim = self.flat.flat_shape
self.norm = nn.LayerNorm(self.embedding_dim) self.norm = nn.LayerNorm(self.embedding_dim) if use_norm else F_x(self.embedding_dim)
self.attns = nn.ModuleList([Attention(self.embedding_dim, heads=heads, dropout=dropout) for _ in range(depth)]) self.attns = nn.ModuleList([Attention(self.embedding_dim, heads=heads, dropout=dropout) for _ in range(depth)])
self.mlps = nn.ModuleList([FeedForward(self.embedding_dim, mlp_dim, dropout=dropout) for _ in range(depth)]) self.mlps = nn.ModuleList([FeedForward(self.embedding_dim, mlp_dim, dropout=dropout, activation=activation)
for _ in range(depth)])
def forward(self, x, mask=None, return_attn_weights=False, **_): def forward(self, x, mask=None, return_attn_weights=False, **_):
tensor = self.flat(x) tensor = self.flat(x)
@ -297,11 +300,11 @@ class TransformerModule(ShapeMixin, nn.Module):
attn_weights.append(attn_weight) attn_weights.append(attn_weight)
else: else:
attn_tensor = attn(attn_tensor, mask=mask) attn_tensor = attn(attn_tensor, mask=mask)
tensor = attn_tensor + tensor tensor = tensor + attn_tensor if self.use_residual else attn_tensor
# MLP # MLP
mlp_tensor = self.norm(tensor) mlp_tensor = self.norm(tensor)
mlp_tensor = mlp(mlp_tensor) mlp_tensor = mlp(mlp_tensor)
tensor = tensor + mlp_tensor tensor = tensor + mlp_tensor if self.use_residual else mlp_tensor
return (tensor, attn_weights) if return_attn_weights else tensor return (tensor, attn_weights) if return_attn_weights else tensor

View File

@ -183,10 +183,11 @@ class BaseCNNEncoder(ShapeMixin, nn.Module):
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
def __init__(self, in_shape, lat_dim=256, use_bias=True, use_norm=False, dropout: Union[int, float] = 0, def __init__(self, in_shape, lat_dim=256, use_bias=True, use_norm=False, dropout: Union[int, float] = 0,
latent_activation: Union[nn.Module, None] = None, activation: nn.Module = nn.ELU, latent_activation: Union[nn.Module, None] = None, activation: nn.Module = nn.ELU,
filters: List[int] = None, kernels: List[int] = None, **kwargs): filters: List[int] = None, kernels: Union[List[int], int, None] = None, **kwargs):
super(BaseCNNEncoder, self).__init__() super(BaseCNNEncoder, self).__init__()
assert filters, '"Filters" has to be a list of int' assert filters, '"Filters" has to be a list of int'
assert kernels, '"Kernels" has to be a list of int' kernels = kernels or [3] * len(filters)
kernels = kernels if not isinstance(kernels, int) else [kernels] * len(filters)
assert len(kernels) == len(filters), 'Length of "Filters" and "Kernels" has to be same.' assert len(kernels) == len(filters), 'Length of "Filters" and "Kernels" has to be same.'
# Optional Padding for odd image-sizes # Optional Padding for odd image-sizes

View File

@ -1,7 +1,5 @@
import inspect
from argparse import ArgumentParser
from functools import reduce from functools import reduce
from matplotlib import pyplot as plt
from abc import ABC from abc import ABC
from pathlib import Path from pathlib import Path
@ -12,14 +10,77 @@ from pytorch_lightning.utilities import argparse_utils
from torch import nn from torch import nn
from torch.nn import functional as F, Unfold from torch.nn import functional as F, Unfold
from sklearn.metrics import ConfusionMatrixDisplay
# Utility - Modules # Utility - Modules
################### ###################
from ..utils.model_io import ModelParameters from ..utils.model_io import ModelParameters
from ..utils.tools import locate_and_import_class, add_argparse_args from ..utils.tools import add_argparse_args
try: try:
import pytorch_lightning as pl import pytorch_lightning as pl
class PLMetrics(pl.metrics.Metric):
def __init__(self, n_classes, tag=''):
super(PLMetrics, self).__init__()
self.n_classes = n_classes
self.tag = tag
self.accuracy_score = pl.metrics.Accuracy(compute_on_step=False)
self.precision = pl.metrics.Precision(num_classes=self.n_classes, average='macro', compute_on_step=False)
self.recall = pl.metrics.Recall(num_classes=self.n_classes, average='macro', compute_on_step=False)
self.confusion_matrix = pl.metrics.ConfusionMatrix(self.n_classes, normalize='true', compute_on_step=False)
# self.precision_recall_curve = pl.metrics.PrecisionRecallCurve(self.n_classes, compute_on_step=False)
# self.average_prec = pl.metrics.AveragePrecision(self.n_classes, compute_on_step=True)
# self.roc = pl.metrics.ROC(self.n_classes, compute_on_step=False)
self.fbeta = pl.metrics.FBeta(self.n_classes, average='macro', compute_on_step=False)
self.f1 = pl.metrics.F1(self.n_classes, average='macro', compute_on_step=False)
def __iter__(self):
return iter(((name, metric) for name, metric in self._modules.items()))
def update(self, preds, target) -> None:
for _, metric in self:
metric.update(preds, target)
def reset(self) -> None:
for _, metric in self:
metric.reset()
def compute(self) -> dict:
tag = f'{self.tag}_' if self.tag else ''
return {f'{tag}{metric_name}_score': metric.compute() for metric_name, metric in self}
def compute_and_prepare(self):
pl_metrics = self.compute()
images_from_metrics = dict()
for metric_name in list(pl_metrics.keys()):
if 'curve' in metric_name:
continue
roc_curve = pl_metrics.pop(metric_name)
print('debug_point')
elif 'matrix' in metric_name:
matrix = pl_metrics.pop(metric_name)
fig1, ax1 = plt.subplots(dpi=96)
disp = ConfusionMatrixDisplay(confusion_matrix=matrix.cpu().numpy(),
display_labels=[i for i in range(self.n_classes)]
)
disp.plot(include_values=True, ax=ax1)
images_from_metrics[metric_name] = fig1
elif 'ROC' in metric_name:
continue
roc = pl_metrics.pop(metric_name)
print('debug_point')
else:
pl_metrics[metric_name] = pl_metrics[metric_name].cpu().item()
return pl_metrics, images_from_metrics
class LightningBaseModule(pl.LightningModule, ABC): class LightningBaseModule(pl.LightningModule, ABC):
@classmethod @classmethod
@ -49,6 +110,9 @@ try:
self._weight_init = weight_init self._weight_init = weight_init
self.params = ModelParameters(model_parameters) self.params = ModelParameters(model_parameters)
self.metrics = PLMetrics(self.params.n_classes, tag='PL')
pass
def size(self): def size(self):
return self.shape return self.shape

View File

@ -25,5 +25,12 @@ class _BaseDataModule(LightningDataModule):
self.datasets = dict() self.datasets = dict()
def transfer_batch_to_device(self, batch, device): def transfer_batch_to_device(self, batch, device):
return batch.to(device) if isinstance(batch, list):
for idx, item in enumerate(batch):
try:
batch[idx] = item.to(device)
except (AttributeError, RuntimeError):
continue
return batch
else:
return batch.to(device)

View File

@ -1,6 +1,9 @@
import ast import ast
import configparser
from pathlib import Path from pathlib import Path
from typing import Mapping, Dict
import torch
from copy import deepcopy from copy import deepcopy
from abc import ABC from abc import ABC
@ -9,8 +12,67 @@ from argparse import Namespace, ArgumentParser
from collections import defaultdict from collections import defaultdict
from configparser import ConfigParser, DuplicateSectionError from configparser import ConfigParser, DuplicateSectionError
import hashlib import hashlib
from pytorch_lightning import Trainer
from ml_lib.utils.tools import locate_and_import_class from ml_lib.utils.loggers import Logger
from ml_lib.utils.tools import locate_and_import_class, auto_cast
# Argument Parser and default Values
# =============================================================================
def parse_comandline_args_add_defaults(filepath, overrides=None):
# Parse Command Line
parser = ArgumentParser()
parser.add_argument('--model_name', type=str)
parser.add_argument('--data_name', type=str)
# Load Defaults from _parameters.ini file
config = configparser.ConfigParser()
config.read(str(filepath))
new_defaults = dict()
for key in ['project', 'train', 'data']:
defaults = config[key]
new_defaults.update({key: auto_cast(val) for key, val in defaults.items()})
if new_defaults['debug']:
new_defaults.update(
max_epochs=2,
max_steps=2 # The seems to be the new "fast_dev_run"
)
args, _ = parser.parse_known_args()
overrides = overrides or dict()
default_data = overrides.get('data_name', None) or new_defaults['data_name']
default_model = overrides.get('model_name', None) or new_defaults['model_name']
data_name = args.__dict__.get('data_name', None) or default_data
model_name = args.__dict__.get('model_name', None) or default_model
new_defaults.update({key: auto_cast(val) for key, val in config[model_name].items()})
found_data_class = locate_and_import_class(data_name, 'datasets')
found_model_class = locate_and_import_class(model_name, 'models')
for module in [Logger, Trainer, found_data_class, found_model_class]:
parser = module.add_argparse_args(parser)
args, _ = parser.parse_known_args(namespace=Namespace(**new_defaults))
args = vars(args)
args.update({key: auto_cast(val) for key, val in args.items()})
args.update(gpus=[0] if torch.cuda.is_available() and not args['debug'] else None,
row_log_interval=1000, # TODO: Better Value / Setting
log_save_interval=10000, # TODO: Better Value / Setting
auto_lr_find=not args['debug'],
weights_summary='top',
check_val_every_n_epoch=1 if args['debug'] else args.get('check_val_every_n_epoch', 1)
)
if overrides is not None and isinstance(overrides, (Mapping, Dict)):
args.update(**overrides)
return args, found_data_class, found_model_class
def is_jsonable(x): def is_jsonable(x):

30
utils/equal_sampler.py Normal file
View File

@ -0,0 +1,30 @@
import random
from typing import Iterator, Sequence
from torch.utils.data import Sampler
from torch.utils.data.sampler import T_co
# noinspection PyMissingConstructor
class EqualSampler(Sampler):
def __init__(self, idxs_per_class: Sequence[Sequence[float]], replacement: bool = True) -> None:
self.replacement = replacement
self.idxs_per_class = idxs_per_class
self.len_largest_class = max([len(x) for x in self.idxs_per_class])
def __iter__(self) -> Iterator[T_co]:
return iter(random.choice(self.idxs_per_class[random.randint(0, len(self.idxs_per_class)-1)])
for _ in range(len(self)))
def __len__(self):
return self.len_largest_class * len(self.idxs_per_class)
if __name__ == '__main__':
es = EqualSampler([list(range(5)), list(range(5, 10)), list(range(10, 12))])
for i in es:
print(i)
pass

View File

@ -1,5 +1,6 @@
import inspect from copy import deepcopy
from argparse import ArgumentParser
import hashlib
from pathlib import Path from pathlib import Path
import os import os
@ -17,11 +18,34 @@ class Logger(LightningLoggerBase):
@classmethod @classmethod
def from_argparse_args(cls, args, **kwargs): def from_argparse_args(cls, args, **kwargs):
return argparse_utils.from_argparse_args(cls, args, **kwargs) cleaned_args = deepcopy(args.__dict__)
# Clean Seed and other attributes
# TODO: Find a better way in cleaning this
for attr in ['seed', 'num_worker', 'debug', 'eval', 'owner', 'data_root', 'check_val_every_n_epoch',
'reset', 'outpath', 'version', 'gpus', 'neptune_key', 'num_sanity_val_steps', 'tpu_cores',
'progress_bar_refresh_rate', 'log_save_interval', 'row_log_interval']:
try:
del cleaned_args[attr]
except KeyError:
pass
kwargs.update(params=cleaned_args)
new_logger = argparse_utils.from_argparse_args(cls, args, **kwargs)
return new_logger
@property @property
def name(self) -> str: def fingerprint(self):
return self._name h = hashlib.md5()
h.update(self._finger_print_string.encode())
fingerprint = h.hexdigest()
return fingerprint
@property
def name(self):
short_name = "".join(c for c in self.model_name if c.isupper())
return f'{short_name}_{self.fingerprint}'
media_dir = 'media' media_dir = 'media'
@ -42,7 +66,12 @@ class Logger(LightningLoggerBase):
@property @property
def project_name(self): def project_name(self):
return f"{self.owner}/{self.name.replace('_', '-')}" return f"{self.owner}/{self.projeect_root.replace('_', '-')}"
@property
def projeect_root(self):
root_path = Path(os.getcwd()).name if not self.debug else 'test'
return root_path
@property @property
def version(self): def version(self):
@ -56,7 +85,7 @@ class Logger(LightningLoggerBase):
def outpath(self): def outpath(self):
return Path(self.root_out) / self.model_name return Path(self.root_out) / self.model_name
def __init__(self, owner, neptune_key, model_name, project_name='', outpath='output', seed=69, debug=False): def __init__(self, owner, neptune_key, model_name, outpath='output', seed=69, debug=False, params=None):
""" """
params (dict|None): Optional. Parameters of the experiment. After experiment creation params are read-only. params (dict|None): Optional. Parameters of the experiment. After experiment creation params are read-only.
Parameters are displayed in the experiments Parameters section and each key-value pair can be Parameters are displayed in the experiments Parameters section and each key-value pair can be
@ -71,51 +100,67 @@ class Logger(LightningLoggerBase):
super(Logger, self).__init__() super(Logger, self).__init__()
self.debug = debug self.debug = debug
self._name = project_name or Path(os.getcwd()).name if not self.debug else 'test'
self.owner = owner if not self.debug else 'testuser' self.owner = owner if not self.debug else 'testuser'
self.neptune_key = neptune_key if not self.debug else 'XXX' self.neptune_key = neptune_key if not self.debug else 'XXX'
self.root_out = outpath if not self.debug else 'debug_out' self.root_out = outpath if not self.debug else 'debug_out'
self.params = params
self.seed = seed self.seed = seed
self.model_name = model_name self.model_name = model_name
if self.params:
_, fingerprint_tuple = zip(*sorted(self.params.items(), key=lambda tup: tup[0]))
self._finger_print_string = str(fingerprint_tuple)
else:
self._finger_print_string = str((self.owner, self.root_out, self.seed, self.model_name, self.debug))
self.params.update(fingerprint=self.fingerprint)
self._csvlogger_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name) self._csvlogger_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
self._neptune_kwargs = dict(offline_mode=self.debug, self._neptune_kwargs = dict(offline_mode=self.debug,
params=self.params,
api_key=self.neptune_key, api_key=self.neptune_key,
experiment_name=self.name, experiment_name=self.name,
# tags=?,
project_name=self.project_name) project_name=self.project_name)
try: try:
self.neptunelogger = NeptuneLogger(**self._neptune_kwargs) self.neptunelogger = NeptuneLogger(**self._neptune_kwargs)
except ProjectNotFound as e: except ProjectNotFound as e:
print(f'The project "{self.project_name}"') print(f'The project "{self.project_name}" does not exist! Create it or check your spelling.')
print(e) print(e)
self.csvlogger = CSVLogger(**self._csvlogger_kwargs) self.csvlogger = CSVLogger(**self._csvlogger_kwargs)
if self.params:
self.log_hyperparams(self.params)
def close(self):
self.csvlogger.close()
self.neptunelogger.close()
def set_fingerprint_string(self, fingerprint_str):
self._finger_print_string = fingerprint_str
def log_text(self, name, text, **_):
# TODO Implement Offline variant.
self.neptunelogger.log_text(name, text)
def log_hyperparams(self, params): def log_hyperparams(self, params):
self.neptunelogger.log_hyperparams(params) self.neptunelogger.log_hyperparams(params)
self.csvlogger.log_hyperparams(params) self.csvlogger.log_hyperparams(params)
pass pass
def log_metric(self, metric_name, metric_value, step=None, **kwargs):
self.csvlogger.log_metrics(dict(metric_name=metric_value, **kwargs), step=step, **kwargs)
self.neptunelogger.log_metric(metric_name, metric_value, step=step, **kwargs)
pass
def log_metrics(self, metrics, step=None): def log_metrics(self, metrics, step=None):
self.neptunelogger.log_metrics(metrics, step=step) self.neptunelogger.log_metrics(metrics, step=step)
self.csvlogger.log_metrics(metrics, step=step) self.csvlogger.log_metrics(metrics, step=step)
pass pass
def close(self): def log_image(self, name, image, ext='png', step=None, **kwargs):
self.csvlogger.close() image_name = f'{"0" * (4 - len(str(step)))}{step}_{name}' if step is not None else name
self.neptunelogger.close()
def log_text(self, name, text, **_):
# TODO Implement Offline variant.
self.neptunelogger.log_text(name, text)
def log_metric(self, metric_name, metric_value, **kwargs):
self.csvlogger.log_metrics(dict(metric_name=metric_value))
self.neptunelogger.log_metric(metric_name, metric_value, **kwargs)
def log_image(self, name, image, ext='png', **kwargs):
step = kwargs.get('step', None)
image_name = f'{step}_{name}' if step is not None else name
image_path = self.log_dir / self.media_dir / f'{image_name}.{ext[1:] if ext.startswith(".") else ext}' image_path = self.log_dir / self.media_dir / f'{image_name}.{ext[1:] if ext.startswith(".") else ext}'
(self.log_dir / self.media_dir).mkdir(parents=True, exist_ok=True) (self.log_dir / self.media_dir).mkdir(parents=True, exist_ok=True)
image.savefig(image_path, bbox_inches='tight', pad_inches=0) image.savefig(image_path, bbox_inches='tight', pad_inches=0)

View File

@ -2,7 +2,7 @@ import importlib
import inspect import inspect
import pickle import pickle
import shelve import shelve
from argparse import ArgumentParser from argparse import ArgumentParser, ArgumentError
from ast import literal_eval from ast import literal_eval
from pathlib import Path, PurePath from pathlib import Path, PurePath
from typing import Union from typing import Union
@ -70,14 +70,17 @@ def add_argparse_args(cls, parent_parser):
full_arg_spec = inspect.getfullargspec(cls.__init__) full_arg_spec = inspect.getfullargspec(cls.__init__)
n_non_defaults = len(full_arg_spec.args) - (len(full_arg_spec.defaults) if full_arg_spec.defaults else 0) n_non_defaults = len(full_arg_spec.args) - (len(full_arg_spec.defaults) if full_arg_spec.defaults else 0)
for idx, argument in enumerate(full_arg_spec.args): for idx, argument in enumerate(full_arg_spec.args):
if argument == 'self': try:
if argument == 'self':
continue
if idx < n_non_defaults:
parser.add_argument(f'--{argument}', type=int)
else:
argument_type = type(argument)
parser.add_argument(f'--{argument}',
type=argument_type,
default=full_arg_spec.defaults[idx - n_non_defaults]
)
except ArgumentError:
continue continue
if idx < n_non_defaults:
parser.add_argument(f'--{argument}', type=int)
else:
argument_type = type(argument)
parser.add_argument(f'--{argument}',
type=argument_type,
default=full_arg_spec.defaults[idx - n_non_defaults]
)
return parser return parser