masks_augments_compare-21/models/transformer_model.py
2020-10-29 16:40:43 +01:00

96 lines
3.8 KiB
Python

import variables as V
import torch
from torch import nn
from ml_lib.modules.blocks import TransformerModule
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
BaseDataloadersMixin)
MIN_NUM_PATCHES = 16
class VisualTransformer(BinaryMaskDatasetMixin,
BaseDataloadersMixin,
BaseTrainMixin,
BaseValMixin,
BaseOptimizerMixin,
LightningBaseModule
):
def __init__(self, hparams):
super(VisualTransformer, self).__init__(hparams)
self.in_shape = self.dataset.train_dataset.sample_shape
assert len(self.in_shape) == 3, 'There need to be three Dimensions'
channels, height, width = self.in_shape
# Automatic Image Shaping
image_size = (max(height, width) // self.params.patch_size) * self.params.patch_size
self.image_size = image_size + self.params.patch_size if image_size < max(height, width) else image_size
# This should be obsolete
assert self.image_size % self.params.patch_size == 0, 'image dimensions must be divisible by the patch size'
num_patches = (self.image_size // self.params.patch_size) ** 2
patch_dim = channels * self.params.patch_size ** 2
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' \
f'attention. Try decreasing your patch size'
# Dataset
# =============================================================================
self.dataset = self.build_dataset()
# Model Paramters
# =============================================================================
# Additional parameters
self.attention_dim = self.params.features
# Utility Modules
self.autopad = AutoPadToShape((self.image_size, self.image_size))
# Modules with Parameters
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.attention_dim), False)
self.embedding = nn.Linear(patch_dim, self.attention_dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, self.attention_dim), False)
self.dropout = nn.Dropout(self.params.dropout)
self.transformer = TransformerModule(self.attention_dim, self.params.attn_depth, self.params.heads,
self.params.lat_dim, self.params.dropout)
self.to_cls_token = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(self.attention_dim),
nn.Linear(self.attention_dim, self.params.lat_dim),
nn.GELU(),
nn.Dropout(self.params.dropout),
nn.Linear(self.params.lat_dim, V.NUM_CLASSES)
)
def forward(self, x, mask=None):
"""
:param tensor: the sequence to the encoder (required).
:param mask: the mask for the src sequence (optional).
:return:
"""
p = self.params.patch_size
# 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p
tensor = torch.reshape(x, (-1, self.image_size * self.image_size, p * p * self.in_shape[0]))
tensor = self.patch_to_embedding(tensor)
b, n, _ = tensor.shape
# '() n d -> b n d', b = b
cls_tokens = tensor.repeat(self.cls_token, b)
tensor = torch.cat((cls_tokens, tensor), dim=1)
tensor += self.pos_embedding[:, :(n + 1)]
tensor = self.dropout(tensor)
tensor = self.transformer(tensor, mask)
tensor = self.to_cls_token(tensor[:, 0])
tensor = self.mlp_head(tensor)
return tensor