adjustment fot CCS, notebook folder
This commit is contained in:
@@ -3,6 +3,7 @@ from argparse import Namespace
|
||||
|
||||
from torch import nn
|
||||
|
||||
from ml_lib.metrics.binary_class_classifictaion import BinaryScores
|
||||
from ml_lib.metrics.multi_class_classification import MultiClassScores
|
||||
from ml_lib.modules.blocks import LinearModule
|
||||
from ml_lib.modules.model_parts import CNNEncoder
|
||||
@@ -36,9 +37,11 @@ class CNNBaseline(CombinedModelMixins,
|
||||
# Modules with Parameters
|
||||
self.encoder = CNNEncoder(in_shape=self.in_shape, **self.params.module_kwargs)
|
||||
|
||||
# Make Decision between binary and Multiclass Classification
|
||||
logits = n_classes if n_classes > 2 else 1
|
||||
module_kwargs = self.params.module_kwargs
|
||||
module_kwargs.update(activation=nn.Softmax)
|
||||
self.classifier = LinearModule(self.encoder.shape, n_classes, **module_kwargs)
|
||||
module_kwargs.update(activation=(nn.Softmax if logits > 1 else nn.Sigmoid))
|
||||
self.classifier = LinearModule(self.encoder.shape, logits, **module_kwargs)
|
||||
|
||||
def forward(self, x, mask=None, return_attn_weights=False):
|
||||
"""
|
||||
@@ -52,4 +55,7 @@ class CNNBaseline(CombinedModelMixins,
|
||||
return Namespace(main_out=tensor)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
return MultiClassScores(self)(outputs)
|
||||
if self.params.n_classes > 2:
|
||||
return MultiClassScores(self)(outputs)
|
||||
else:
|
||||
return BinaryScores(self)(outputs)
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import inspect
|
||||
from argparse import Namespace
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@@ -70,13 +68,15 @@ class VisualTransformer(CombinedModelMixins,
|
||||
|
||||
self.to_cls_token = nn.Identity()
|
||||
|
||||
logits = self.params.n_classes if self.params.n_classes > 2 else 1
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(self.embed_dim),
|
||||
nn.Linear(self.embed_dim, self.params.lat_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(self.params.dropout),
|
||||
nn.Linear(self.params.lat_dim, n_classes),
|
||||
nn.Softmax()
|
||||
nn.Linear(self.params.lat_dim, logits),
|
||||
nn.Softmax() if logits > 1 else nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None, return_attn_weights=False):
|
||||
|
||||
Reference in New Issue
Block a user