variable mask size, beter image shapes
This commit is contained in:
@@ -41,27 +41,19 @@ class VisualTransformer(CombinedModelMixins,
|
||||
|
||||
# Automatic Image Shaping
|
||||
self.patch_size = self.params.patch_size
|
||||
image_size = (max(height, width) // self.patch_size) * self.patch_size
|
||||
self.image_size = image_size + self.patch_size if image_size < max(height, width) else image_size
|
||||
|
||||
# This should be obsolete
|
||||
assert self.image_size % self.patch_size == 0, 'image dimensions must be divisible by the patch size'
|
||||
new_height = (height // self.patch_size) * self.patch_size
|
||||
new_height = new_height + self.patch_size if new_height < height else new_height
|
||||
new_width = (width // self.patch_size) * self.patch_size
|
||||
new_width = new_width + self.patch_size if new_width < width else new_width
|
||||
|
||||
num_patches = (self.image_size // self.patch_size) ** 2
|
||||
num_patches = (new_height // self.patch_size) * (new_width // self.patch_size)
|
||||
patch_dim = channels * self.patch_size ** 2
|
||||
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
|
||||
f'attention. Try decreasing your patch size'
|
||||
|
||||
# Correct the Embedding Dim
|
||||
#if not self.embed_dim % self.params.heads == 0:
|
||||
# self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
|
||||
# message = ('Embedding Dimension was fixed to be devideable by the number' +
|
||||
# f' of attention heads, is now: {self.embed_dim}')
|
||||
# for func in print, warnings.warn:
|
||||
# func(message)
|
||||
|
||||
# Utility Modules
|
||||
self.autopad = AutoPadToShape((self.image_size, self.image_size))
|
||||
self.autopad = AutoPadToShape((new_height, new_width))
|
||||
|
||||
# Modules with Parameters
|
||||
self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.mlp_dim,
|
||||
@@ -95,14 +87,23 @@ class VisualTransformer(CombinedModelMixins,
|
||||
"""
|
||||
tensor = self.autopad(x)
|
||||
p = self.params.patch_size
|
||||
tensor = rearrange(tensor, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
|
||||
|
||||
tensor = rearrange(tensor, 'b c (h p1) (w p2) -> b (w h) (p1 p2 c)', p1=p, p2=p)
|
||||
b, n, _ = tensor.shape
|
||||
|
||||
# mask
|
||||
lengths = torch.count_nonzero(tensor, dim=-1)
|
||||
mask = (lengths == torch.zeros_like(lengths))
|
||||
# CLS-token awareness
|
||||
# mask = torch.cat((torch.zeros(b, 1), mask), dim=-1)
|
||||
# mask = repeat(mask, 'b n -> b n', h=self.params.heads)
|
||||
|
||||
tensor = self.patch_to_embedding(tensor)
|
||||
b, n, _ = tensor.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
|
||||
|
||||
tensor = torch.cat((cls_tokens, tensor), dim=1)
|
||||
|
||||
tensor += self.pos_embedding[:, :(n + 1)]
|
||||
tensor = self.dropout(tensor)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user