Transformer Implementation
This commit is contained in:
parent
f296ba78b9
commit
13812b83b5
@ -52,7 +52,7 @@ class NormalizeLocal(object):
|
||||
std = x.std() + 0.0001
|
||||
|
||||
# Pytorch Version:
|
||||
# x = x.__sub__(mean).__div__(std)
|
||||
# tensor = tensor.__sub__(mean).__div__(std)
|
||||
# Numpy Version
|
||||
x = (x - mean) / std
|
||||
x[np.isnan(x)] = 0
|
||||
|
@ -1,3 +1,5 @@
|
||||
import math
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
@ -142,8 +144,8 @@ class DeConvModule(ShapeMixin, nn.Module):
|
||||
|
||||
self.autopad = AutoPad() if autopad else lambda x: x
|
||||
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
|
||||
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if norm else lambda x: x
|
||||
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
|
||||
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if norm else F_x(self.in_shape)
|
||||
self.dropout = nn.Dropout2d(dropout) if dropout else F_x(self.in_shape)
|
||||
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, self.conv_kernel, bias=bias,
|
||||
padding=self.padding, stride=self.stride)
|
||||
|
||||
@ -168,8 +170,8 @@ class ResidualModule(ShapeMixin, nn.Module):
|
||||
self.in_shape = in_shape
|
||||
module_parameters.update(in_shape=in_shape)
|
||||
if norm:
|
||||
self.norm = nn.BatchNorm1d if len(self.in_shape) <= 2 else nn.BatchNorm2d
|
||||
self.norm = self.norm(self.in_shape if isinstance(self.in_shape, int) else self.in_shape[0])
|
||||
norm = nn.BatchNorm1d if len(self.in_shape) <= 2 else nn.BatchNorm2d
|
||||
self.norm = norm(self.in_shape if isinstance(self.in_shape, int) else self.in_shape[0])
|
||||
else:
|
||||
self.norm = F_x(self.in_shape)
|
||||
self.activation = module_parameters.get('activation', None)
|
||||
@ -181,8 +183,9 @@ class ResidualModule(ShapeMixin, nn.Module):
|
||||
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
|
||||
|
||||
def forward(self, x):
|
||||
tensor = self.norm(x)
|
||||
for module in self.residual_block:
|
||||
tensor = module(x)
|
||||
tensor = module(tensor)
|
||||
|
||||
# noinspection PyUnboundLocalVariable
|
||||
tensor = tensor + x
|
||||
@ -208,3 +211,84 @@ class RecurrentModule(ShapeMixin, nn.Module):
|
||||
def forward(self, x):
|
||||
tensor = self.rnn(x)
|
||||
return tensor
|
||||
|
||||
|
||||
class AttentionModule(ShapeMixin, nn.Module):
|
||||
def __init__(self,in_shape, features, dropout=0.1):
|
||||
super().__init__()
|
||||
self.in_shape = in_shape
|
||||
self.dropout = dropout
|
||||
self.features = features
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, x):
|
||||
pass
|
||||
|
||||
|
||||
class MultiHeadAttentionModule(ShapeMixin, nn.Module):
|
||||
def __init__(self, in_shape, heads, features, dropout=0.1):
|
||||
super().__init__()
|
||||
self.in_shape = in_shape
|
||||
|
||||
self.features = features
|
||||
self.heads = heads
|
||||
self.final_dim = self.features // self.heads
|
||||
|
||||
self.linear_q = LinearModule(self.features, self.features)
|
||||
self.linear_v = LinearModule(self.features, self.features)
|
||||
self.linear_k = LinearModule(self.features, self.features)
|
||||
self.dropout = nn.Dropout(dropout) if dropout else F_x(self.features)
|
||||
self.linear_out = nn.Linear(self.features, self.features)
|
||||
|
||||
def forward(self, q, k, v, mask=None):
|
||||
|
||||
batch_size = q.size(0)
|
||||
|
||||
# perform linear operation and split into h heads
|
||||
k = self.linear_k(k).view(batch_size, -1, self.heads, self.final_dim)
|
||||
q = self.linear_q(q).view(batch_size, -1, self.heads, self.final_dim)
|
||||
v = self.linear_v(v).view(batch_size, -1, self.heads, self.final_dim)
|
||||
|
||||
# transpose to get dimensions bs * h * sl * features
|
||||
# ToDo: Do we need this?
|
||||
|
||||
k = k.transpose(1, 2)
|
||||
q = q.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
# calculate attention
|
||||
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.final_dim)
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1)
|
||||
scores = scores.masked_fill(mask == 0, -1e9)
|
||||
scores = F.softmax(scores, dim=-1)
|
||||
scores = self.dropout(scores)
|
||||
scores = torch.matmul(scores, v)
|
||||
|
||||
# concatenate heads and apply final linear transformation
|
||||
# ToDo: This seems to be old coding style. Do we Need this?
|
||||
concat = scores.transpose(1, 2).contiguous().view(batch_size, -1, self.features)
|
||||
|
||||
output = self.out(concat)
|
||||
return output
|
||||
|
||||
|
||||
class TransformerModule(ShapeMixin, nn.Module):
|
||||
|
||||
def __init__(self, in_shape, hidden_size, n_heads, num_layers=1, dropout=None, use_norm=False, **kwargs):
|
||||
super(TransformerModule, self).__init__()
|
||||
|
||||
self.in_shape = in_shape
|
||||
|
||||
self.flat = Flatten(self.in_shape) if isinstance(self.in_shape, (tuple, list)) else F_x(in_shape)
|
||||
|
||||
encoder_layer = nn.TransformerEncoderLayer(self.flat_shape, n_heads, dim_feedforward=hidden_size,
|
||||
dropout=dropout, activation=kwargs.get('activation')
|
||||
)
|
||||
self.norm = nn.LayerNorm(hidden_size) if use_norm else F_x(hidden_size)
|
||||
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers, )
|
||||
|
||||
def forward(self, x, mask=None, key_padding_mask=None):
|
||||
tensor = self.flat(x)
|
||||
tensor = self.transformer(tensor, mask, key_padding_mask)
|
||||
return tensor
|
||||
|
@ -11,7 +11,7 @@ from operator import mul
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .blocks import ConvModule, DeConvModule, LinearModule
|
||||
from .blocks import ConvModule, DeConvModule, LinearModule, MultiHeadAttentionModule
|
||||
|
||||
from .util import ShapeMixin, LightningBaseModule, Flatten
|
||||
|
||||
@ -25,7 +25,7 @@ class AEBaseModule(LightningBaseModule, ABC):
|
||||
assert bool(dataloader) ^ bool(lat_min and lat_max), 'Decide wether to give min, max or a dataloader, not both.'
|
||||
|
||||
min_max = self._find_min_max(dataloader) if dataloader else [None, None]
|
||||
# assert not any([x is None for x in min_max])
|
||||
# assert not any([tensor is None for tensor in min_max])
|
||||
lat_min = torch.as_tensor(lat_min or min_max[0])
|
||||
lat_max = lat_max or min_max[1]
|
||||
|
||||
@ -189,7 +189,7 @@ class BaseEncoder(ShapeMixin, nn.Module):
|
||||
|
||||
# Optional Padding for odd image-sizes
|
||||
# Obsolet, cdan be done by autopadding module on incoming tensors
|
||||
# in_shape = [x+1 if x % 2 != 0 and idx else x for idx, x in enumerate(in_shape)]
|
||||
# in_shape = [tensor+1 if tensor % 2 != 0 and idx else tensor for idx, tensor in enumerate(in_shape)]
|
||||
|
||||
# Parameters
|
||||
self.lat_dim = lat_dim
|
||||
@ -275,3 +275,16 @@ class Encoder(BaseEncoder):
|
||||
tensor = self.l1(tensor)
|
||||
tensor = self.latent_activation(tensor) if self.latent_activation else tensor
|
||||
return tensor
|
||||
|
||||
|
||||
class TransformerEncoder(ShapeMixin, nn.Module):
|
||||
|
||||
def __init__(self, in_shape):
|
||||
super(TransformerEncoder, self).__init__()
|
||||
# MultiheadSelfAttention
|
||||
self.msa = MultiHeadAttentionModule()
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
|
||||
|
116
modules/util.py
116
modules/util.py
@ -8,90 +8,93 @@ from operator import mul
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
||||
|
||||
# Utility - Modules
|
||||
###################
|
||||
from ..utils.model_io import ModelParameters
|
||||
|
||||
try:
|
||||
import pytorch_lightning as pl
|
||||
|
||||
class LightningBaseModule(pl.LightningModule, ABC):
|
||||
class LightningBaseModule(pl.LightningModule, ABC):
|
||||
|
||||
@classmethod
|
||||
def name(cls):
|
||||
return cls.__name__
|
||||
@classmethod
|
||||
def name(cls):
|
||||
return cls.__name__
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
try:
|
||||
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||
output = self(x)
|
||||
return output.shape[1:]
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return -1
|
||||
@property
|
||||
def shape(self):
|
||||
try:
|
||||
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||
output = self(x)
|
||||
return output.shape[1:]
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return -1
|
||||
|
||||
def __init__(self, hparams):
|
||||
super(LightningBaseModule, self).__init__()
|
||||
def __init__(self, hparams):
|
||||
super(LightningBaseModule, self).__init__()
|
||||
|
||||
# Set Parameters
|
||||
################################
|
||||
self.hparams = hparams
|
||||
self.params = ModelParameters(hparams)
|
||||
# Set Parameters
|
||||
################################
|
||||
self.hparams = hparams
|
||||
self.params = ModelParameters(hparams)
|
||||
|
||||
# Dataset Loading
|
||||
################################
|
||||
# TODO: Find a way to push Class Name, library path and parameters (sometimes those are objects) in here
|
||||
def size(self):
|
||||
return self.shape
|
||||
|
||||
def size(self):
|
||||
return self.shape
|
||||
def save_to_disk(self, model_path):
|
||||
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
|
||||
if not (model_path / 'model_class.obj').exists():
|
||||
with (model_path / 'model_class.obj').open('wb') as f:
|
||||
torch.save(self.__class__, f)
|
||||
return True
|
||||
|
||||
def save_to_disk(self, model_path):
|
||||
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
|
||||
if not (model_path / 'model_class.obj').exists():
|
||||
with (model_path / 'model_class.obj').open('wb') as f:
|
||||
torch.save(self.__class__, f)
|
||||
return True
|
||||
@property
|
||||
def data_len(self):
|
||||
return len(self.dataset.train_dataset)
|
||||
|
||||
@property
|
||||
def data_len(self):
|
||||
return len(self.dataset.train_dataset)
|
||||
@property
|
||||
def n_train_batches(self):
|
||||
return len(self.train_dataloader())
|
||||
|
||||
@property
|
||||
def n_train_batches(self):
|
||||
return len(self.train_dataloader())
|
||||
def configure_optimizers(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def configure_optimizers(self):
|
||||
raise NotImplementedError
|
||||
def forward(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
def test_step(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def test_step(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
def test_epoch_end(self, outputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
raise NotImplementedError
|
||||
def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_):
|
||||
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
|
||||
self.apply(weight_initializer)
|
||||
|
||||
def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_):
|
||||
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
|
||||
self.apply(weight_initializer)
|
||||
modules = [LightningBaseModule, nn.Module]
|
||||
|
||||
except ImportError:
|
||||
modules = [nn.Module, ]
|
||||
pass # Maybe post a hint to install pytorch-lightning.
|
||||
|
||||
|
||||
class ShapeMixin:
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
assert isinstance(self, (LightningBaseModule, nn.Module))
|
||||
|
||||
assert isinstance(self, modules)
|
||||
|
||||
def get_out_shape(output):
|
||||
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
|
||||
|
||||
if self.in_shape is not None:
|
||||
in_shape = self.in_shape if hasattr(self, 'in_shape') else None
|
||||
if in_shape is not None:
|
||||
try:
|
||||
device = self.device
|
||||
except AttributeError:
|
||||
@ -99,10 +102,11 @@ class ShapeMixin:
|
||||
device = next(self.parameters()).device
|
||||
except StopIteration:
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
x = torch.randn(self.in_shape, device=device)
|
||||
x = torch.randn(in_shape, device=device)
|
||||
# This is needed for BatchNorm shape checking
|
||||
x = torch.stack((x, x))
|
||||
|
||||
# noinspection PyCallingNonCallable
|
||||
y = self(x)
|
||||
if isinstance(y, tuple):
|
||||
shape = tuple([get_out_shape(y[i]) for i in range(len(y))])
|
||||
@ -265,7 +269,7 @@ class Splitter(nn.Module):
|
||||
self.autopad = AutoPadToShape(self._out_shape)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
dim = self.dim + 1 if len(self.in_shape) == (x.ndim -1) else self.dim
|
||||
dim = self.dim + 1 if len(self.in_shape) == (x.ndim - 1) else self.dim
|
||||
x = x.transpose(0, dim)
|
||||
n_blocks = list()
|
||||
for block_idx in range(self.n):
|
||||
|
@ -102,7 +102,7 @@ class Config(ConfigParser, ABC):
|
||||
# TODO: Do this programmatically; This did not work:
|
||||
# Initialize Default Sections as Property
|
||||
# for section in self.default_sections:
|
||||
# self.__setattr__(section, property(lambda x :x._get_namespace_for_section(section))
|
||||
# self.__setattr__(section, property(lambda tensor :tensor._get_namespace_for_section(section))
|
||||
|
||||
@property
|
||||
def main(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user