bugs fixed, binary datasets working

This commit is contained in:
Steffen
2021-03-27 18:23:51 +01:00
parent 37e36df0a8
commit 76eb567eed
4 changed files with 33 additions and 44 deletions

View File

@ -8,6 +8,7 @@ from torch import nn
from einops import rearrange, repeat
from ml_lib.metrics.binary_class_classifictaion import BinaryScores
from ml_lib.metrics.multi_class_classification import MultiClassScores
from ml_lib.modules.blocks import TransformerModule
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x)
@ -64,11 +65,6 @@ class Tester(CombinedModelMixins,
self.autopad = AutoPadToShape((self.image_size, self.image_size))
# Modules with Parameters
self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.mlp_dim,
heads=self.params.heads, depth=self.params.attn_depth,
dropout=self.params.dropout, use_norm=self.params.use_norm,
activation=self.params.activation, use_residual=self.params.use_residual
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
@ -78,13 +74,17 @@ class Tester(CombinedModelMixins,
self.to_cls_token = nn.Identity()
logits = self.params.n_classes if self.params.n_classes > 2 else 1
outbound_activation = nn.Softmax if logits > 1 else nn.Sigmoid
self.mlp_head = nn.Sequential(
nn.LayerNorm(self.embed_dim),
nn.Linear(self.embed_dim, self.params.lat_dim),
nn.GELU(),
nn.Dropout(self.params.dropout),
nn.Linear(self.params.lat_dim, n_classes),
nn.Softmax()
nn.Linear(self.params.lat_dim, logits),
outbound_activation()
)
def forward(self, x, mask=None, return_attn_weights=False):
@ -106,15 +106,12 @@ class Tester(CombinedModelMixins,
tensor += self.pos_embedding[:, :(n + 1)]
tensor = self.dropout(tensor)
if return_attn_weights:
tensor, attn_weights = self.transformer(tensor, mask, return_attn_weights)
else:
attn_weights = None
tensor = self.transformer(tensor, mask)
tensor = self.to_cls_token(tensor[:, 0])
tensor = self.mlp_head(tensor)
return Namespace(main_out=tensor, attn_weights=attn_weights)
return Namespace(main_out=tensor, attn_weights=None)
def additional_scores(self, outputs):
return MultiClassScores(self)(outputs)
if self.params.n_classes > 2:
return MultiClassScores(self)(outputs)
else:
return BinaryScores(self)(outputs)