adjustment fot CCS, notebook folder

This commit is contained in:
Steffen Illium
2021-03-22 16:43:19 +01:00
parent 78b3139d1a
commit c12f3866c8
6 changed files with 156 additions and 29 deletions

View File

@@ -1,8 +1,6 @@
import inspect
from argparse import Namespace
import warnings
import torch
from torch import nn
@@ -70,13 +68,15 @@ class VisualTransformer(CombinedModelMixins,
self.to_cls_token = nn.Identity()
logits = self.params.n_classes if self.params.n_classes > 2 else 1
self.mlp_head = nn.Sequential(
nn.LayerNorm(self.embed_dim),
nn.Linear(self.embed_dim, self.params.lat_dim),
nn.GELU(),
nn.Dropout(self.params.dropout),
nn.Linear(self.params.lat_dim, n_classes),
nn.Softmax()
nn.Linear(self.params.lat_dim, logits),
nn.Softmax() if logits > 1 else nn.Sigmoid()
)
def forward(self, x, mask=None, return_attn_weights=False):