Transformer running

This commit is contained in:
Steffen Illium
2021-03-04 12:01:08 +01:00
parent b5e3e5aec1
commit f89f0f8528
14 changed files with 349 additions and 80 deletions

0
additions/__init__.py Normal file
View File

43
additions/losses.py Normal file
View File

@ -0,0 +1,43 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.modules.loss._WeightedLoss):
def __init__(self, weight=None, gamma=2,reduction='mean'):
super(FocalLoss, self).__init__(weight,reduction=reduction)
self.gamma = gamma
self.weight = weight # weight parameter will act as the alpha parameter to balance class weights
def forward(self, input, target):
ce_loss = F.cross_entropy(input, target, reduction=self.reduction, weight=self.weight)
pt = torch.exp(-ce_loss)
focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
return focal_loss
class FocalLossRob(nn.Module):
# taken from https://github.com/mathiaszinnen/focal_loss_torch/blob/main/focal_loss/focal_loss.py
def __init__(self, alpha=1, gamma=2, reduction: str = 'mean'):
super().__init__()
if reduction not in ['mean', 'none', 'sum']:
raise NotImplementedError('Reduction {} not implemented.'.format(reduction))
self.reduction = reduction
self.alpha = alpha
self.gamma = gamma
def forward(self, x, target):
x = x.clamp(1e-7, 1. - 1e-7) # own addition
p_t = torch.where(target == 1, x, 1-x)
fl = - 1 * (1 - p_t) ** self.gamma * torch.log(p_t)
fl = torch.where(target == 1, fl * self.alpha, fl)
return self._reduce(fl)
def _reduce(self, x):
if self.reduction == 'mean':
return x.mean()
elif self.reduction == 'sum':
return x.sum()
else:
return x