Audio Dataset
This commit is contained in:
@@ -5,11 +5,11 @@ from torch.nn import ModuleList
|
||||
|
||||
from ml_lib.modules.blocks import ConvModule, LinearModule
|
||||
from ml_lib.modules.util import (LightningBaseModule, Splitter, Merger)
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
|
||||
BaseDataloadersMixin)
|
||||
|
||||
|
||||
class BandwiseConvClassifier(BinaryMaskDatasetMixin,
|
||||
class BandwiseConvClassifier(DatasetMixin,
|
||||
BaseDataloadersMixin,
|
||||
BaseTrainMixin,
|
||||
BaseValMixin,
|
||||
|
||||
@@ -6,11 +6,11 @@ from torch.nn import ModuleList
|
||||
|
||||
from ml_lib.modules.blocks import ConvModule, LinearModule
|
||||
from ml_lib.modules.util import (LightningBaseModule, Splitter)
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
|
||||
BaseDataloadersMixin)
|
||||
|
||||
|
||||
class BandwiseConvMultiheadClassifier(BinaryMaskDatasetMixin,
|
||||
class BandwiseConvMultiheadClassifier(DatasetMixin,
|
||||
BaseDataloadersMixin,
|
||||
BaseTrainMixin,
|
||||
BaseValMixin,
|
||||
|
||||
@@ -5,11 +5,11 @@ from torch.nn import ModuleList
|
||||
|
||||
from ml_lib.modules.blocks import ConvModule, LinearModule
|
||||
from ml_lib.modules.util import LightningBaseModule
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
|
||||
BaseDataloadersMixin)
|
||||
|
||||
|
||||
class ConvClassifier(BinaryMaskDatasetMixin,
|
||||
class ConvClassifier(DatasetMixin,
|
||||
BaseDataloadersMixin,
|
||||
BaseTrainMixin,
|
||||
BaseValMixin,
|
||||
|
||||
@@ -8,11 +8,11 @@ from torch.nn import ModuleList
|
||||
from ml_lib.modules.util import LightningBaseModule
|
||||
from ml_lib.utils.config import Config
|
||||
from ml_lib.utils.model_io import SavedLightningModels
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
|
||||
BaseDataloadersMixin)
|
||||
|
||||
|
||||
class Ensemble(BinaryMaskDatasetMixin,
|
||||
class Ensemble(DatasetMixin,
|
||||
BaseDataloadersMixin,
|
||||
BaseTrainMixin,
|
||||
BaseValMixin,
|
||||
|
||||
@@ -5,11 +5,11 @@ from torch.nn import ModuleList
|
||||
|
||||
from ml_lib.modules.blocks import ConvModule, LinearModule, ResidualModule
|
||||
from ml_lib.modules.util import LightningBaseModule
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
|
||||
BaseDataloadersMixin)
|
||||
|
||||
|
||||
class ResidualConvClassifier(BinaryMaskDatasetMixin,
|
||||
class ResidualConvClassifier(DatasetMixin,
|
||||
BaseDataloadersMixin,
|
||||
BaseTrainMixin,
|
||||
BaseValMixin,
|
||||
|
||||
@@ -9,15 +9,16 @@ from einops import rearrange, repeat
|
||||
|
||||
from ml_lib.modules.blocks import TransformerModule
|
||||
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x)
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||
BaseDataloadersMixin)
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
|
||||
BaseDataloadersMixin, BaseTestMixin)
|
||||
|
||||
MIN_NUM_PATCHES = 16
|
||||
|
||||
class VisualTransformer(BinaryMaskDatasetMixin,
|
||||
class VisualTransformer(DatasetMixin,
|
||||
BaseDataloadersMixin,
|
||||
BaseTrainMixin,
|
||||
BaseValMixin,
|
||||
BaseTestMixin,
|
||||
BaseOptimizerMixin,
|
||||
LightningBaseModule
|
||||
):
|
||||
|
||||
111
models/transformer_model_horizontal.py
Normal file
111
models/transformer_model_horizontal.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from argparse import Namespace
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ml_lib.modules.blocks import TransformerModule
|
||||
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow)
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
|
||||
BaseDataloadersMixin, BaseTestMixin)
|
||||
|
||||
MIN_NUM_PATCHES = 16
|
||||
|
||||
class HorizontalVisualTransformer(DatasetMixin,
|
||||
BaseDataloadersMixin,
|
||||
BaseTrainMixin,
|
||||
BaseValMixin,
|
||||
BaseTestMixin,
|
||||
BaseOptimizerMixin,
|
||||
LightningBaseModule
|
||||
):
|
||||
|
||||
def __init__(self, hparams):
|
||||
super(HorizontalVisualTransformer, self).__init__(hparams)
|
||||
|
||||
# Dataset
|
||||
# =============================================================================
|
||||
self.dataset = self.build_dataset()
|
||||
|
||||
self.in_shape = self.dataset.train_dataset.sample_shape
|
||||
assert len(self.in_shape) == 3, 'There need to be three Dimensions'
|
||||
channels, height, width = self.in_shape
|
||||
|
||||
# Model Paramters
|
||||
# =============================================================================
|
||||
# Additional parameters
|
||||
self.embed_dim = self.params.embedding_size
|
||||
self.patch_size = self.params.patch_size
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.channels = channels
|
||||
|
||||
self.new_height = ((self.height - self.patch_size)//1) + 1
|
||||
|
||||
num_patches = self.new_height - (self.patch_size // 2)
|
||||
patch_dim = channels * self.patch_size * self.width
|
||||
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
|
||||
f'attention. Try decreasing your patch size'
|
||||
|
||||
# Correct the Embedding Dim
|
||||
if not self.embed_dim % self.params.heads == 0:
|
||||
self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
|
||||
message = ('Embedding Dimension was fixed to be devideable by the number' +
|
||||
f' of attention heads, is now: {self.embed_dim}')
|
||||
for func in print, warnings.warn:
|
||||
func(message)
|
||||
|
||||
# Utility Modules
|
||||
self.autopad = AutoPadToShape((self.new_height, self.width))
|
||||
self.dropout = nn.Dropout(self.params.dropout)
|
||||
self.slider = SlidingWindow((channels, *self.autopad.target_shape), (self.patch_size, self.width),
|
||||
keepdim=False)
|
||||
|
||||
# Modules with Parameters
|
||||
self.transformer = TransformerModule(in_shape=self.embed_dim, hidden_size=self.params.lat_dim,
|
||||
n_heads=self.params.heads, num_layers=self.params.attn_depth,
|
||||
dropout=self.params.dropout, use_norm=self.params.use_norm,
|
||||
activation=self.params.activation_as_string
|
||||
)
|
||||
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
|
||||
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
|
||||
else F_x(self.embed_dim)
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
||||
self.to_cls_token = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(self.embed_dim),
|
||||
nn.Linear(self.embed_dim, self.params.lat_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(self.params.dropout),
|
||||
nn.Linear(self.params.lat_dim, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
"""
|
||||
:param x: the sequence to the encoder (required).
|
||||
:param mask: the mask for the src sequence (optional).
|
||||
:return:
|
||||
"""
|
||||
tensor = self.autopad(x)
|
||||
tensor = self.slider(tensor)
|
||||
|
||||
tensor = self.patch_to_embedding(tensor)
|
||||
b, n, _ = tensor.shape
|
||||
|
||||
# cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
cls_tokens = self.cls_token.repeat((b, 1, 1))
|
||||
|
||||
tensor = torch.cat((cls_tokens, tensor), dim=1)
|
||||
tensor += self.pos_embedding[:, :(n + 1)]
|
||||
tensor = self.dropout(tensor)
|
||||
|
||||
tensor = self.transformer(tensor, mask)
|
||||
|
||||
tensor = self.to_cls_token(tensor[:, 0])
|
||||
tensor = self.mlp_head(tensor)
|
||||
return Namespace(main_out=tensor)
|
||||
@@ -7,21 +7,22 @@ from torch import nn
|
||||
|
||||
from ml_lib.modules.blocks import TransformerModule
|
||||
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow)
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||
BaseDataloadersMixin)
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
|
||||
BaseDataloadersMixin, BaseTestMixin)
|
||||
|
||||
MIN_NUM_PATCHES = 16
|
||||
|
||||
class SequentialVisualTransformer(BinaryMaskDatasetMixin,
|
||||
BaseDataloadersMixin,
|
||||
BaseTrainMixin,
|
||||
BaseValMixin,
|
||||
BaseOptimizerMixin,
|
||||
LightningBaseModule
|
||||
):
|
||||
class VerticalVisualTransformer(DatasetMixin,
|
||||
BaseDataloadersMixin,
|
||||
BaseTrainMixin,
|
||||
BaseValMixin,
|
||||
BaseTestMixin,
|
||||
BaseOptimizerMixin,
|
||||
LightningBaseModule
|
||||
):
|
||||
|
||||
def __init__(self, hparams):
|
||||
super(SequentialVisualTransformer, self).__init__(hparams)
|
||||
super(VerticalVisualTransformer, self).__init__(hparams)
|
||||
|
||||
# Dataset
|
||||
# =============================================================================
|
||||
Reference in New Issue
Block a user