CCS intergration training running
notebooks
This commit is contained in:
@@ -7,7 +7,7 @@ from torch import nn
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from ml_lib.metrics.multi_class_classification import MultiClassScores
|
||||
from ml_lib.modules.blocks import TransformerModule
|
||||
from ml_lib.modules.blocks import (TransformerModule, F_x)
|
||||
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape)
|
||||
from util.module_mixins import CombinedModelMixins
|
||||
|
||||
@@ -21,7 +21,8 @@ class VisualTransformer(CombinedModelMixins,
|
||||
def __init__(self, in_shape, n_classes, weight_init, activation,
|
||||
embedding_size, heads, attn_depth, patch_size, use_residual, variable_length,
|
||||
use_bias, use_norm, dropout, lat_dim, loss, scheduler, mlp_dim, head_dim,
|
||||
lr, weight_decay, sto_weight_avg, lr_scheduler_parameter, opt_reset_interval):
|
||||
lr, weight_decay, sto_weight_avg, lr_scheduler_parameter, opt_reset_interval,
|
||||
return_logits=False):
|
||||
|
||||
# TODO: Move this to parent class, or make it much easier to access... But How...
|
||||
a = dict(locals())
|
||||
@@ -69,14 +70,20 @@ class VisualTransformer(CombinedModelMixins,
|
||||
self.to_cls_token = nn.Identity()
|
||||
|
||||
logits = self.params.n_classes if self.params.n_classes > 2 else 1
|
||||
|
||||
|
||||
if return_logits:
|
||||
outbound_activation = nn.Identity()
|
||||
else:
|
||||
outbound_activation = nn.Softmax() if logits > 1 else nn.Sigmoid()
|
||||
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(self.embed_dim),
|
||||
nn.Linear(self.embed_dim, self.params.lat_dim),
|
||||
nn.GELU(),
|
||||
self.params.activation(),
|
||||
nn.Dropout(self.params.dropout),
|
||||
nn.Linear(self.params.lat_dim, logits),
|
||||
nn.Softmax() if logits > 1 else nn.Sigmoid()
|
||||
outbound_activation
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None, return_attn_weights=False):
|
||||
|
||||
Reference in New Issue
Block a user