CCS intergration dataloader

This commit is contained in:
Steffen
2021-03-19 17:17:16 +01:00
parent 6ace861016
commit d4059779c4
8 changed files with 213 additions and 35 deletions

View File

@@ -10,7 +10,7 @@ from einops import rearrange, repeat
from ml_lib.metrics.multi_class_classification import MultiClassScores
from ml_lib.modules.blocks import TransformerModule
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x)
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape)
from util.module_mixins import CombinedModelMixins
MIN_NUM_PATCHES = 16
@@ -25,7 +25,7 @@ class VisualTransformer(CombinedModelMixins,
use_bias, use_norm, dropout, lat_dim, loss, scheduler, mlp_dim, head_dim,
lr, weight_decay, sto_weight_avg, lr_scheduler_parameter, opt_reset_interval):
# TODO: Move this to parent class, or make it much easieer to access... But How...
# TODO: Move this to parent class, or make it much easier to access... But How...
a = dict(locals())
params = {arg: a[arg] for arg in inspect.signature(self.__init__).parameters.keys() if arg != 'self'}
super(VisualTransformer, self).__init__(params)
@@ -75,7 +75,7 @@ class VisualTransformer(CombinedModelMixins,
nn.Linear(self.embed_dim, self.params.lat_dim),
nn.GELU(),
nn.Dropout(self.params.dropout),
nn.Linear(self.params.lat_dim, n_classes),
nn.Linear(self.params.lat_dim, self.params.n_classes),
nn.Softmax()
)
@@ -88,7 +88,7 @@ class VisualTransformer(CombinedModelMixins,
tensor = self.autopad(x)
p = self.params.patch_size
tensor = rearrange(tensor, 'b c (h p1) (w p2) -> b (w h) (p1 p2 c)', p1=p, p2=p)
tensor = rearrange(tensor, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
b, n, _ = tensor.shape
# mask
@@ -96,7 +96,7 @@ class VisualTransformer(CombinedModelMixins,
mask = (lengths == torch.zeros_like(lengths))
# CLS-token awareness
# mask = torch.cat((torch.zeros(b, 1), mask), dim=-1)
# mask = repeat(mask, 'b n -> b n', h=self.params.heads)
# mask = repeat(mask, 'b n -> b h n', h=self.params.heads)
tensor = self.patch_to_embedding(tensor)