torchaudio testing

This commit is contained in:
Si11ium
2020-12-17 08:02:29 +01:00
parent 95dcf22f3d
commit 68431b848e
13 changed files with 578 additions and 418 deletions

View File

@@ -10,11 +10,12 @@ from einops import rearrange, repeat
from ml_lib.modules.blocks import TransformerModule
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
BaseDataloadersMixin, BaseTestMixin)
BaseDataloadersMixin, BaseTestMixin, BaseLossMixin)
MIN_NUM_PATCHES = 16
class VisualTransformer(DatasetMixin,
BaseLossMixin,
BaseDataloadersMixin,
BaseTrainMixin,
BaseValMixin,
@@ -84,8 +85,8 @@ class VisualTransformer(DatasetMixin,
nn.Linear(self.embed_dim, self.params.lat_dim),
nn.GELU(),
nn.Dropout(self.params.dropout),
nn.Linear(self.params.lat_dim, 1),
nn.Sigmoid()
nn.Linear(self.params.lat_dim, 10),
nn.Softmax()
)
def forward(self, x, mask=None):

View File

@@ -8,11 +8,12 @@ from torch import nn
from ml_lib.modules.blocks import TransformerModule
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
BaseDataloadersMixin, BaseTestMixin)
BaseDataloadersMixin, BaseTestMixin, BaseLossMixin)
MIN_NUM_PATCHES = 16
class HorizontalVisualTransformer(DatasetMixin,
BaseLossMixin,
BaseDataloadersMixin,
BaseTrainMixin,
BaseValMixin,
@@ -35,6 +36,7 @@ class HorizontalVisualTransformer(DatasetMixin,
# Model Paramters
# =============================================================================
# Additional parameters
self.n_classes = self.dataset.train_dataset.n_classes
self.embed_dim = self.params.embedding_size
self.patch_size = self.params.patch_size
self.height = height
@@ -81,8 +83,8 @@ class HorizontalVisualTransformer(DatasetMixin,
nn.Linear(self.embed_dim, self.params.lat_dim),
nn.GELU(),
nn.Dropout(self.params.dropout),
nn.Linear(self.params.lat_dim, 1),
nn.Sigmoid()
nn.Linear(self.params.lat_dim, 10),
nn.Softmax()
)
def forward(self, x, mask=None):

View File

@@ -8,11 +8,12 @@ from torch import nn
from ml_lib.modules.blocks import TransformerModule
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
BaseDataloadersMixin, BaseTestMixin)
BaseDataloadersMixin, BaseTestMixin, BaseLossMixin)
MIN_NUM_PATCHES = 16
class VerticalVisualTransformer(DatasetMixin,
BaseLossMixin,
BaseDataloadersMixin,
BaseTrainMixin,
BaseValMixin,
@@ -80,8 +81,8 @@ class VerticalVisualTransformer(DatasetMixin,
nn.Linear(self.embed_dim, self.params.lat_dim),
nn.GELU(),
nn.Dropout(self.params.dropout),
nn.Linear(self.params.lat_dim, 1),
nn.Sigmoid()
nn.Linear(self.params.lat_dim, 10),
nn.Softmax()
)
def forward(self, x, mask=None):