adjustment fot CCS

This commit is contained in:
Steffen Illium
2021-03-19 17:16:38 +01:00
parent 74a2603c79
commit 0a0a1cdcb5
5 changed files with 27 additions and 22 deletions

View File

@@ -14,10 +14,12 @@ class CNNBaseline(CombinedModelMixins,
LightningBaseModule
):
def __init__(self, in_shape, n_classes, weight_init, activation, use_bias, use_norm, dropout, lat_dim, features,
def __init__(self, in_shape, n_classes, weight_init, activation,
use_bias, use_norm, dropout, lat_dim, features,
filters,
lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval,
loss, scheduler):
loss, scheduler, lr_scheduler_parameter
):
# TODO: Move this to parent class, or make it much easieer to access....
a = dict(locals())

View File

@@ -21,7 +21,7 @@ class VisualTransformer(CombinedModelMixins,
):
def __init__(self, in_shape, n_classes, weight_init, activation,
embedding_size, heads, attn_depth, patch_size, use_residual,
embedding_size, heads, attn_depth, patch_size, use_residual, variable_length,
use_bias, use_norm, dropout, lat_dim, loss, scheduler, mlp_dim, head_dim,
lr, weight_decay, sto_weight_avg, lr_scheduler_parameter, opt_reset_interval):
@@ -88,15 +88,18 @@ class VisualTransformer(CombinedModelMixins,
tensor = self.autopad(x)
p = self.params.patch_size
tensor = rearrange(tensor, 'b c (h p1) (w p2) -> b (w h) (p1 p2 c)', p1=p, p2=p)
tensor = rearrange(tensor, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
b, n, _ = tensor.shape
# mask
lengths = torch.count_nonzero(tensor, dim=-1)
mask = (lengths == torch.zeros_like(lengths))
# CLS-token awareness
# mask = torch.cat((torch.zeros(b, 1), mask), dim=-1)
# mask = repeat(mask, 'b n -> b n', h=self.params.heads)
if self.params.variable_length and mask is None:
# mask
lengths = torch.count_nonzero(tensor, dim=-1)
mask = (lengths == torch.zeros_like(lengths))
# CLS-token awareness
# mask = torch.cat((torch.zeros(b, 1), mask), dim=-1)
# mask = repeat(mask, 'b n -> b n', h=self.params.heads)
else:
mask = mask
tensor = self.patch_to_embedding(tensor)