CCS intergration dataloader
This commit is contained in:
@@ -10,7 +10,7 @@ from einops import rearrange, repeat
|
||||
|
||||
from ml_lib.metrics.multi_class_classification import MultiClassScores
|
||||
from ml_lib.modules.blocks import TransformerModule
|
||||
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x)
|
||||
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape)
|
||||
from util.module_mixins import CombinedModelMixins
|
||||
|
||||
MIN_NUM_PATCHES = 16
|
||||
@@ -25,7 +25,7 @@ class VisualTransformer(CombinedModelMixins,
|
||||
use_bias, use_norm, dropout, lat_dim, loss, scheduler, mlp_dim, head_dim,
|
||||
lr, weight_decay, sto_weight_avg, lr_scheduler_parameter, opt_reset_interval):
|
||||
|
||||
# TODO: Move this to parent class, or make it much easieer to access... But How...
|
||||
# TODO: Move this to parent class, or make it much easier to access... But How...
|
||||
a = dict(locals())
|
||||
params = {arg: a[arg] for arg in inspect.signature(self.__init__).parameters.keys() if arg != 'self'}
|
||||
super(VisualTransformer, self).__init__(params)
|
||||
@@ -75,7 +75,7 @@ class VisualTransformer(CombinedModelMixins,
|
||||
nn.Linear(self.embed_dim, self.params.lat_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(self.params.dropout),
|
||||
nn.Linear(self.params.lat_dim, n_classes),
|
||||
nn.Linear(self.params.lat_dim, self.params.n_classes),
|
||||
nn.Softmax()
|
||||
)
|
||||
|
||||
@@ -88,7 +88,7 @@ class VisualTransformer(CombinedModelMixins,
|
||||
tensor = self.autopad(x)
|
||||
p = self.params.patch_size
|
||||
|
||||
tensor = rearrange(tensor, 'b c (h p1) (w p2) -> b (w h) (p1 p2 c)', p1=p, p2=p)
|
||||
tensor = rearrange(tensor, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
|
||||
b, n, _ = tensor.shape
|
||||
|
||||
# mask
|
||||
@@ -96,7 +96,7 @@ class VisualTransformer(CombinedModelMixins,
|
||||
mask = (lengths == torch.zeros_like(lengths))
|
||||
# CLS-token awareness
|
||||
# mask = torch.cat((torch.zeros(b, 1), mask), dim=-1)
|
||||
# mask = repeat(mask, 'b n -> b n', h=self.params.heads)
|
||||
# mask = repeat(mask, 'b n -> b h n', h=self.params.heads)
|
||||
|
||||
tensor = self.patch_to_embedding(tensor)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user