Transformer Implementation

This commit is contained in:
Si11ium 2020-10-29 16:40:43 +01:00
parent f296ba78b9
commit 13812b83b5
5 changed files with 167 additions and 66 deletions

View File

@ -52,7 +52,7 @@ class NormalizeLocal(object):
std = x.std() + 0.0001
# Pytorch Version:
# x = x.__sub__(mean).__div__(std)
# tensor = tensor.__sub__(mean).__div__(std)
# Numpy Version
x = (x - mean) / std
x[np.isnan(x)] = 0

View File

@ -1,3 +1,5 @@
import math
from pathlib import Path
from typing import Union
@ -142,8 +144,8 @@ class DeConvModule(ShapeMixin, nn.Module):
self.autopad = AutoPad() if autopad else lambda x: x
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if norm else lambda x: x
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if norm else F_x(self.in_shape)
self.dropout = nn.Dropout2d(dropout) if dropout else F_x(self.in_shape)
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, self.conv_kernel, bias=bias,
padding=self.padding, stride=self.stride)
@ -168,8 +170,8 @@ class ResidualModule(ShapeMixin, nn.Module):
self.in_shape = in_shape
module_parameters.update(in_shape=in_shape)
if norm:
self.norm = nn.BatchNorm1d if len(self.in_shape) <= 2 else nn.BatchNorm2d
self.norm = self.norm(self.in_shape if isinstance(self.in_shape, int) else self.in_shape[0])
norm = nn.BatchNorm1d if len(self.in_shape) <= 2 else nn.BatchNorm2d
self.norm = norm(self.in_shape if isinstance(self.in_shape, int) else self.in_shape[0])
else:
self.norm = F_x(self.in_shape)
self.activation = module_parameters.get('activation', None)
@ -181,8 +183,9 @@ class ResidualModule(ShapeMixin, nn.Module):
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
def forward(self, x):
tensor = self.norm(x)
for module in self.residual_block:
tensor = module(x)
tensor = module(tensor)
# noinspection PyUnboundLocalVariable
tensor = tensor + x
@ -208,3 +211,84 @@ class RecurrentModule(ShapeMixin, nn.Module):
def forward(self, x):
tensor = self.rnn(x)
return tensor
class AttentionModule(ShapeMixin, nn.Module):
def __init__(self,in_shape, features, dropout=0.1):
super().__init__()
self.in_shape = in_shape
self.dropout = dropout
self.features = features
raise NotImplementedError
def forward(self, x):
pass
class MultiHeadAttentionModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, heads, features, dropout=0.1):
super().__init__()
self.in_shape = in_shape
self.features = features
self.heads = heads
self.final_dim = self.features // self.heads
self.linear_q = LinearModule(self.features, self.features)
self.linear_v = LinearModule(self.features, self.features)
self.linear_k = LinearModule(self.features, self.features)
self.dropout = nn.Dropout(dropout) if dropout else F_x(self.features)
self.linear_out = nn.Linear(self.features, self.features)
def forward(self, q, k, v, mask=None):
batch_size = q.size(0)
# perform linear operation and split into h heads
k = self.linear_k(k).view(batch_size, -1, self.heads, self.final_dim)
q = self.linear_q(q).view(batch_size, -1, self.heads, self.final_dim)
v = self.linear_v(v).view(batch_size, -1, self.heads, self.final_dim)
# transpose to get dimensions bs * h * sl * features
# ToDo: Do we need this?
k = k.transpose(1, 2)
q = q.transpose(1, 2)
v = v.transpose(1, 2)
# calculate attention
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.final_dim)
if mask is not None:
mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask == 0, -1e9)
scores = F.softmax(scores, dim=-1)
scores = self.dropout(scores)
scores = torch.matmul(scores, v)
# concatenate heads and apply final linear transformation
# ToDo: This seems to be old coding style. Do we Need this?
concat = scores.transpose(1, 2).contiguous().view(batch_size, -1, self.features)
output = self.out(concat)
return output
class TransformerModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, hidden_size, n_heads, num_layers=1, dropout=None, use_norm=False, **kwargs):
super(TransformerModule, self).__init__()
self.in_shape = in_shape
self.flat = Flatten(self.in_shape) if isinstance(self.in_shape, (tuple, list)) else F_x(in_shape)
encoder_layer = nn.TransformerEncoderLayer(self.flat_shape, n_heads, dim_feedforward=hidden_size,
dropout=dropout, activation=kwargs.get('activation')
)
self.norm = nn.LayerNorm(hidden_size) if use_norm else F_x(hidden_size)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers, )
def forward(self, x, mask=None, key_padding_mask=None):
tensor = self.flat(x)
tensor = self.transformer(tensor, mask, key_padding_mask)
return tensor

View File

@ -11,7 +11,7 @@ from operator import mul
from torch import nn
from torch.utils.data import DataLoader
from .blocks import ConvModule, DeConvModule, LinearModule
from .blocks import ConvModule, DeConvModule, LinearModule, MultiHeadAttentionModule
from .util import ShapeMixin, LightningBaseModule, Flatten
@ -25,7 +25,7 @@ class AEBaseModule(LightningBaseModule, ABC):
assert bool(dataloader) ^ bool(lat_min and lat_max), 'Decide wether to give min, max or a dataloader, not both.'
min_max = self._find_min_max(dataloader) if dataloader else [None, None]
# assert not any([x is None for x in min_max])
# assert not any([tensor is None for tensor in min_max])
lat_min = torch.as_tensor(lat_min or min_max[0])
lat_max = lat_max or min_max[1]
@ -189,7 +189,7 @@ class BaseEncoder(ShapeMixin, nn.Module):
# Optional Padding for odd image-sizes
# Obsolet, cdan be done by autopadding module on incoming tensors
# in_shape = [x+1 if x % 2 != 0 and idx else x for idx, x in enumerate(in_shape)]
# in_shape = [tensor+1 if tensor % 2 != 0 and idx else tensor for idx, tensor in enumerate(in_shape)]
# Parameters
self.lat_dim = lat_dim
@ -275,3 +275,16 @@ class Encoder(BaseEncoder):
tensor = self.l1(tensor)
tensor = self.latent_activation(tensor) if self.latent_activation else tensor
return tensor
class TransformerEncoder(ShapeMixin, nn.Module):
def __init__(self, in_shape):
super(TransformerEncoder, self).__init__()
# MultiheadSelfAttention
self.msa = MultiHeadAttentionModule()
def forward(self, x):

View File

@ -8,13 +8,12 @@ from operator import mul
from torch import nn
from torch.nn import functional as F
import pytorch_lightning as pl
# Utility - Modules
###################
from ..utils.model_io import ModelParameters
try:
import pytorch_lightning as pl
class LightningBaseModule(pl.LightningModule, ABC):
@ -40,10 +39,6 @@ class LightningBaseModule(pl.LightningModule, ABC):
self.hparams = hparams
self.params = ModelParameters(hparams)
# Dataset Loading
################################
# TODO: Find a way to push Class Name, library path and parameters (sometimes those are objects) in here
def size(self):
return self.shape
@ -81,17 +76,25 @@ class LightningBaseModule(pl.LightningModule, ABC):
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
self.apply(weight_initializer)
modules = [LightningBaseModule, nn.Module]
except ImportError:
modules = [nn.Module, ]
pass # Maybe post a hint to install pytorch-lightning.
class ShapeMixin:
@property
def shape(self):
assert isinstance(self, (LightningBaseModule, nn.Module))
assert isinstance(self, modules)
def get_out_shape(output):
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
if self.in_shape is not None:
in_shape = self.in_shape if hasattr(self, 'in_shape') else None
if in_shape is not None:
try:
device = self.device
except AttributeError:
@ -99,10 +102,11 @@ class ShapeMixin:
device = next(self.parameters()).device
except StopIteration:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
x = torch.randn(self.in_shape, device=device)
x = torch.randn(in_shape, device=device)
# This is needed for BatchNorm shape checking
x = torch.stack((x, x))
# noinspection PyCallingNonCallable
y = self(x)
if isinstance(y, tuple):
shape = tuple([get_out_shape(y[i]) for i in range(len(y))])

View File

@ -102,7 +102,7 @@ class Config(ConfigParser, ABC):
# TODO: Do this programmatically; This did not work:
# Initialize Default Sections as Property
# for section in self.default_sections:
# self.__setattr__(section, property(lambda x :x._get_namespace_for_section(section))
# self.__setattr__(section, property(lambda tensor :tensor._get_namespace_for_section(section))
@property
def main(self):