From 7bac9e984b28bc96c547e3492df278369f94956e Mon Sep 17 00:00:00 2001 From: Si11ium Date: Thu, 29 Oct 2020 16:40:43 +0100 Subject: [PATCH] Transformer Implementation --- _paramters.py | 2 +- models/transformer_model.py | 96 +++++++++++++++++++++++++++++++++++++ util/config.py | 4 +- variables.py | 3 ++ 4 files changed, 103 insertions(+), 2 deletions(-) create mode 100644 models/transformer_model.py diff --git a/_paramters.py b/_paramters.py index 3abe2ea..c119a9d 100644 --- a/_paramters.py +++ b/_paramters.py @@ -38,7 +38,7 @@ main_arg_parser.add_argument("--data_speed_min", type=float, default=0, help="") main_arg_parser.add_argument("--data_speed_max", type=float, default=0, help="") # 1.7 # Model Parameters -main_arg_parser.add_argument("--model_type", type=str, default="RCC", help="") +main_arg_parser.add_argument("--model_type", type=str, default="ViT", help="") main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="") main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="") main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="") diff --git a/models/transformer_model.py b/models/transformer_model.py new file mode 100644 index 0000000..6a36b6a --- /dev/null +++ b/models/transformer_model.py @@ -0,0 +1,96 @@ +import variables as V + +import torch +from torch import nn + +from ml_lib.modules.blocks import TransformerModule +from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape) +from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin, + BaseDataloadersMixin) + +MIN_NUM_PATCHES = 16 + + +class VisualTransformer(BinaryMaskDatasetMixin, + BaseDataloadersMixin, + BaseTrainMixin, + BaseValMixin, + BaseOptimizerMixin, + LightningBaseModule + ): + + def __init__(self, hparams): + super(VisualTransformer, self).__init__(hparams) + + self.in_shape = self.dataset.train_dataset.sample_shape + assert len(self.in_shape) == 3, 'There need to be three Dimensions' + channels, height, width = self.in_shape + + # Automatic Image Shaping + image_size = (max(height, width) // self.params.patch_size) * self.params.patch_size + self.image_size = image_size + self.params.patch_size if image_size < max(height, width) else image_size + + # This should be obsolete + assert self.image_size % self.params.patch_size == 0, 'image dimensions must be divisible by the patch size' + + num_patches = (self.image_size // self.params.patch_size) ** 2 + patch_dim = channels * self.params.patch_size ** 2 + assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' \ + f'attention. Try decreasing your patch size' + + # Dataset + # ============================================================================= + self.dataset = self.build_dataset() + + # Model Paramters + # ============================================================================= + # Additional parameters + self.attention_dim = self.params.features + + # Utility Modules + self.autopad = AutoPadToShape((self.image_size, self.image_size)) + + # Modules with Parameters + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.attention_dim), False) + self.embedding = nn.Linear(patch_dim, self.attention_dim) + self.cls_token = nn.Parameter(torch.randn(1, 1, self.attention_dim), False) + self.dropout = nn.Dropout(self.params.dropout) + + self.transformer = TransformerModule(self.attention_dim, self.params.attn_depth, self.params.heads, + self.params.lat_dim, self.params.dropout) + + self.to_cls_token = nn.Identity() + + self.mlp_head = nn.Sequential( + nn.LayerNorm(self.attention_dim), + nn.Linear(self.attention_dim, self.params.lat_dim), + nn.GELU(), + nn.Dropout(self.params.dropout), + nn.Linear(self.params.lat_dim, V.NUM_CLASSES) + ) + + def forward(self, x, mask=None): + """ + :param tensor: the sequence to the encoder (required). + :param mask: the mask for the src sequence (optional). + :return: + """ + + p = self.params.patch_size + # 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p + tensor = torch.reshape(x, (-1, self.image_size * self.image_size, p * p * self.in_shape[0])) + + tensor = self.patch_to_embedding(tensor) + b, n, _ = tensor.shape + + # '() n d -> b n d', b = b + cls_tokens = tensor.repeat(self.cls_token, b) + tensor = torch.cat((cls_tokens, tensor), dim=1) + tensor += self.pos_embedding[:, :(n + 1)] + tensor = self.dropout(tensor) + + tensor = self.transformer(tensor, mask) + + tensor = self.to_cls_token(tensor[:, 0]) + tensor = self.mlp_head(tensor) + return tensor \ No newline at end of file diff --git a/util/config.py b/util/config.py index 14c1ac4..3031d6c 100644 --- a/util/config.py +++ b/util/config.py @@ -4,6 +4,7 @@ from models.bandwise_conv_classifier import BandwiseConvClassifier from models.bandwise_conv_multihead_classifier import BandwiseConvMultiheadClassifier from models.ensemble import Ensemble from models.residual_conv_classifier import ResidualConvClassifier +from models.transformer_model import VisualTransformer class MConfig(Config): @@ -20,5 +21,6 @@ class MConfig(Config): Ensemble=Ensemble, E=Ensemble, ResidualConvClassifier=ResidualConvClassifier, - RCC=ResidualConvClassifier + RCC=ResidualConvClassifier, + ViT=VisualTransformer ) diff --git a/variables.py b/variables.py index 9d13bc2..9995fb9 100644 --- a/variables.py +++ b/variables.py @@ -4,5 +4,8 @@ from argparse import Namespace CLEAR = 0 MASK = 1 +NUM_CLASSES = 2 + + # Dataset Options DATA_OPTIONS = Namespace(test='test', devel='devel', train='train')