paper preperations and notebooks, optuna callbacks
This commit is contained in:
@@ -18,9 +18,10 @@ MIN_NUM_PATCHES = 16
|
||||
class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
|
||||
|
||||
def __init__(self, in_shape, n_classes, weight_init, activation,
|
||||
embedding_size, heads, attn_depth, patch_size, use_residual,
|
||||
use_bias, use_norm, dropout, lat_dim, features, loss, scheduler,
|
||||
lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval):
|
||||
embedding_size, heads, attn_depth, patch_size, use_residual, variable_length,
|
||||
use_bias, use_norm, dropout, lat_dim, loss, scheduler, mlp_dim, head_dim,
|
||||
lr, weight_decay, sto_weight_avg, lr_scheduler_parameter, opt_reset_interval,
|
||||
return_logits=False):
|
||||
|
||||
# TODO: Move this to parent class, or make it much easieer to access... But How...
|
||||
a = dict(locals())
|
||||
@@ -47,14 +48,6 @@ class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
|
||||
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
|
||||
f'attention. Try decreasing your patch size'
|
||||
|
||||
# Correct the Embedding Dim
|
||||
if not self.embed_dim % self.params.heads == 0:
|
||||
self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
|
||||
message = ('Embedding Dimension was fixed to be devideable by the number' +
|
||||
f' of attention heads, is now: {self.embed_dim}')
|
||||
for func in print, warnings.warn:
|
||||
func(message)
|
||||
|
||||
# Utility Modules
|
||||
self.autopad = AutoPadToShape((self.height, self.new_width))
|
||||
self.dropout = nn.Dropout(self.params.dropout)
|
||||
@@ -62,10 +55,11 @@ class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
|
||||
keepdim=False)
|
||||
|
||||
# Modules with Parameters
|
||||
self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.lat_dim,
|
||||
self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.mlp_dim,
|
||||
head_dim=self.params.head_dim,
|
||||
heads=self.params.heads, depth=self.params.attn_depth,
|
||||
dropout=self.params.dropout, use_norm=self.params.use_norm,
|
||||
activation=self.params.activation
|
||||
activation=self.params.activation, use_residual=self.params.use_residual
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
|
||||
@@ -74,13 +68,17 @@ class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
||||
self.to_cls_token = nn.Identity()
|
||||
|
||||
logits = self.params.n_classes if self.params.n_classes > 2 else 1
|
||||
|
||||
outbound_activation = nn.Softmax if logits > 1 else nn.Sigmoid
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(self.embed_dim),
|
||||
nn.Linear(self.embed_dim, self.params.lat_dim),
|
||||
nn.GELU(),
|
||||
self.params.activation(),
|
||||
nn.Dropout(self.params.dropout),
|
||||
nn.Linear(self.params.lat_dim, self.n_classes),
|
||||
nn.Softmax()
|
||||
nn.Linear(self.params.lat_dim, logits),
|
||||
outbound_activation()
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None, return_attn_weights=False):
|
||||
@@ -112,5 +110,3 @@ class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
|
||||
tensor = self.mlp_head(tensor)
|
||||
return Namespace(main_out=tensor, attn_weights=attn_weights)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
return MultiClassScores(self)(outputs)
|
||||
|
||||
Reference in New Issue
Block a user