paper preperations and notebooks, optuna callbacks

This commit is contained in:
Steffen Illium
2021-04-02 08:45:11 +02:00
parent 7c88602776
commit cec3a07d60
21 changed files with 3818 additions and 1059 deletions

View File

@@ -73,9 +73,9 @@ class VisualTransformer(CombinedModelMixins,
logits = self.params.n_classes if self.params.n_classes > 2 else 1
if return_logits:
outbound_activation = nn.Identity()
outbound_activation = nn.Identity
else:
outbound_activation = nn.Softmax() if logits > 1 else nn.Sigmoid()
outbound_activation = nn.Softmax if logits > 1 else nn.Sigmoid
self.mlp_head = nn.Sequential(
@@ -84,7 +84,7 @@ class VisualTransformer(CombinedModelMixins,
self.params.activation(),
nn.Dropout(self.params.dropout),
nn.Linear(self.params.lat_dim, logits),
outbound_activation
outbound_activation()
)
def forward(self, x, mask=None, return_attn_weights=False):
@@ -128,8 +128,3 @@ class VisualTransformer(CombinedModelMixins,
tensor = self.mlp_head(tensor)
return Namespace(main_out=tensor, attn_weights=attn_weights)
def additional_scores(self, outputs):
if self.params.n_classes <= 2:
return BinaryScores(self)(outputs)
else:
return MultiClassScores(self)(outputs)