Transformer Implementation
This commit is contained in:
parent
b57eabb371
commit
7bac9e984b
@ -38,7 +38,7 @@ main_arg_parser.add_argument("--data_speed_min", type=float, default=0, help="")
|
|||||||
main_arg_parser.add_argument("--data_speed_max", type=float, default=0, help="") # 1.7
|
main_arg_parser.add_argument("--data_speed_max", type=float, default=0, help="") # 1.7
|
||||||
|
|
||||||
# Model Parameters
|
# Model Parameters
|
||||||
main_arg_parser.add_argument("--model_type", type=str, default="RCC", help="")
|
main_arg_parser.add_argument("--model_type", type=str, default="ViT", help="")
|
||||||
main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="")
|
main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="")
|
||||||
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
|
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
|
||||||
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="")
|
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="")
|
||||||
|
96
models/transformer_model.py
Normal file
96
models/transformer_model.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
import variables as V
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from ml_lib.modules.blocks import TransformerModule
|
||||||
|
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape)
|
||||||
|
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||||
|
BaseDataloadersMixin)
|
||||||
|
|
||||||
|
MIN_NUM_PATCHES = 16
|
||||||
|
|
||||||
|
|
||||||
|
class VisualTransformer(BinaryMaskDatasetMixin,
|
||||||
|
BaseDataloadersMixin,
|
||||||
|
BaseTrainMixin,
|
||||||
|
BaseValMixin,
|
||||||
|
BaseOptimizerMixin,
|
||||||
|
LightningBaseModule
|
||||||
|
):
|
||||||
|
|
||||||
|
def __init__(self, hparams):
|
||||||
|
super(VisualTransformer, self).__init__(hparams)
|
||||||
|
|
||||||
|
self.in_shape = self.dataset.train_dataset.sample_shape
|
||||||
|
assert len(self.in_shape) == 3, 'There need to be three Dimensions'
|
||||||
|
channels, height, width = self.in_shape
|
||||||
|
|
||||||
|
# Automatic Image Shaping
|
||||||
|
image_size = (max(height, width) // self.params.patch_size) * self.params.patch_size
|
||||||
|
self.image_size = image_size + self.params.patch_size if image_size < max(height, width) else image_size
|
||||||
|
|
||||||
|
# This should be obsolete
|
||||||
|
assert self.image_size % self.params.patch_size == 0, 'image dimensions must be divisible by the patch size'
|
||||||
|
|
||||||
|
num_patches = (self.image_size // self.params.patch_size) ** 2
|
||||||
|
patch_dim = channels * self.params.patch_size ** 2
|
||||||
|
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' \
|
||||||
|
f'attention. Try decreasing your patch size'
|
||||||
|
|
||||||
|
# Dataset
|
||||||
|
# =============================================================================
|
||||||
|
self.dataset = self.build_dataset()
|
||||||
|
|
||||||
|
# Model Paramters
|
||||||
|
# =============================================================================
|
||||||
|
# Additional parameters
|
||||||
|
self.attention_dim = self.params.features
|
||||||
|
|
||||||
|
# Utility Modules
|
||||||
|
self.autopad = AutoPadToShape((self.image_size, self.image_size))
|
||||||
|
|
||||||
|
# Modules with Parameters
|
||||||
|
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.attention_dim), False)
|
||||||
|
self.embedding = nn.Linear(patch_dim, self.attention_dim)
|
||||||
|
self.cls_token = nn.Parameter(torch.randn(1, 1, self.attention_dim), False)
|
||||||
|
self.dropout = nn.Dropout(self.params.dropout)
|
||||||
|
|
||||||
|
self.transformer = TransformerModule(self.attention_dim, self.params.attn_depth, self.params.heads,
|
||||||
|
self.params.lat_dim, self.params.dropout)
|
||||||
|
|
||||||
|
self.to_cls_token = nn.Identity()
|
||||||
|
|
||||||
|
self.mlp_head = nn.Sequential(
|
||||||
|
nn.LayerNorm(self.attention_dim),
|
||||||
|
nn.Linear(self.attention_dim, self.params.lat_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(self.params.dropout),
|
||||||
|
nn.Linear(self.params.lat_dim, V.NUM_CLASSES)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
"""
|
||||||
|
:param tensor: the sequence to the encoder (required).
|
||||||
|
:param mask: the mask for the src sequence (optional).
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
p = self.params.patch_size
|
||||||
|
# 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p
|
||||||
|
tensor = torch.reshape(x, (-1, self.image_size * self.image_size, p * p * self.in_shape[0]))
|
||||||
|
|
||||||
|
tensor = self.patch_to_embedding(tensor)
|
||||||
|
b, n, _ = tensor.shape
|
||||||
|
|
||||||
|
# '() n d -> b n d', b = b
|
||||||
|
cls_tokens = tensor.repeat(self.cls_token, b)
|
||||||
|
tensor = torch.cat((cls_tokens, tensor), dim=1)
|
||||||
|
tensor += self.pos_embedding[:, :(n + 1)]
|
||||||
|
tensor = self.dropout(tensor)
|
||||||
|
|
||||||
|
tensor = self.transformer(tensor, mask)
|
||||||
|
|
||||||
|
tensor = self.to_cls_token(tensor[:, 0])
|
||||||
|
tensor = self.mlp_head(tensor)
|
||||||
|
return tensor
|
@ -4,6 +4,7 @@ from models.bandwise_conv_classifier import BandwiseConvClassifier
|
|||||||
from models.bandwise_conv_multihead_classifier import BandwiseConvMultiheadClassifier
|
from models.bandwise_conv_multihead_classifier import BandwiseConvMultiheadClassifier
|
||||||
from models.ensemble import Ensemble
|
from models.ensemble import Ensemble
|
||||||
from models.residual_conv_classifier import ResidualConvClassifier
|
from models.residual_conv_classifier import ResidualConvClassifier
|
||||||
|
from models.transformer_model import VisualTransformer
|
||||||
|
|
||||||
|
|
||||||
class MConfig(Config):
|
class MConfig(Config):
|
||||||
@ -20,5 +21,6 @@ class MConfig(Config):
|
|||||||
Ensemble=Ensemble,
|
Ensemble=Ensemble,
|
||||||
E=Ensemble,
|
E=Ensemble,
|
||||||
ResidualConvClassifier=ResidualConvClassifier,
|
ResidualConvClassifier=ResidualConvClassifier,
|
||||||
RCC=ResidualConvClassifier
|
RCC=ResidualConvClassifier,
|
||||||
|
ViT=VisualTransformer
|
||||||
)
|
)
|
||||||
|
@ -4,5 +4,8 @@ from argparse import Namespace
|
|||||||
CLEAR = 0
|
CLEAR = 0
|
||||||
MASK = 1
|
MASK = 1
|
||||||
|
|
||||||
|
NUM_CLASSES = 2
|
||||||
|
|
||||||
|
|
||||||
# Dataset Options
|
# Dataset Options
|
||||||
DATA_OPTIONS = Namespace(test='test', devel='devel', train='train')
|
DATA_OPTIONS = Namespace(test='test', devel='devel', train='train')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user