Transformer Implementation

This commit is contained in:
Si11ium 2020-10-29 16:40:43 +01:00
parent f296ba78b9
commit 13812b83b5
5 changed files with 167 additions and 66 deletions

View File

@ -52,7 +52,7 @@ class NormalizeLocal(object):
std = x.std() + 0.0001 std = x.std() + 0.0001
# Pytorch Version: # Pytorch Version:
# x = x.__sub__(mean).__div__(std) # tensor = tensor.__sub__(mean).__div__(std)
# Numpy Version # Numpy Version
x = (x - mean) / std x = (x - mean) / std
x[np.isnan(x)] = 0 x[np.isnan(x)] = 0

View File

@ -1,3 +1,5 @@
import math
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
@ -142,8 +144,8 @@ class DeConvModule(ShapeMixin, nn.Module):
self.autopad = AutoPad() if autopad else lambda x: x self.autopad = AutoPad() if autopad else lambda x: x
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if norm else lambda x: x self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if norm else F_x(self.in_shape)
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x self.dropout = nn.Dropout2d(dropout) if dropout else F_x(self.in_shape)
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, self.conv_kernel, bias=bias, self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, self.conv_kernel, bias=bias,
padding=self.padding, stride=self.stride) padding=self.padding, stride=self.stride)
@ -168,8 +170,8 @@ class ResidualModule(ShapeMixin, nn.Module):
self.in_shape = in_shape self.in_shape = in_shape
module_parameters.update(in_shape=in_shape) module_parameters.update(in_shape=in_shape)
if norm: if norm:
self.norm = nn.BatchNorm1d if len(self.in_shape) <= 2 else nn.BatchNorm2d norm = nn.BatchNorm1d if len(self.in_shape) <= 2 else nn.BatchNorm2d
self.norm = self.norm(self.in_shape if isinstance(self.in_shape, int) else self.in_shape[0]) self.norm = norm(self.in_shape if isinstance(self.in_shape, int) else self.in_shape[0])
else: else:
self.norm = F_x(self.in_shape) self.norm = F_x(self.in_shape)
self.activation = module_parameters.get('activation', None) self.activation = module_parameters.get('activation', None)
@ -181,8 +183,9 @@ class ResidualModule(ShapeMixin, nn.Module):
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.' assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
def forward(self, x): def forward(self, x):
tensor = self.norm(x)
for module in self.residual_block: for module in self.residual_block:
tensor = module(x) tensor = module(tensor)
# noinspection PyUnboundLocalVariable # noinspection PyUnboundLocalVariable
tensor = tensor + x tensor = tensor + x
@ -208,3 +211,84 @@ class RecurrentModule(ShapeMixin, nn.Module):
def forward(self, x): def forward(self, x):
tensor = self.rnn(x) tensor = self.rnn(x)
return tensor return tensor
class AttentionModule(ShapeMixin, nn.Module):
def __init__(self,in_shape, features, dropout=0.1):
super().__init__()
self.in_shape = in_shape
self.dropout = dropout
self.features = features
raise NotImplementedError
def forward(self, x):
pass
class MultiHeadAttentionModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, heads, features, dropout=0.1):
super().__init__()
self.in_shape = in_shape
self.features = features
self.heads = heads
self.final_dim = self.features // self.heads
self.linear_q = LinearModule(self.features, self.features)
self.linear_v = LinearModule(self.features, self.features)
self.linear_k = LinearModule(self.features, self.features)
self.dropout = nn.Dropout(dropout) if dropout else F_x(self.features)
self.linear_out = nn.Linear(self.features, self.features)
def forward(self, q, k, v, mask=None):
batch_size = q.size(0)
# perform linear operation and split into h heads
k = self.linear_k(k).view(batch_size, -1, self.heads, self.final_dim)
q = self.linear_q(q).view(batch_size, -1, self.heads, self.final_dim)
v = self.linear_v(v).view(batch_size, -1, self.heads, self.final_dim)
# transpose to get dimensions bs * h * sl * features
# ToDo: Do we need this?
k = k.transpose(1, 2)
q = q.transpose(1, 2)
v = v.transpose(1, 2)
# calculate attention
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.final_dim)
if mask is not None:
mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask == 0, -1e9)
scores = F.softmax(scores, dim=-1)
scores = self.dropout(scores)
scores = torch.matmul(scores, v)
# concatenate heads and apply final linear transformation
# ToDo: This seems to be old coding style. Do we Need this?
concat = scores.transpose(1, 2).contiguous().view(batch_size, -1, self.features)
output = self.out(concat)
return output
class TransformerModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, hidden_size, n_heads, num_layers=1, dropout=None, use_norm=False, **kwargs):
super(TransformerModule, self).__init__()
self.in_shape = in_shape
self.flat = Flatten(self.in_shape) if isinstance(self.in_shape, (tuple, list)) else F_x(in_shape)
encoder_layer = nn.TransformerEncoderLayer(self.flat_shape, n_heads, dim_feedforward=hidden_size,
dropout=dropout, activation=kwargs.get('activation')
)
self.norm = nn.LayerNorm(hidden_size) if use_norm else F_x(hidden_size)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers, )
def forward(self, x, mask=None, key_padding_mask=None):
tensor = self.flat(x)
tensor = self.transformer(tensor, mask, key_padding_mask)
return tensor

View File

@ -11,7 +11,7 @@ from operator import mul
from torch import nn from torch import nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from .blocks import ConvModule, DeConvModule, LinearModule from .blocks import ConvModule, DeConvModule, LinearModule, MultiHeadAttentionModule
from .util import ShapeMixin, LightningBaseModule, Flatten from .util import ShapeMixin, LightningBaseModule, Flatten
@ -25,7 +25,7 @@ class AEBaseModule(LightningBaseModule, ABC):
assert bool(dataloader) ^ bool(lat_min and lat_max), 'Decide wether to give min, max or a dataloader, not both.' assert bool(dataloader) ^ bool(lat_min and lat_max), 'Decide wether to give min, max or a dataloader, not both.'
min_max = self._find_min_max(dataloader) if dataloader else [None, None] min_max = self._find_min_max(dataloader) if dataloader else [None, None]
# assert not any([x is None for x in min_max]) # assert not any([tensor is None for tensor in min_max])
lat_min = torch.as_tensor(lat_min or min_max[0]) lat_min = torch.as_tensor(lat_min or min_max[0])
lat_max = lat_max or min_max[1] lat_max = lat_max or min_max[1]
@ -189,7 +189,7 @@ class BaseEncoder(ShapeMixin, nn.Module):
# Optional Padding for odd image-sizes # Optional Padding for odd image-sizes
# Obsolet, cdan be done by autopadding module on incoming tensors # Obsolet, cdan be done by autopadding module on incoming tensors
# in_shape = [x+1 if x % 2 != 0 and idx else x for idx, x in enumerate(in_shape)] # in_shape = [tensor+1 if tensor % 2 != 0 and idx else tensor for idx, tensor in enumerate(in_shape)]
# Parameters # Parameters
self.lat_dim = lat_dim self.lat_dim = lat_dim
@ -275,3 +275,16 @@ class Encoder(BaseEncoder):
tensor = self.l1(tensor) tensor = self.l1(tensor)
tensor = self.latent_activation(tensor) if self.latent_activation else tensor tensor = self.latent_activation(tensor) if self.latent_activation else tensor
return tensor return tensor
class TransformerEncoder(ShapeMixin, nn.Module):
def __init__(self, in_shape):
super(TransformerEncoder, self).__init__()
# MultiheadSelfAttention
self.msa = MultiHeadAttentionModule()
def forward(self, x):

View File

@ -8,90 +8,93 @@ from operator import mul
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
import pytorch_lightning as pl
# Utility - Modules # Utility - Modules
################### ###################
from ..utils.model_io import ModelParameters from ..utils.model_io import ModelParameters
try:
import pytorch_lightning as pl
class LightningBaseModule(pl.LightningModule, ABC): class LightningBaseModule(pl.LightningModule, ABC):
@classmethod @classmethod
def name(cls): def name(cls):
return cls.__name__ return cls.__name__
@property @property
def shape(self): def shape(self):
try: try:
x = torch.randn(self.in_shape).unsqueeze(0) x = torch.randn(self.in_shape).unsqueeze(0)
output = self(x) output = self(x)
return output.shape[1:] return output.shape[1:]
except Exception as e: except Exception as e:
print(e) print(e)
return -1 return -1
def __init__(self, hparams): def __init__(self, hparams):
super(LightningBaseModule, self).__init__() super(LightningBaseModule, self).__init__()
# Set Parameters # Set Parameters
################################ ################################
self.hparams = hparams self.hparams = hparams
self.params = ModelParameters(hparams) self.params = ModelParameters(hparams)
# Dataset Loading def size(self):
################################ return self.shape
# TODO: Find a way to push Class Name, library path and parameters (sometimes those are objects) in here
def size(self): def save_to_disk(self, model_path):
return self.shape Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
if not (model_path / 'model_class.obj').exists():
with (model_path / 'model_class.obj').open('wb') as f:
torch.save(self.__class__, f)
return True
def save_to_disk(self, model_path): @property
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True) def data_len(self):
if not (model_path / 'model_class.obj').exists(): return len(self.dataset.train_dataset)
with (model_path / 'model_class.obj').open('wb') as f:
torch.save(self.__class__, f)
return True
@property @property
def data_len(self): def n_train_batches(self):
return len(self.dataset.train_dataset) return len(self.train_dataloader())
@property def configure_optimizers(self):
def n_train_batches(self): raise NotImplementedError
return len(self.train_dataloader())
def configure_optimizers(self): def forward(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
def forward(self, *args, **kwargs): def training_step(self, batch_xy, batch_nb, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
def training_step(self, batch_xy, batch_nb, *args, **kwargs): def test_step(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
def test_step(self, *args, **kwargs): def test_epoch_end(self, outputs):
raise NotImplementedError raise NotImplementedError
def test_epoch_end(self, outputs): def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_):
raise NotImplementedError weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
self.apply(weight_initializer)
def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_): modules = [LightningBaseModule, nn.Module]
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
self.apply(weight_initializer) except ImportError:
modules = [nn.Module, ]
pass # Maybe post a hint to install pytorch-lightning.
class ShapeMixin: class ShapeMixin:
@property @property
def shape(self): def shape(self):
assert isinstance(self, (LightningBaseModule, nn.Module))
assert isinstance(self, modules)
def get_out_shape(output): def get_out_shape(output):
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1] return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
if self.in_shape is not None: in_shape = self.in_shape if hasattr(self, 'in_shape') else None
if in_shape is not None:
try: try:
device = self.device device = self.device
except AttributeError: except AttributeError:
@ -99,10 +102,11 @@ class ShapeMixin:
device = next(self.parameters()).device device = next(self.parameters()).device
except StopIteration: except StopIteration:
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
x = torch.randn(self.in_shape, device=device) x = torch.randn(in_shape, device=device)
# This is needed for BatchNorm shape checking # This is needed for BatchNorm shape checking
x = torch.stack((x, x)) x = torch.stack((x, x))
# noinspection PyCallingNonCallable
y = self(x) y = self(x)
if isinstance(y, tuple): if isinstance(y, tuple):
shape = tuple([get_out_shape(y[i]) for i in range(len(y))]) shape = tuple([get_out_shape(y[i]) for i in range(len(y))])
@ -265,7 +269,7 @@ class Splitter(nn.Module):
self.autopad = AutoPadToShape(self._out_shape) self.autopad = AutoPadToShape(self._out_shape)
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
dim = self.dim + 1 if len(self.in_shape) == (x.ndim -1) else self.dim dim = self.dim + 1 if len(self.in_shape) == (x.ndim - 1) else self.dim
x = x.transpose(0, dim) x = x.transpose(0, dim)
n_blocks = list() n_blocks = list()
for block_idx in range(self.n): for block_idx in range(self.n):

View File

@ -102,7 +102,7 @@ class Config(ConfigParser, ABC):
# TODO: Do this programmatically; This did not work: # TODO: Do this programmatically; This did not work:
# Initialize Default Sections as Property # Initialize Default Sections as Property
# for section in self.default_sections: # for section in self.default_sections:
# self.__setattr__(section, property(lambda x :x._get_namespace_for_section(section)) # self.__setattr__(section, property(lambda tensor :tensor._get_namespace_for_section(section))
@property @property
def main(self): def main(self):