torchaudio testing
This commit is contained in:
@@ -10,11 +10,12 @@ from einops import rearrange, repeat
|
||||
from ml_lib.modules.blocks import TransformerModule
|
||||
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x)
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
|
||||
BaseDataloadersMixin, BaseTestMixin)
|
||||
BaseDataloadersMixin, BaseTestMixin, BaseLossMixin)
|
||||
|
||||
MIN_NUM_PATCHES = 16
|
||||
|
||||
class VisualTransformer(DatasetMixin,
|
||||
BaseLossMixin,
|
||||
BaseDataloadersMixin,
|
||||
BaseTrainMixin,
|
||||
BaseValMixin,
|
||||
@@ -84,8 +85,8 @@ class VisualTransformer(DatasetMixin,
|
||||
nn.Linear(self.embed_dim, self.params.lat_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(self.params.dropout),
|
||||
nn.Linear(self.params.lat_dim, 1),
|
||||
nn.Sigmoid()
|
||||
nn.Linear(self.params.lat_dim, 10),
|
||||
nn.Softmax()
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
|
||||
@@ -8,11 +8,12 @@ from torch import nn
|
||||
from ml_lib.modules.blocks import TransformerModule
|
||||
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow)
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
|
||||
BaseDataloadersMixin, BaseTestMixin)
|
||||
BaseDataloadersMixin, BaseTestMixin, BaseLossMixin)
|
||||
|
||||
MIN_NUM_PATCHES = 16
|
||||
|
||||
class HorizontalVisualTransformer(DatasetMixin,
|
||||
BaseLossMixin,
|
||||
BaseDataloadersMixin,
|
||||
BaseTrainMixin,
|
||||
BaseValMixin,
|
||||
@@ -35,6 +36,7 @@ class HorizontalVisualTransformer(DatasetMixin,
|
||||
# Model Paramters
|
||||
# =============================================================================
|
||||
# Additional parameters
|
||||
self.n_classes = self.dataset.train_dataset.n_classes
|
||||
self.embed_dim = self.params.embedding_size
|
||||
self.patch_size = self.params.patch_size
|
||||
self.height = height
|
||||
@@ -81,8 +83,8 @@ class HorizontalVisualTransformer(DatasetMixin,
|
||||
nn.Linear(self.embed_dim, self.params.lat_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(self.params.dropout),
|
||||
nn.Linear(self.params.lat_dim, 1),
|
||||
nn.Sigmoid()
|
||||
nn.Linear(self.params.lat_dim, 10),
|
||||
nn.Softmax()
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
|
||||
@@ -8,11 +8,12 @@ from torch import nn
|
||||
from ml_lib.modules.blocks import TransformerModule
|
||||
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow)
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
|
||||
BaseDataloadersMixin, BaseTestMixin)
|
||||
BaseDataloadersMixin, BaseTestMixin, BaseLossMixin)
|
||||
|
||||
MIN_NUM_PATCHES = 16
|
||||
|
||||
class VerticalVisualTransformer(DatasetMixin,
|
||||
BaseLossMixin,
|
||||
BaseDataloadersMixin,
|
||||
BaseTrainMixin,
|
||||
BaseValMixin,
|
||||
@@ -80,8 +81,8 @@ class VerticalVisualTransformer(DatasetMixin,
|
||||
nn.Linear(self.embed_dim, self.params.lat_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(self.params.dropout),
|
||||
nn.Linear(self.params.lat_dim, 1),
|
||||
nn.Sigmoid()
|
||||
nn.Linear(self.params.lat_dim, 10),
|
||||
nn.Softmax()
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
|
||||
Reference in New Issue
Block a user