bugs fixed, binary datasets working
This commit is contained in:
@ -8,6 +8,7 @@ from torch import nn
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from ml_lib.metrics.binary_class_classifictaion import BinaryScores
|
||||
from ml_lib.metrics.multi_class_classification import MultiClassScores
|
||||
from ml_lib.modules.blocks import TransformerModule
|
||||
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x)
|
||||
@ -64,11 +65,6 @@ class Tester(CombinedModelMixins,
|
||||
self.autopad = AutoPadToShape((self.image_size, self.image_size))
|
||||
|
||||
# Modules with Parameters
|
||||
self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.mlp_dim,
|
||||
heads=self.params.heads, depth=self.params.attn_depth,
|
||||
dropout=self.params.dropout, use_norm=self.params.use_norm,
|
||||
activation=self.params.activation, use_residual=self.params.use_residual
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
|
||||
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
|
||||
@ -78,13 +74,17 @@ class Tester(CombinedModelMixins,
|
||||
|
||||
self.to_cls_token = nn.Identity()
|
||||
|
||||
logits = self.params.n_classes if self.params.n_classes > 2 else 1
|
||||
|
||||
outbound_activation = nn.Softmax if logits > 1 else nn.Sigmoid
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(self.embed_dim),
|
||||
nn.Linear(self.embed_dim, self.params.lat_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(self.params.dropout),
|
||||
nn.Linear(self.params.lat_dim, n_classes),
|
||||
nn.Softmax()
|
||||
nn.Linear(self.params.lat_dim, logits),
|
||||
outbound_activation()
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None, return_attn_weights=False):
|
||||
@ -106,15 +106,12 @@ class Tester(CombinedModelMixins,
|
||||
tensor += self.pos_embedding[:, :(n + 1)]
|
||||
tensor = self.dropout(tensor)
|
||||
|
||||
if return_attn_weights:
|
||||
tensor, attn_weights = self.transformer(tensor, mask, return_attn_weights)
|
||||
else:
|
||||
attn_weights = None
|
||||
tensor = self.transformer(tensor, mask)
|
||||
|
||||
tensor = self.to_cls_token(tensor[:, 0])
|
||||
tensor = self.mlp_head(tensor)
|
||||
return Namespace(main_out=tensor, attn_weights=attn_weights)
|
||||
return Namespace(main_out=tensor, attn_weights=None)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
return MultiClassScores(self)(outputs)
|
||||
if self.params.n_classes > 2:
|
||||
return MultiClassScores(self)(outputs)
|
||||
else:
|
||||
return BinaryScores(self)(outputs)
|
||||
|
Reference in New Issue
Block a user