New Model, Many Changes

This commit is contained in:
Si11ium
2020-11-22 16:24:00 +01:00
parent be097a111a
commit a079a196af
3 changed files with 39 additions and 43 deletions

View File

@@ -5,8 +5,6 @@ import warnings
import torch
from torch import nn
from einops import repeat
from ml_lib.modules.blocks import TransformerModule
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
@@ -39,16 +37,13 @@ class SequentialVisualTransformer(BinaryMaskDatasetMixin,
self.embed_dim = self.params.embedding_size
self.patch_size = self.params.patch_size
self.height = height
self.width = width
self.channels = channels
# Automatic Image Shaping
image_size = (max(height, width) // self.patch_size) * self.patch_size
self.image_size = image_size + self.patch_size if image_size < max(height, width) else image_size
self.new_width = ((self.width - self.patch_size)//1) + 1
# This should be obsolete
assert self.image_size % self.patch_size == 0, 'image dimensions must be divisible by the patch size'
num_patches = (self.image_size // self.patch_size) ** 2
patch_dim = channels * self.patch_size * self.image_size
num_patches = self.new_width - (self.patch_size // 2)
patch_dim = channels * self.patch_size * self.height
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
f'attention. Try decreasing your patch size'
@@ -61,9 +56,9 @@ class SequentialVisualTransformer(BinaryMaskDatasetMixin,
func(message)
# Utility Modules
self.autopad = AutoPadToShape((self.image_size, self.image_size))
self.autopad = AutoPadToShape((self.height, self.new_width))
self.dropout = nn.Dropout(self.params.dropout)
self.slider = SlidingWindow((self.image_size, self.patch_size), keepdim=False)
self.slider = SlidingWindow((channels, *self.autopad.target_shape), (self.height, self.patch_size), keepdim=False)
# Modules with Parameters
self.transformer = TransformerModule(in_shape=self.embed_dim, hidden_size=self.params.lat_dim,