113 lines
4.7 KiB
Python
113 lines
4.7 KiB
Python
import inspect
|
|
from argparse import Namespace
|
|
|
|
import warnings
|
|
|
|
import torch
|
|
from einops import repeat
|
|
from torch import nn
|
|
|
|
from ml_lib.metrics.multi_class_classification import MultiClassScores
|
|
from ml_lib.modules.blocks import TransformerModule
|
|
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow)
|
|
from util.module_mixins import CombinedModelMixins
|
|
|
|
MIN_NUM_PATCHES = 16
|
|
|
|
|
|
class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
|
|
|
|
def __init__(self, in_shape, n_classes, weight_init, activation,
|
|
embedding_size, heads, attn_depth, patch_size, use_residual, variable_length,
|
|
use_bias, use_norm, dropout, lat_dim, loss, scheduler, mlp_dim, head_dim,
|
|
lr, weight_decay, sto_weight_avg, lr_scheduler_parameter, opt_reset_interval,
|
|
return_logits=False):
|
|
|
|
# TODO: Move this to parent class, or make it much easieer to access... But How...
|
|
a = dict(locals())
|
|
params = {arg: a[arg] for arg in inspect.signature(self.__init__).parameters.keys() if arg != 'self'}
|
|
super(VerticalVisualTransformer, self).__init__(params)
|
|
|
|
self.in_shape = in_shape
|
|
self.n_classes = n_classes
|
|
|
|
assert len(self.in_shape) == 3, 'There need to be three Dimensions'
|
|
channels, height, width = self.in_shape
|
|
|
|
# Model Paramters
|
|
# =============================================================================
|
|
# Additional parameters
|
|
self.embed_dim = self.params.embedding_size
|
|
self.height = height
|
|
self.channels = channels
|
|
|
|
self.new_width = ((width - self.params.patch_size)//1) + 1
|
|
|
|
num_patches = self.new_width - (self.params.patch_size // 2)
|
|
patch_dim = channels * self.params.patch_size * self.height
|
|
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
|
|
f'attention. Try decreasing your patch size'
|
|
|
|
# Utility Modules
|
|
self.autopad = AutoPadToShape((self.height, self.new_width))
|
|
self.dropout = nn.Dropout(self.params.dropout)
|
|
self.slider = SlidingWindow((channels, *self.autopad.target_shape), (self.height, self.params.patch_size),
|
|
keepdim=False)
|
|
|
|
# Modules with Parameters
|
|
self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.mlp_dim,
|
|
head_dim=self.params.head_dim,
|
|
heads=self.params.heads, depth=self.params.attn_depth,
|
|
dropout=self.params.dropout, use_norm=self.params.use_norm,
|
|
activation=self.params.activation, use_residual=self.params.use_residual
|
|
)
|
|
|
|
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
|
|
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
|
|
else F_x(self.embed_dim)
|
|
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
|
self.to_cls_token = nn.Identity()
|
|
|
|
logits = self.params.n_classes if self.params.n_classes > 2 else 1
|
|
|
|
outbound_activation = nn.Softmax if logits > 1 else nn.Sigmoid
|
|
|
|
self.mlp_head = nn.Sequential(
|
|
nn.LayerNorm(self.embed_dim),
|
|
nn.Linear(self.embed_dim, self.params.lat_dim),
|
|
self.params.activation(),
|
|
nn.Dropout(self.params.dropout),
|
|
nn.Linear(self.params.lat_dim, logits),
|
|
outbound_activation()
|
|
)
|
|
|
|
def forward(self, x, mask=None, return_attn_weights=False):
|
|
"""
|
|
:param x: the sequence to the encoder (required).
|
|
:param mask: the mask for the src sequence (optional).
|
|
:param return_attn_weights: wether to return the attn weights (optional)
|
|
:return:
|
|
"""
|
|
tensor = self.autopad(x)
|
|
tensor = self.slider(tensor)
|
|
|
|
tensor = self.patch_to_embedding(tensor)
|
|
b, n, _ = tensor.shape
|
|
|
|
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
|
|
|
|
tensor = torch.cat((cls_tokens, tensor), dim=1)
|
|
tensor += self.pos_embedding[:, :(n + 1)]
|
|
tensor = self.dropout(tensor)
|
|
|
|
if return_attn_weights:
|
|
tensor, attn_weights = self.transformer(tensor, mask, return_attn_weights)
|
|
else:
|
|
attn_weights = None
|
|
tensor = self.transformer(tensor, mask)
|
|
|
|
tensor = self.to_cls_token(tensor[:, 0])
|
|
tensor = self.mlp_head(tensor)
|
|
return Namespace(main_out=tensor, attn_weights=attn_weights)
|
|
|