From 13812b83b5a7a37b2a3cb82cc367af5cba01f63e Mon Sep 17 00:00:00 2001 From: Si11ium Date: Thu, 29 Oct 2020 16:40:43 +0100 Subject: [PATCH] Transformer Implementation --- audio_toolset/audio_io.py | 2 +- modules/blocks.py | 94 ++++++++++++++++++++++++++++-- modules/model_parts.py | 19 ++++++- modules/util.py | 116 ++++++++++++++++++++------------------ utils/config.py | 2 +- 5 files changed, 167 insertions(+), 66 deletions(-) diff --git a/audio_toolset/audio_io.py b/audio_toolset/audio_io.py index e0e02a2..b57038c 100644 --- a/audio_toolset/audio_io.py +++ b/audio_toolset/audio_io.py @@ -52,7 +52,7 @@ class NormalizeLocal(object): std = x.std() + 0.0001 # Pytorch Version: - # x = x.__sub__(mean).__div__(std) + # tensor = tensor.__sub__(mean).__div__(std) # Numpy Version x = (x - mean) / std x[np.isnan(x)] = 0 diff --git a/modules/blocks.py b/modules/blocks.py index cb85886..df17d2d 100644 --- a/modules/blocks.py +++ b/modules/blocks.py @@ -1,3 +1,5 @@ +import math + from pathlib import Path from typing import Union @@ -142,8 +144,8 @@ class DeConvModule(ShapeMixin, nn.Module): self.autopad = AutoPad() if autopad else lambda x: x self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x - self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if norm else lambda x: x - self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x + self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if norm else F_x(self.in_shape) + self.dropout = nn.Dropout2d(dropout) if dropout else F_x(self.in_shape) self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, self.conv_kernel, bias=bias, padding=self.padding, stride=self.stride) @@ -168,8 +170,8 @@ class ResidualModule(ShapeMixin, nn.Module): self.in_shape = in_shape module_parameters.update(in_shape=in_shape) if norm: - self.norm = nn.BatchNorm1d if len(self.in_shape) <= 2 else nn.BatchNorm2d - self.norm = self.norm(self.in_shape if isinstance(self.in_shape, int) else self.in_shape[0]) + norm = nn.BatchNorm1d if len(self.in_shape) <= 2 else nn.BatchNorm2d + self.norm = norm(self.in_shape if isinstance(self.in_shape, int) else self.in_shape[0]) else: self.norm = F_x(self.in_shape) self.activation = module_parameters.get('activation', None) @@ -181,8 +183,9 @@ class ResidualModule(ShapeMixin, nn.Module): assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.' def forward(self, x): + tensor = self.norm(x) for module in self.residual_block: - tensor = module(x) + tensor = module(tensor) # noinspection PyUnboundLocalVariable tensor = tensor + x @@ -208,3 +211,84 @@ class RecurrentModule(ShapeMixin, nn.Module): def forward(self, x): tensor = self.rnn(x) return tensor + + +class AttentionModule(ShapeMixin, nn.Module): + def __init__(self,in_shape, features, dropout=0.1): + super().__init__() + self.in_shape = in_shape + self.dropout = dropout + self.features = features + raise NotImplementedError + + def forward(self, x): + pass + + +class MultiHeadAttentionModule(ShapeMixin, nn.Module): + def __init__(self, in_shape, heads, features, dropout=0.1): + super().__init__() + self.in_shape = in_shape + + self.features = features + self.heads = heads + self.final_dim = self.features // self.heads + + self.linear_q = LinearModule(self.features, self.features) + self.linear_v = LinearModule(self.features, self.features) + self.linear_k = LinearModule(self.features, self.features) + self.dropout = nn.Dropout(dropout) if dropout else F_x(self.features) + self.linear_out = nn.Linear(self.features, self.features) + + def forward(self, q, k, v, mask=None): + + batch_size = q.size(0) + + # perform linear operation and split into h heads + k = self.linear_k(k).view(batch_size, -1, self.heads, self.final_dim) + q = self.linear_q(q).view(batch_size, -1, self.heads, self.final_dim) + v = self.linear_v(v).view(batch_size, -1, self.heads, self.final_dim) + + # transpose to get dimensions bs * h * sl * features + # ToDo: Do we need this? + + k = k.transpose(1, 2) + q = q.transpose(1, 2) + v = v.transpose(1, 2) + + # calculate attention + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.final_dim) + if mask is not None: + mask = mask.unsqueeze(1) + scores = scores.masked_fill(mask == 0, -1e9) + scores = F.softmax(scores, dim=-1) + scores = self.dropout(scores) + scores = torch.matmul(scores, v) + + # concatenate heads and apply final linear transformation + # ToDo: This seems to be old coding style. Do we Need this? + concat = scores.transpose(1, 2).contiguous().view(batch_size, -1, self.features) + + output = self.out(concat) + return output + + +class TransformerModule(ShapeMixin, nn.Module): + + def __init__(self, in_shape, hidden_size, n_heads, num_layers=1, dropout=None, use_norm=False, **kwargs): + super(TransformerModule, self).__init__() + + self.in_shape = in_shape + + self.flat = Flatten(self.in_shape) if isinstance(self.in_shape, (tuple, list)) else F_x(in_shape) + + encoder_layer = nn.TransformerEncoderLayer(self.flat_shape, n_heads, dim_feedforward=hidden_size, + dropout=dropout, activation=kwargs.get('activation') + ) + self.norm = nn.LayerNorm(hidden_size) if use_norm else F_x(hidden_size) + self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers, ) + + def forward(self, x, mask=None, key_padding_mask=None): + tensor = self.flat(x) + tensor = self.transformer(tensor, mask, key_padding_mask) + return tensor diff --git a/modules/model_parts.py b/modules/model_parts.py index a89dbc4..a837da4 100644 --- a/modules/model_parts.py +++ b/modules/model_parts.py @@ -11,7 +11,7 @@ from operator import mul from torch import nn from torch.utils.data import DataLoader -from .blocks import ConvModule, DeConvModule, LinearModule +from .blocks import ConvModule, DeConvModule, LinearModule, MultiHeadAttentionModule from .util import ShapeMixin, LightningBaseModule, Flatten @@ -25,7 +25,7 @@ class AEBaseModule(LightningBaseModule, ABC): assert bool(dataloader) ^ bool(lat_min and lat_max), 'Decide wether to give min, max or a dataloader, not both.' min_max = self._find_min_max(dataloader) if dataloader else [None, None] - # assert not any([x is None for x in min_max]) + # assert not any([tensor is None for tensor in min_max]) lat_min = torch.as_tensor(lat_min or min_max[0]) lat_max = lat_max or min_max[1] @@ -189,7 +189,7 @@ class BaseEncoder(ShapeMixin, nn.Module): # Optional Padding for odd image-sizes # Obsolet, cdan be done by autopadding module on incoming tensors - # in_shape = [x+1 if x % 2 != 0 and idx else x for idx, x in enumerate(in_shape)] + # in_shape = [tensor+1 if tensor % 2 != 0 and idx else tensor for idx, tensor in enumerate(in_shape)] # Parameters self.lat_dim = lat_dim @@ -275,3 +275,16 @@ class Encoder(BaseEncoder): tensor = self.l1(tensor) tensor = self.latent_activation(tensor) if self.latent_activation else tensor return tensor + + +class TransformerEncoder(ShapeMixin, nn.Module): + + def __init__(self, in_shape): + super(TransformerEncoder, self).__init__() + # MultiheadSelfAttention + self.msa = MultiHeadAttentionModule() + + + def forward(self, x): + + diff --git a/modules/util.py b/modules/util.py index 3eb96ba..020114a 100644 --- a/modules/util.py +++ b/modules/util.py @@ -8,90 +8,93 @@ from operator import mul from torch import nn from torch.nn import functional as F -import pytorch_lightning as pl - - # Utility - Modules ################### from ..utils.model_io import ModelParameters +try: + import pytorch_lightning as pl -class LightningBaseModule(pl.LightningModule, ABC): + class LightningBaseModule(pl.LightningModule, ABC): - @classmethod - def name(cls): - return cls.__name__ + @classmethod + def name(cls): + return cls.__name__ - @property - def shape(self): - try: - x = torch.randn(self.in_shape).unsqueeze(0) - output = self(x) - return output.shape[1:] - except Exception as e: - print(e) - return -1 + @property + def shape(self): + try: + x = torch.randn(self.in_shape).unsqueeze(0) + output = self(x) + return output.shape[1:] + except Exception as e: + print(e) + return -1 - def __init__(self, hparams): - super(LightningBaseModule, self).__init__() + def __init__(self, hparams): + super(LightningBaseModule, self).__init__() - # Set Parameters - ################################ - self.hparams = hparams - self.params = ModelParameters(hparams) + # Set Parameters + ################################ + self.hparams = hparams + self.params = ModelParameters(hparams) - # Dataset Loading - ################################ - # TODO: Find a way to push Class Name, library path and parameters (sometimes those are objects) in here + def size(self): + return self.shape - def size(self): - return self.shape + def save_to_disk(self, model_path): + Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True) + if not (model_path / 'model_class.obj').exists(): + with (model_path / 'model_class.obj').open('wb') as f: + torch.save(self.__class__, f) + return True - def save_to_disk(self, model_path): - Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True) - if not (model_path / 'model_class.obj').exists(): - with (model_path / 'model_class.obj').open('wb') as f: - torch.save(self.__class__, f) - return True + @property + def data_len(self): + return len(self.dataset.train_dataset) - @property - def data_len(self): - return len(self.dataset.train_dataset) + @property + def n_train_batches(self): + return len(self.train_dataloader()) - @property - def n_train_batches(self): - return len(self.train_dataloader()) + def configure_optimizers(self): + raise NotImplementedError - def configure_optimizers(self): - raise NotImplementedError + def forward(self, *args, **kwargs): + raise NotImplementedError - def forward(self, *args, **kwargs): - raise NotImplementedError + def training_step(self, batch_xy, batch_nb, *args, **kwargs): + raise NotImplementedError - def training_step(self, batch_xy, batch_nb, *args, **kwargs): - raise NotImplementedError + def test_step(self, *args, **kwargs): + raise NotImplementedError - def test_step(self, *args, **kwargs): - raise NotImplementedError + def test_epoch_end(self, outputs): + raise NotImplementedError - def test_epoch_end(self, outputs): - raise NotImplementedError + def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_): + weight_initializer = WeightInit(in_place_init_function=in_place_init_func_) + self.apply(weight_initializer) - def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_): - weight_initializer = WeightInit(in_place_init_function=in_place_init_func_) - self.apply(weight_initializer) + modules = [LightningBaseModule, nn.Module] + +except ImportError: + modules = [nn.Module, ] + pass # Maybe post a hint to install pytorch-lightning. class ShapeMixin: @property def shape(self): - assert isinstance(self, (LightningBaseModule, nn.Module)) + + assert isinstance(self, modules) def get_out_shape(output): return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1] - if self.in_shape is not None: + in_shape = self.in_shape if hasattr(self, 'in_shape') else None + if in_shape is not None: try: device = self.device except AttributeError: @@ -99,10 +102,11 @@ class ShapeMixin: device = next(self.parameters()).device except StopIteration: device = 'cuda' if torch.cuda.is_available() else 'cpu' - x = torch.randn(self.in_shape, device=device) + x = torch.randn(in_shape, device=device) # This is needed for BatchNorm shape checking x = torch.stack((x, x)) + # noinspection PyCallingNonCallable y = self(x) if isinstance(y, tuple): shape = tuple([get_out_shape(y[i]) for i in range(len(y))]) @@ -265,7 +269,7 @@ class Splitter(nn.Module): self.autopad = AutoPadToShape(self._out_shape) def forward(self, x: torch.Tensor): - dim = self.dim + 1 if len(self.in_shape) == (x.ndim -1) else self.dim + dim = self.dim + 1 if len(self.in_shape) == (x.ndim - 1) else self.dim x = x.transpose(0, dim) n_blocks = list() for block_idx in range(self.n): diff --git a/utils/config.py b/utils/config.py index b8bb63c..040c231 100644 --- a/utils/config.py +++ b/utils/config.py @@ -102,7 +102,7 @@ class Config(ConfigParser, ABC): # TODO: Do this programmatically; This did not work: # Initialize Default Sections as Property # for section in self.default_sections: - # self.__setattr__(section, property(lambda x :x._get_namespace_for_section(section)) + # self.__setattr__(section, property(lambda tensor :tensor._get_namespace_for_section(section)) @property def main(self):