masks_augments_compare-21/models/transformer_model_sequential.py
2020-11-22 16:24:00 +01:00

110 lines
4.4 KiB
Python

from argparse import Namespace
import warnings
import torch
from torch import nn
from ml_lib.modules.blocks import TransformerModule
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
BaseDataloadersMixin)
MIN_NUM_PATCHES = 16
class SequentialVisualTransformer(BinaryMaskDatasetMixin,
BaseDataloadersMixin,
BaseTrainMixin,
BaseValMixin,
BaseOptimizerMixin,
LightningBaseModule
):
def __init__(self, hparams):
super(SequentialVisualTransformer, self).__init__(hparams)
# Dataset
# =============================================================================
self.dataset = self.build_dataset()
self.in_shape = self.dataset.train_dataset.sample_shape
assert len(self.in_shape) == 3, 'There need to be three Dimensions'
channels, height, width = self.in_shape
# Model Paramters
# =============================================================================
# Additional parameters
self.embed_dim = self.params.embedding_size
self.patch_size = self.params.patch_size
self.height = height
self.width = width
self.channels = channels
self.new_width = ((self.width - self.patch_size)//1) + 1
num_patches = self.new_width - (self.patch_size // 2)
patch_dim = channels * self.patch_size * self.height
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
f'attention. Try decreasing your patch size'
# Correct the Embedding Dim
if not self.embed_dim % self.params.heads == 0:
self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
message = ('Embedding Dimension was fixed to be devideable by the number' +
f' of attention heads, is now: {self.embed_dim}')
for func in print, warnings.warn:
func(message)
# Utility Modules
self.autopad = AutoPadToShape((self.height, self.new_width))
self.dropout = nn.Dropout(self.params.dropout)
self.slider = SlidingWindow((channels, *self.autopad.target_shape), (self.height, self.patch_size), keepdim=False)
# Modules with Parameters
self.transformer = TransformerModule(in_shape=self.embed_dim, hidden_size=self.params.lat_dim,
n_heads=self.params.heads, num_layers=self.params.attn_depth,
dropout=self.params.dropout, use_norm=self.params.use_norm,
activation=self.params.activation_as_string
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
else F_x(self.embed_dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
self.to_cls_token = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(self.embed_dim),
nn.Linear(self.embed_dim, self.params.lat_dim),
nn.GELU(),
nn.Dropout(self.params.dropout),
nn.Linear(self.params.lat_dim, 1),
nn.Sigmoid()
)
def forward(self, x, mask=None):
"""
:param x: the sequence to the encoder (required).
:param mask: the mask for the src sequence (optional).
:return:
"""
tensor = self.autopad(x)
tensor = self.slider(tensor)
tensor = self.patch_to_embedding(tensor)
b, n, _ = tensor.shape
# cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
cls_tokens = self.cls_token.repeat((b, 1, 1))
tensor = torch.cat((cls_tokens, tensor), dim=1)
tensor += self.pos_embedding[:, :(n + 1)]
tensor = self.dropout(tensor)
tensor = self.transformer(tensor, mask)
tensor = self.to_cls_token(tensor[:, 0])
tensor = self.mlp_head(tensor)
return Namespace(main_out=tensor)