paper preperations and notebooks, optuna callbacks
This commit is contained in:
@@ -73,9 +73,9 @@ class VisualTransformer(CombinedModelMixins,
|
||||
logits = self.params.n_classes if self.params.n_classes > 2 else 1
|
||||
|
||||
if return_logits:
|
||||
outbound_activation = nn.Identity()
|
||||
outbound_activation = nn.Identity
|
||||
else:
|
||||
outbound_activation = nn.Softmax() if logits > 1 else nn.Sigmoid()
|
||||
outbound_activation = nn.Softmax if logits > 1 else nn.Sigmoid
|
||||
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
@@ -84,7 +84,7 @@ class VisualTransformer(CombinedModelMixins,
|
||||
self.params.activation(),
|
||||
nn.Dropout(self.params.dropout),
|
||||
nn.Linear(self.params.lat_dim, logits),
|
||||
outbound_activation
|
||||
outbound_activation()
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None, return_attn_weights=False):
|
||||
@@ -128,8 +128,3 @@ class VisualTransformer(CombinedModelMixins,
|
||||
tensor = self.mlp_head(tensor)
|
||||
return Namespace(main_out=tensor, attn_weights=attn_weights)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
if self.params.n_classes <= 2:
|
||||
return BinaryScores(self)(outputs)
|
||||
else:
|
||||
return MultiClassScores(self)(outputs)
|
||||
|
||||
Reference in New Issue
Block a user