Transformer Implementation

This commit is contained in:
Si11ium 2020-10-29 16:40:43 +01:00
parent f296ba78b9
commit 13812b83b5
5 changed files with 167 additions and 66 deletions

View File

@ -52,7 +52,7 @@ class NormalizeLocal(object):
std = x.std() + 0.0001
# Pytorch Version:
# x = x.__sub__(mean).__div__(std)
# tensor = tensor.__sub__(mean).__div__(std)
# Numpy Version
x = (x - mean) / std
x[np.isnan(x)] = 0

View File

@ -1,3 +1,5 @@
import math
from pathlib import Path
from typing import Union
@ -142,8 +144,8 @@ class DeConvModule(ShapeMixin, nn.Module):
self.autopad = AutoPad() if autopad else lambda x: x
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if norm else lambda x: x
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if norm else F_x(self.in_shape)
self.dropout = nn.Dropout2d(dropout) if dropout else F_x(self.in_shape)
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, self.conv_kernel, bias=bias,
padding=self.padding, stride=self.stride)
@ -168,8 +170,8 @@ class ResidualModule(ShapeMixin, nn.Module):
self.in_shape = in_shape
module_parameters.update(in_shape=in_shape)
if norm:
self.norm = nn.BatchNorm1d if len(self.in_shape) <= 2 else nn.BatchNorm2d
self.norm = self.norm(self.in_shape if isinstance(self.in_shape, int) else self.in_shape[0])
norm = nn.BatchNorm1d if len(self.in_shape) <= 2 else nn.BatchNorm2d
self.norm = norm(self.in_shape if isinstance(self.in_shape, int) else self.in_shape[0])
else:
self.norm = F_x(self.in_shape)
self.activation = module_parameters.get('activation', None)
@ -181,8 +183,9 @@ class ResidualModule(ShapeMixin, nn.Module):
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
def forward(self, x):
tensor = self.norm(x)
for module in self.residual_block:
tensor = module(x)
tensor = module(tensor)
# noinspection PyUnboundLocalVariable
tensor = tensor + x
@ -208,3 +211,84 @@ class RecurrentModule(ShapeMixin, nn.Module):
def forward(self, x):
tensor = self.rnn(x)
return tensor
class AttentionModule(ShapeMixin, nn.Module):
def __init__(self,in_shape, features, dropout=0.1):
super().__init__()
self.in_shape = in_shape
self.dropout = dropout
self.features = features
raise NotImplementedError
def forward(self, x):
pass
class MultiHeadAttentionModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, heads, features, dropout=0.1):
super().__init__()
self.in_shape = in_shape
self.features = features
self.heads = heads
self.final_dim = self.features // self.heads
self.linear_q = LinearModule(self.features, self.features)
self.linear_v = LinearModule(self.features, self.features)
self.linear_k = LinearModule(self.features, self.features)
self.dropout = nn.Dropout(dropout) if dropout else F_x(self.features)
self.linear_out = nn.Linear(self.features, self.features)
def forward(self, q, k, v, mask=None):
batch_size = q.size(0)
# perform linear operation and split into h heads
k = self.linear_k(k).view(batch_size, -1, self.heads, self.final_dim)
q = self.linear_q(q).view(batch_size, -1, self.heads, self.final_dim)
v = self.linear_v(v).view(batch_size, -1, self.heads, self.final_dim)
# transpose to get dimensions bs * h * sl * features
# ToDo: Do we need this?
k = k.transpose(1, 2)
q = q.transpose(1, 2)
v = v.transpose(1, 2)
# calculate attention
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.final_dim)
if mask is not None:
mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask == 0, -1e9)
scores = F.softmax(scores, dim=-1)
scores = self.dropout(scores)
scores = torch.matmul(scores, v)
# concatenate heads and apply final linear transformation
# ToDo: This seems to be old coding style. Do we Need this?
concat = scores.transpose(1, 2).contiguous().view(batch_size, -1, self.features)
output = self.out(concat)
return output
class TransformerModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, hidden_size, n_heads, num_layers=1, dropout=None, use_norm=False, **kwargs):
super(TransformerModule, self).__init__()
self.in_shape = in_shape
self.flat = Flatten(self.in_shape) if isinstance(self.in_shape, (tuple, list)) else F_x(in_shape)
encoder_layer = nn.TransformerEncoderLayer(self.flat_shape, n_heads, dim_feedforward=hidden_size,
dropout=dropout, activation=kwargs.get('activation')
)
self.norm = nn.LayerNorm(hidden_size) if use_norm else F_x(hidden_size)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers, )
def forward(self, x, mask=None, key_padding_mask=None):
tensor = self.flat(x)
tensor = self.transformer(tensor, mask, key_padding_mask)
return tensor

View File

@ -11,7 +11,7 @@ from operator import mul
from torch import nn
from torch.utils.data import DataLoader
from .blocks import ConvModule, DeConvModule, LinearModule
from .blocks import ConvModule, DeConvModule, LinearModule, MultiHeadAttentionModule
from .util import ShapeMixin, LightningBaseModule, Flatten
@ -25,7 +25,7 @@ class AEBaseModule(LightningBaseModule, ABC):
assert bool(dataloader) ^ bool(lat_min and lat_max), 'Decide wether to give min, max or a dataloader, not both.'
min_max = self._find_min_max(dataloader) if dataloader else [None, None]
# assert not any([x is None for x in min_max])
# assert not any([tensor is None for tensor in min_max])
lat_min = torch.as_tensor(lat_min or min_max[0])
lat_max = lat_max or min_max[1]
@ -189,7 +189,7 @@ class BaseEncoder(ShapeMixin, nn.Module):
# Optional Padding for odd image-sizes
# Obsolet, cdan be done by autopadding module on incoming tensors
# in_shape = [x+1 if x % 2 != 0 and idx else x for idx, x in enumerate(in_shape)]
# in_shape = [tensor+1 if tensor % 2 != 0 and idx else tensor for idx, tensor in enumerate(in_shape)]
# Parameters
self.lat_dim = lat_dim
@ -275,3 +275,16 @@ class Encoder(BaseEncoder):
tensor = self.l1(tensor)
tensor = self.latent_activation(tensor) if self.latent_activation else tensor
return tensor
class TransformerEncoder(ShapeMixin, nn.Module):
def __init__(self, in_shape):
super(TransformerEncoder, self).__init__()
# MultiheadSelfAttention
self.msa = MultiHeadAttentionModule()
def forward(self, x):

View File

@ -8,90 +8,93 @@ from operator import mul
from torch import nn
from torch.nn import functional as F
import pytorch_lightning as pl
# Utility - Modules
###################
from ..utils.model_io import ModelParameters
try:
import pytorch_lightning as pl
class LightningBaseModule(pl.LightningModule, ABC):
class LightningBaseModule(pl.LightningModule, ABC):
@classmethod
def name(cls):
return cls.__name__
@classmethod
def name(cls):
return cls.__name__
@property
def shape(self):
try:
x = torch.randn(self.in_shape).unsqueeze(0)
output = self(x)
return output.shape[1:]
except Exception as e:
print(e)
return -1
@property
def shape(self):
try:
x = torch.randn(self.in_shape).unsqueeze(0)
output = self(x)
return output.shape[1:]
except Exception as e:
print(e)
return -1
def __init__(self, hparams):
super(LightningBaseModule, self).__init__()
def __init__(self, hparams):
super(LightningBaseModule, self).__init__()
# Set Parameters
################################
self.hparams = hparams
self.params = ModelParameters(hparams)
# Set Parameters
################################
self.hparams = hparams
self.params = ModelParameters(hparams)
# Dataset Loading
################################
# TODO: Find a way to push Class Name, library path and parameters (sometimes those are objects) in here
def size(self):
return self.shape
def size(self):
return self.shape
def save_to_disk(self, model_path):
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
if not (model_path / 'model_class.obj').exists():
with (model_path / 'model_class.obj').open('wb') as f:
torch.save(self.__class__, f)
return True
def save_to_disk(self, model_path):
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
if not (model_path / 'model_class.obj').exists():
with (model_path / 'model_class.obj').open('wb') as f:
torch.save(self.__class__, f)
return True
@property
def data_len(self):
return len(self.dataset.train_dataset)
@property
def data_len(self):
return len(self.dataset.train_dataset)
@property
def n_train_batches(self):
return len(self.train_dataloader())
@property
def n_train_batches(self):
return len(self.train_dataloader())
def configure_optimizers(self):
raise NotImplementedError
def configure_optimizers(self):
raise NotImplementedError
def forward(self, *args, **kwargs):
raise NotImplementedError
def forward(self, *args, **kwargs):
raise NotImplementedError
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
raise NotImplementedError
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
raise NotImplementedError
def test_step(self, *args, **kwargs):
raise NotImplementedError
def test_step(self, *args, **kwargs):
raise NotImplementedError
def test_epoch_end(self, outputs):
raise NotImplementedError
def test_epoch_end(self, outputs):
raise NotImplementedError
def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_):
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
self.apply(weight_initializer)
def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_):
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
self.apply(weight_initializer)
modules = [LightningBaseModule, nn.Module]
except ImportError:
modules = [nn.Module, ]
pass # Maybe post a hint to install pytorch-lightning.
class ShapeMixin:
@property
def shape(self):
assert isinstance(self, (LightningBaseModule, nn.Module))
assert isinstance(self, modules)
def get_out_shape(output):
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
if self.in_shape is not None:
in_shape = self.in_shape if hasattr(self, 'in_shape') else None
if in_shape is not None:
try:
device = self.device
except AttributeError:
@ -99,10 +102,11 @@ class ShapeMixin:
device = next(self.parameters()).device
except StopIteration:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
x = torch.randn(self.in_shape, device=device)
x = torch.randn(in_shape, device=device)
# This is needed for BatchNorm shape checking
x = torch.stack((x, x))
# noinspection PyCallingNonCallable
y = self(x)
if isinstance(y, tuple):
shape = tuple([get_out_shape(y[i]) for i in range(len(y))])
@ -265,7 +269,7 @@ class Splitter(nn.Module):
self.autopad = AutoPadToShape(self._out_shape)
def forward(self, x: torch.Tensor):
dim = self.dim + 1 if len(self.in_shape) == (x.ndim -1) else self.dim
dim = self.dim + 1 if len(self.in_shape) == (x.ndim - 1) else self.dim
x = x.transpose(0, dim)
n_blocks = list()
for block_idx in range(self.n):

View File

@ -102,7 +102,7 @@ class Config(ConfigParser, ABC):
# TODO: Do this programmatically; This did not work:
# Initialize Default Sections as Property
# for section in self.default_sections:
# self.__setattr__(section, property(lambda x :x._get_namespace_for_section(section))
# self.__setattr__(section, property(lambda tensor :tensor._get_namespace_for_section(section))
@property
def main(self):