Transformer Implementation
This commit is contained in:
parent
f296ba78b9
commit
13812b83b5
@ -52,7 +52,7 @@ class NormalizeLocal(object):
|
|||||||
std = x.std() + 0.0001
|
std = x.std() + 0.0001
|
||||||
|
|
||||||
# Pytorch Version:
|
# Pytorch Version:
|
||||||
# x = x.__sub__(mean).__div__(std)
|
# tensor = tensor.__sub__(mean).__div__(std)
|
||||||
# Numpy Version
|
# Numpy Version
|
||||||
x = (x - mean) / std
|
x = (x - mean) / std
|
||||||
x[np.isnan(x)] = 0
|
x[np.isnan(x)] = 0
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@ -142,8 +144,8 @@ class DeConvModule(ShapeMixin, nn.Module):
|
|||||||
|
|
||||||
self.autopad = AutoPad() if autopad else lambda x: x
|
self.autopad = AutoPad() if autopad else lambda x: x
|
||||||
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
|
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
|
||||||
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if norm else lambda x: x
|
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if norm else F_x(self.in_shape)
|
||||||
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
|
self.dropout = nn.Dropout2d(dropout) if dropout else F_x(self.in_shape)
|
||||||
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, self.conv_kernel, bias=bias,
|
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, self.conv_kernel, bias=bias,
|
||||||
padding=self.padding, stride=self.stride)
|
padding=self.padding, stride=self.stride)
|
||||||
|
|
||||||
@ -168,8 +170,8 @@ class ResidualModule(ShapeMixin, nn.Module):
|
|||||||
self.in_shape = in_shape
|
self.in_shape = in_shape
|
||||||
module_parameters.update(in_shape=in_shape)
|
module_parameters.update(in_shape=in_shape)
|
||||||
if norm:
|
if norm:
|
||||||
self.norm = nn.BatchNorm1d if len(self.in_shape) <= 2 else nn.BatchNorm2d
|
norm = nn.BatchNorm1d if len(self.in_shape) <= 2 else nn.BatchNorm2d
|
||||||
self.norm = self.norm(self.in_shape if isinstance(self.in_shape, int) else self.in_shape[0])
|
self.norm = norm(self.in_shape if isinstance(self.in_shape, int) else self.in_shape[0])
|
||||||
else:
|
else:
|
||||||
self.norm = F_x(self.in_shape)
|
self.norm = F_x(self.in_shape)
|
||||||
self.activation = module_parameters.get('activation', None)
|
self.activation = module_parameters.get('activation', None)
|
||||||
@ -181,8 +183,9 @@ class ResidualModule(ShapeMixin, nn.Module):
|
|||||||
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
|
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
tensor = self.norm(x)
|
||||||
for module in self.residual_block:
|
for module in self.residual_block:
|
||||||
tensor = module(x)
|
tensor = module(tensor)
|
||||||
|
|
||||||
# noinspection PyUnboundLocalVariable
|
# noinspection PyUnboundLocalVariable
|
||||||
tensor = tensor + x
|
tensor = tensor + x
|
||||||
@ -208,3 +211,84 @@ class RecurrentModule(ShapeMixin, nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
tensor = self.rnn(x)
|
tensor = self.rnn(x)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionModule(ShapeMixin, nn.Module):
|
||||||
|
def __init__(self,in_shape, features, dropout=0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.in_shape = in_shape
|
||||||
|
self.dropout = dropout
|
||||||
|
self.features = features
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttentionModule(ShapeMixin, nn.Module):
|
||||||
|
def __init__(self, in_shape, heads, features, dropout=0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.in_shape = in_shape
|
||||||
|
|
||||||
|
self.features = features
|
||||||
|
self.heads = heads
|
||||||
|
self.final_dim = self.features // self.heads
|
||||||
|
|
||||||
|
self.linear_q = LinearModule(self.features, self.features)
|
||||||
|
self.linear_v = LinearModule(self.features, self.features)
|
||||||
|
self.linear_k = LinearModule(self.features, self.features)
|
||||||
|
self.dropout = nn.Dropout(dropout) if dropout else F_x(self.features)
|
||||||
|
self.linear_out = nn.Linear(self.features, self.features)
|
||||||
|
|
||||||
|
def forward(self, q, k, v, mask=None):
|
||||||
|
|
||||||
|
batch_size = q.size(0)
|
||||||
|
|
||||||
|
# perform linear operation and split into h heads
|
||||||
|
k = self.linear_k(k).view(batch_size, -1, self.heads, self.final_dim)
|
||||||
|
q = self.linear_q(q).view(batch_size, -1, self.heads, self.final_dim)
|
||||||
|
v = self.linear_v(v).view(batch_size, -1, self.heads, self.final_dim)
|
||||||
|
|
||||||
|
# transpose to get dimensions bs * h * sl * features
|
||||||
|
# ToDo: Do we need this?
|
||||||
|
|
||||||
|
k = k.transpose(1, 2)
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
|
||||||
|
# calculate attention
|
||||||
|
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.final_dim)
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask.unsqueeze(1)
|
||||||
|
scores = scores.masked_fill(mask == 0, -1e9)
|
||||||
|
scores = F.softmax(scores, dim=-1)
|
||||||
|
scores = self.dropout(scores)
|
||||||
|
scores = torch.matmul(scores, v)
|
||||||
|
|
||||||
|
# concatenate heads and apply final linear transformation
|
||||||
|
# ToDo: This seems to be old coding style. Do we Need this?
|
||||||
|
concat = scores.transpose(1, 2).contiguous().view(batch_size, -1, self.features)
|
||||||
|
|
||||||
|
output = self.out(concat)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerModule(ShapeMixin, nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_shape, hidden_size, n_heads, num_layers=1, dropout=None, use_norm=False, **kwargs):
|
||||||
|
super(TransformerModule, self).__init__()
|
||||||
|
|
||||||
|
self.in_shape = in_shape
|
||||||
|
|
||||||
|
self.flat = Flatten(self.in_shape) if isinstance(self.in_shape, (tuple, list)) else F_x(in_shape)
|
||||||
|
|
||||||
|
encoder_layer = nn.TransformerEncoderLayer(self.flat_shape, n_heads, dim_feedforward=hidden_size,
|
||||||
|
dropout=dropout, activation=kwargs.get('activation')
|
||||||
|
)
|
||||||
|
self.norm = nn.LayerNorm(hidden_size) if use_norm else F_x(hidden_size)
|
||||||
|
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers, )
|
||||||
|
|
||||||
|
def forward(self, x, mask=None, key_padding_mask=None):
|
||||||
|
tensor = self.flat(x)
|
||||||
|
tensor = self.transformer(tensor, mask, key_padding_mask)
|
||||||
|
return tensor
|
||||||
|
@ -11,7 +11,7 @@ from operator import mul
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from .blocks import ConvModule, DeConvModule, LinearModule
|
from .blocks import ConvModule, DeConvModule, LinearModule, MultiHeadAttentionModule
|
||||||
|
|
||||||
from .util import ShapeMixin, LightningBaseModule, Flatten
|
from .util import ShapeMixin, LightningBaseModule, Flatten
|
||||||
|
|
||||||
@ -25,7 +25,7 @@ class AEBaseModule(LightningBaseModule, ABC):
|
|||||||
assert bool(dataloader) ^ bool(lat_min and lat_max), 'Decide wether to give min, max or a dataloader, not both.'
|
assert bool(dataloader) ^ bool(lat_min and lat_max), 'Decide wether to give min, max or a dataloader, not both.'
|
||||||
|
|
||||||
min_max = self._find_min_max(dataloader) if dataloader else [None, None]
|
min_max = self._find_min_max(dataloader) if dataloader else [None, None]
|
||||||
# assert not any([x is None for x in min_max])
|
# assert not any([tensor is None for tensor in min_max])
|
||||||
lat_min = torch.as_tensor(lat_min or min_max[0])
|
lat_min = torch.as_tensor(lat_min or min_max[0])
|
||||||
lat_max = lat_max or min_max[1]
|
lat_max = lat_max or min_max[1]
|
||||||
|
|
||||||
@ -189,7 +189,7 @@ class BaseEncoder(ShapeMixin, nn.Module):
|
|||||||
|
|
||||||
# Optional Padding for odd image-sizes
|
# Optional Padding for odd image-sizes
|
||||||
# Obsolet, cdan be done by autopadding module on incoming tensors
|
# Obsolet, cdan be done by autopadding module on incoming tensors
|
||||||
# in_shape = [x+1 if x % 2 != 0 and idx else x for idx, x in enumerate(in_shape)]
|
# in_shape = [tensor+1 if tensor % 2 != 0 and idx else tensor for idx, tensor in enumerate(in_shape)]
|
||||||
|
|
||||||
# Parameters
|
# Parameters
|
||||||
self.lat_dim = lat_dim
|
self.lat_dim = lat_dim
|
||||||
@ -275,3 +275,16 @@ class Encoder(BaseEncoder):
|
|||||||
tensor = self.l1(tensor)
|
tensor = self.l1(tensor)
|
||||||
tensor = self.latent_activation(tensor) if self.latent_activation else tensor
|
tensor = self.latent_activation(tensor) if self.latent_activation else tensor
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoder(ShapeMixin, nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_shape):
|
||||||
|
super(TransformerEncoder, self).__init__()
|
||||||
|
# MultiheadSelfAttention
|
||||||
|
self.msa = MultiHeadAttentionModule()
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
|
||||||
|
116
modules/util.py
116
modules/util.py
@ -8,90 +8,93 @@ from operator import mul
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
|
||||||
|
|
||||||
|
|
||||||
# Utility - Modules
|
# Utility - Modules
|
||||||
###################
|
###################
|
||||||
from ..utils.model_io import ModelParameters
|
from ..utils.model_io import ModelParameters
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
|
||||||
class LightningBaseModule(pl.LightningModule, ABC):
|
class LightningBaseModule(pl.LightningModule, ABC):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def name(cls):
|
def name(cls):
|
||||||
return cls.__name__
|
return cls.__name__
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self):
|
def shape(self):
|
||||||
try:
|
try:
|
||||||
x = torch.randn(self.in_shape).unsqueeze(0)
|
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||||
output = self(x)
|
output = self(x)
|
||||||
return output.shape[1:]
|
return output.shape[1:]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
def __init__(self, hparams):
|
def __init__(self, hparams):
|
||||||
super(LightningBaseModule, self).__init__()
|
super(LightningBaseModule, self).__init__()
|
||||||
|
|
||||||
# Set Parameters
|
# Set Parameters
|
||||||
################################
|
################################
|
||||||
self.hparams = hparams
|
self.hparams = hparams
|
||||||
self.params = ModelParameters(hparams)
|
self.params = ModelParameters(hparams)
|
||||||
|
|
||||||
# Dataset Loading
|
def size(self):
|
||||||
################################
|
return self.shape
|
||||||
# TODO: Find a way to push Class Name, library path and parameters (sometimes those are objects) in here
|
|
||||||
|
|
||||||
def size(self):
|
def save_to_disk(self, model_path):
|
||||||
return self.shape
|
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
|
||||||
|
if not (model_path / 'model_class.obj').exists():
|
||||||
|
with (model_path / 'model_class.obj').open('wb') as f:
|
||||||
|
torch.save(self.__class__, f)
|
||||||
|
return True
|
||||||
|
|
||||||
def save_to_disk(self, model_path):
|
@property
|
||||||
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
|
def data_len(self):
|
||||||
if not (model_path / 'model_class.obj').exists():
|
return len(self.dataset.train_dataset)
|
||||||
with (model_path / 'model_class.obj').open('wb') as f:
|
|
||||||
torch.save(self.__class__, f)
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data_len(self):
|
def n_train_batches(self):
|
||||||
return len(self.dataset.train_dataset)
|
return len(self.train_dataloader())
|
||||||
|
|
||||||
@property
|
def configure_optimizers(self):
|
||||||
def n_train_batches(self):
|
raise NotImplementedError
|
||||||
return len(self.train_dataloader())
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def forward(self, *args, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
def test_step(self, *args, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def test_step(self, *args, **kwargs):
|
def test_epoch_end(self, outputs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def test_epoch_end(self, outputs):
|
def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_):
|
||||||
raise NotImplementedError
|
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
|
||||||
|
self.apply(weight_initializer)
|
||||||
|
|
||||||
def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_):
|
modules = [LightningBaseModule, nn.Module]
|
||||||
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
|
|
||||||
self.apply(weight_initializer)
|
except ImportError:
|
||||||
|
modules = [nn.Module, ]
|
||||||
|
pass # Maybe post a hint to install pytorch-lightning.
|
||||||
|
|
||||||
|
|
||||||
class ShapeMixin:
|
class ShapeMixin:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self):
|
def shape(self):
|
||||||
assert isinstance(self, (LightningBaseModule, nn.Module))
|
|
||||||
|
assert isinstance(self, modules)
|
||||||
|
|
||||||
def get_out_shape(output):
|
def get_out_shape(output):
|
||||||
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
|
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
|
||||||
|
|
||||||
if self.in_shape is not None:
|
in_shape = self.in_shape if hasattr(self, 'in_shape') else None
|
||||||
|
if in_shape is not None:
|
||||||
try:
|
try:
|
||||||
device = self.device
|
device = self.device
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
@ -99,10 +102,11 @@ class ShapeMixin:
|
|||||||
device = next(self.parameters()).device
|
device = next(self.parameters()).device
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
x = torch.randn(self.in_shape, device=device)
|
x = torch.randn(in_shape, device=device)
|
||||||
# This is needed for BatchNorm shape checking
|
# This is needed for BatchNorm shape checking
|
||||||
x = torch.stack((x, x))
|
x = torch.stack((x, x))
|
||||||
|
|
||||||
|
# noinspection PyCallingNonCallable
|
||||||
y = self(x)
|
y = self(x)
|
||||||
if isinstance(y, tuple):
|
if isinstance(y, tuple):
|
||||||
shape = tuple([get_out_shape(y[i]) for i in range(len(y))])
|
shape = tuple([get_out_shape(y[i]) for i in range(len(y))])
|
||||||
@ -265,7 +269,7 @@ class Splitter(nn.Module):
|
|||||||
self.autopad = AutoPadToShape(self._out_shape)
|
self.autopad = AutoPadToShape(self._out_shape)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
dim = self.dim + 1 if len(self.in_shape) == (x.ndim -1) else self.dim
|
dim = self.dim + 1 if len(self.in_shape) == (x.ndim - 1) else self.dim
|
||||||
x = x.transpose(0, dim)
|
x = x.transpose(0, dim)
|
||||||
n_blocks = list()
|
n_blocks = list()
|
||||||
for block_idx in range(self.n):
|
for block_idx in range(self.n):
|
||||||
|
@ -102,7 +102,7 @@ class Config(ConfigParser, ABC):
|
|||||||
# TODO: Do this programmatically; This did not work:
|
# TODO: Do this programmatically; This did not work:
|
||||||
# Initialize Default Sections as Property
|
# Initialize Default Sections as Property
|
||||||
# for section in self.default_sections:
|
# for section in self.default_sections:
|
||||||
# self.__setattr__(section, property(lambda x :x._get_namespace_for_section(section))
|
# self.__setattr__(section, property(lambda tensor :tensor._get_namespace_for_section(section))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def main(self):
|
def main(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user