New Model, Many Changes

This commit is contained in:
Si11ium 2020-11-21 09:28:26 +01:00
parent 7bac9e984b
commit be097a111a
12 changed files with 349 additions and 125 deletions

View File

@ -29,35 +29,44 @@ main_arg_parser.add_argument("--data_n_fft", type=int, default=512, help="")
main_arg_parser.add_argument("--data_stretch", type=strtobool, default=True, help="") main_arg_parser.add_argument("--data_stretch", type=strtobool, default=True, help="")
# Transformation Parameters # Transformation Parameters
main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, help="") # 0.4 main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0.0, help="") # 0.4
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.3, help="") # 0.4 main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.0, help="") # 0.4
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4 main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4
main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="") # 0.2 main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0.3, help="") # 0.2
main_arg_parser.add_argument("--data_speed_amount", type=float, default=0, help="") # 0.4 main_arg_parser.add_argument("--data_speed_amount", type=float, default=0, help="") # 0.4
main_arg_parser.add_argument("--data_speed_min", type=float, default=0, help="") # 0.7 main_arg_parser.add_argument("--data_speed_min", type=float, default=0, help="") # 0.7
main_arg_parser.add_argument("--data_speed_max", type=float, default=0, help="") # 1.7 main_arg_parser.add_argument("--data_speed_max", type=float, default=0, help="") # 1.7
# Model Parameters # Model Parameters
main_arg_parser.add_argument("--model_type", type=str, default="ViT", help="") # General
main_arg_parser.add_argument("--model_type", type=str, default="SequentialVisualTransformer", help="")
main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="") main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="")
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="") main_arg_parser.add_argument("--model_activation", type=str, default="gelu", help="")
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="")
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
main_arg_parser.add_argument("--model_lat_dim", type=int, default=128, help="")
main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="") main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_norm", type=strtobool, default=True, help="") main_arg_parser.add_argument("--model_norm", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_dropout", type=float, default=0.2, help="") main_arg_parser.add_argument("--model_dropout", type=float, default=0.2, help="")
main_arg_parser.add_argument("--model_lat_dim", type=int, default=128, help="")
main_arg_parser.add_argument("--model_features", type=int, default=64, help="")
# CNN Specific
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="")
# Transformer Specific
main_arg_parser.add_argument("--model_patch_size", type=int, default=9, help="")
main_arg_parser.add_argument("--model_attn_depth", type=int, default=3, help="")
main_arg_parser.add_argument("--model_heads", type=int, default=8, help="")
main_arg_parser.add_argument("--model_embedding_size", type=int, default=64, help="")
# Training Parameters # Training Parameters
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="") main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="") main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
# FIXME: Stochastic weight Avaraging is not good, maybe its my implementation?
main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="") main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--train_weight_decay", type=float, default=1e-7, help="") main_arg_parser.add_argument("--train_weight_decay", type=float, default=0, help="")
main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=0, help="") main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=0, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=51, help="") main_arg_parser.add_argument("--train_epochs", type=int, default=100, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=300, help="") main_arg_parser.add_argument("--train_batch_size", type=int, default=250, help="")
main_arg_parser.add_argument("--train_lr", type=float, default=1e-4, help="") main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="")
main_arg_parser.add_argument("--train_lr_warmup_steps", type=int, default=10, help="")
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="") main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
# Project Parameters # Project Parameters

95
datasets/urban_8k.py Normal file
View File

@ -0,0 +1,95 @@
import pickle
from collections import defaultdict
from pathlib import Path
import librosa as librosa
from torch.utils.data import Dataset
import torch
import variables as V
from ml_lib.modules.util import F_x
class BinaryMasksDataset(Dataset):
_to_label = defaultdict(lambda: -1)
_to_label.update(dict(clear=V.CLEAR, mask=V.MASK))
@property
def sample_shape(self):
return self[0][0].shape
def __init__(self, data_root, setting, mel_transforms, transforms=None, stretch_dataset=False,
use_preprocessed=True):
self.use_preprocessed = use_preprocessed
self.stretch = stretch_dataset
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
super(BinaryMasksDataset, self).__init__()
self.data_root = Path(data_root)
self.setting = setting
self._wav_folder = self.data_root / 'wav'
self._mel_folder = self.data_root / 'mel'
self.container_ext = '.pik'
self._mel_transform = mel_transforms
self._labels = self._build_labels()
self._wav_files = list(sorted(self._labels.keys()))
self._transforms = transforms or F_x(in_shape=None)
def _build_labels(self):
labeldict = dict()
with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f:
# Exclude the header
_ = next(f)
for row in f:
if self.setting not in row:
continue
filename, label = row.strip().split(',')
labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename
if self.stretch and self.setting == V.DATA_OPTIONS.train:
additional_dict = ({f'X{key}': val for key, val in labeldict.items()})
additional_dict.update({f'XX{key}': val for key, val in labeldict.items()})
additional_dict.update({f'XXX{key}': val for key, val in labeldict.items()})
labeldict.update(additional_dict)
# Delete File if one exists.
if not self.use_preprocessed:
for key in labeldict.keys():
try:
(self._mel_folder / (key.replace('.wav', '') + self.container_ext)).unlink()
except FileNotFoundError:
pass
return labeldict
def __len__(self):
return len(self._labels)
def _compute_or_retrieve(self, filename):
if not (self._mel_folder / (filename + self.container_ext)).exists():
raw_sample, sr = librosa.core.load(self._wav_folder / (filename.replace('X', '') + '.wav'))
mel_sample = self._mel_transform(raw_sample)
self._mel_folder.mkdir(exist_ok=True, parents=True)
with (self._mel_folder / (filename + self.container_ext)).open(mode='wb') as f:
pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
with (self._mel_folder / (filename + self.container_ext)).open(mode='rb') as f:
mel_sample = pickle.load(f, fix_imports=True)
return mel_sample
def __getitem__(self, item):
key: str = list(self._labels.keys())[item]
filename = key.replace('.wav', '')
mel_sample = self._compute_or_retrieve(filename)
label = self._labels[key]
transformed_samples = self._transforms(mel_sample)
if self.setting != V.DATA_OPTIONS.test:
# In test, filenames instead of labels are returned. This is a little hacky though.
label = torch.as_tensor(label, dtype=torch.float)
return transformed_samples, label

36
main.py
View File

@ -6,14 +6,13 @@ import warnings
import torch import torch
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from ml_lib.modules.util import LightningBaseModule from ml_lib.modules.util import LightningBaseModule
from ml_lib.utils.config import Config
from ml_lib.utils.logging import Logger from ml_lib.utils.logging import Logger
# Project Specific Logger SubClasses # Project Specific Logger SubClasses
from util.config import MConfig
warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning) warnings.filterwarnings('ignore', category=UserWarning)
@ -37,35 +36,30 @@ def run_lightning_loop(config_obj):
# Callbacks # Callbacks
# ============================================================================= # =============================================================================
# Checkpoint Saving # Checkpoint Saving
checkpoint_callback = ModelCheckpoint( ckpt_callback = ModelCheckpoint(
monitor='uar_score', monitor='mean_loss',
filepath=str(logger.log_dir / 'ckpt_weights'), filepath=str(logger.log_dir / 'ckpt_weights'),
verbose=False, verbose=False,
save_top_k=5, save_top_k=5,
) )
# Early Stopping # Learning Rate Logger
# TODO: For This to work, set a validation step and End Eval and Score lr_logger = LearningRateMonitor(logging_interval='epoch')
early_stopping_callback = EarlyStopping(
monitor='uar_score',
min_delta=0.01,
patience=10,
)
# Trainer # Trainer
# ============================================================================= # =============================================================================
trainer = Trainer(max_epochs=config_obj.train.epochs, trainer = Trainer(max_epochs=config_obj.train.epochs,
show_progress_bar=True,
weights_save_path=logger.log_dir, weights_save_path=logger.log_dir,
gpus=[0] if torch.cuda.is_available() else None, gpus=[0] if torch.cuda.is_available() else None,
check_val_every_n_epoch=10, check_val_every_n_epoch=10,
# num_sanity_val_steps=config_obj.train.num_sanity_val_steps, # num_sanity_val_steps=config_obj.train.num_sanity_val_steps,
# row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting # row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting
# log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting # log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting
checkpoint_callback=checkpoint_callback, checkpoint_callback=True,
callbacks=[lr_logger, ckpt_callback],
logger=logger, logger=logger,
fast_dev_run=config_obj.main.debug, fast_dev_run=config_obj.main.debug,
early_stop_callback=None auto_lr_find=not config_obj.main.debug
) )
# Model # Model
@ -78,10 +72,15 @@ def run_lightning_loop(config_obj):
# Train It # Train It
if config_obj.model.type.lower() != 'ensemble': if config_obj.model.type.lower() != 'ensemble':
if not config_obj.main.debug and not config_obj.train.lr:
trainer.tune(model)
# ToDo: LR Finder Plot
# fig = lr_finder.plot(suggest=True)
trainer.fit(model) trainer.fit(model)
# Save the last state & all parameters # Save the last state & all parameters
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt') trainer.save_checkpoint(str(logger.log_dir / 'weights.ckpt'))
model.save_to_disk(logger.log_dir) model.save_to_disk(logger.log_dir)
# Evaluate It # Evaluate It
@ -99,8 +98,7 @@ def run_lightning_loop(config_obj):
outputs.append( outputs.append(
model.validation_step((batch_x, label), idx, 1) model.validation_step((batch_x, label), idx, 1)
) )
summary_dict = model.validation_epoch_end([outputs]) model.validation_epoch_end([outputs])
print(summary_dict['log']['uar_score'])
# trainer.test() # trainer.test()
outpath = Path(config_obj.train.outpath) outpath = Path(config_obj.train.outpath)
@ -132,6 +130,6 @@ if __name__ == "__main__":
from _paramters import main_arg_parser from _paramters import main_arg_parser
config = MConfig.read_argparser(main_arg_parser) config = Config.read_argparser(main_arg_parser)
fix_all_random_seeds(config) fix_all_random_seeds(config)
trained_model = run_lightning_loop(config) trained_model = run_lightning_loop(config)

View File

@ -15,11 +15,10 @@ from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal, MelToImage
# Transforms # Transforms
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
from ml_lib.utils.logging import Logger from ml_lib.utils.config import Config
from ml_lib.utils.model_io import SavedLightningModels from ml_lib.utils.model_io import SavedLightningModels
from ml_lib.utils.transforms import ToTensor from ml_lib.utils.transforms import ToTensor
from ml_lib.visualization.tools import Plotter
from util.config import MConfig
# Datasets # Datasets
from datasets.binar_masks import BinaryMasksDataset from datasets.binar_masks import BinaryMasksDataset
@ -66,8 +65,8 @@ if __name__ == '__main__':
config_filename = 'config.ini' config_filename = 'config.ini'
inference_out = 'manual_test_out.csv' inference_out = 'manual_test_out.csv'
config = MConfig() config = Config()
config.read_file((Path(model_path) / config_filename).open('r')) config.read_file((Path(model_path) / config_filename).open())
test_dataloader = prepare_dataloader(config) test_dataloader = prepare_dataloader(config)
loaded_model = restore_logger_and_model(model_path) loaded_model = restore_logger_and_model(model_path)

View File

@ -4,7 +4,7 @@ from torch import nn
from torch.nn import ModuleList from torch.nn import ModuleList
from ml_lib.modules.blocks import ConvModule, LinearModule from ml_lib.modules.blocks import ConvModule, LinearModule
from ml_lib.modules.util import (LightningBaseModule, HorizontalSplitter, HorizontalMerger) from ml_lib.modules.util import (LightningBaseModule, Splitter, Merger)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin, from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
BaseDataloadersMixin) BaseDataloadersMixin)
@ -33,7 +33,7 @@ class BandwiseConvClassifier(BinaryMaskDatasetMixin,
# Modules # Modules
# ============================================================================= # =============================================================================
self.split = HorizontalSplitter(self.in_shape, self.n_band_sections) self.split = Splitter(self.in_shape, self.n_band_sections)
k = 3 k = 3
self.band_list = ModuleList() self.band_list = ModuleList()
@ -48,7 +48,7 @@ class BandwiseConvClassifier(BinaryMaskDatasetMixin,
# last_shape = self.conv_list[-1].shape # last_shape = self.conv_list[-1].shape
self.band_list.append(conv_list) self.band_list.append(conv_list)
self.merge = HorizontalMerger(self.band_list[-1][-1].shape, self.n_band_sections) self.merge = Merger(self.band_list[-1][-1].shape, self.n_band_sections)
self.full_1 = LinearModule(self.merge.shape, self.params.lat_dim, **self.params.module_kwargs) self.full_1 = LinearModule(self.merge.shape, self.params.lat_dim, **self.params.module_kwargs)
self.full_2 = LinearModule(self.full_1.shape, self.full_1.shape * 2, **self.params.module_kwargs) self.full_2 = LinearModule(self.full_1.shape, self.full_1.shape * 2, **self.params.module_kwargs)

View File

@ -5,7 +5,7 @@ from torch import nn
from torch.nn import ModuleList from torch.nn import ModuleList
from ml_lib.modules.blocks import ConvModule, LinearModule from ml_lib.modules.blocks import ConvModule, LinearModule
from ml_lib.modules.util import (LightningBaseModule, Flatten, HorizontalSplitter) from ml_lib.modules.util import (LightningBaseModule, Splitter)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin, from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
BaseDataloadersMixin) BaseDataloadersMixin)
@ -69,7 +69,7 @@ class BandwiseConvMultiheadClassifier(BinaryMaskDatasetMixin,
# Modules # Modules
# ============================================================================= # =============================================================================
self.split = HorizontalSplitter(self.in_shape, self.n_band_sections) self.split = Splitter(self.in_shape, self.n_band_sections)
self.band_list = ModuleList() self.band_list = ModuleList()
for band in range(self.n_band_sections): for band in range(self.n_band_sections):

View File

@ -1,16 +1,19 @@
import variables as V from argparse import Namespace
import warnings
import torch import torch
from torch import nn from torch import nn
from einops import rearrange, repeat
from ml_lib.modules.blocks import TransformerModule from ml_lib.modules.blocks import TransformerModule
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape) from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin, from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
BaseDataloadersMixin) BaseDataloadersMixin)
MIN_NUM_PATCHES = 16 MIN_NUM_PATCHES = 16
class VisualTransformer(BinaryMaskDatasetMixin, class VisualTransformer(BinaryMaskDatasetMixin,
BaseDataloadersMixin, BaseDataloadersMixin,
BaseTrainMixin, BaseTrainMixin,
@ -22,69 +25,83 @@ class VisualTransformer(BinaryMaskDatasetMixin,
def __init__(self, hparams): def __init__(self, hparams):
super(VisualTransformer, self).__init__(hparams) super(VisualTransformer, self).__init__(hparams)
self.in_shape = self.dataset.train_dataset.sample_shape
assert len(self.in_shape) == 3, 'There need to be three Dimensions'
channels, height, width = self.in_shape
# Automatic Image Shaping
image_size = (max(height, width) // self.params.patch_size) * self.params.patch_size
self.image_size = image_size + self.params.patch_size if image_size < max(height, width) else image_size
# This should be obsolete
assert self.image_size % self.params.patch_size == 0, 'image dimensions must be divisible by the patch size'
num_patches = (self.image_size // self.params.patch_size) ** 2
patch_dim = channels * self.params.patch_size ** 2
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' \
f'attention. Try decreasing your patch size'
# Dataset # Dataset
# ============================================================================= # =============================================================================
self.dataset = self.build_dataset() self.dataset = self.build_dataset()
self.in_shape = self.dataset.train_dataset.sample_shape
assert len(self.in_shape) == 3, 'There need to be three Dimensions'
channels, height, width = self.in_shape
# Model Paramters # Model Paramters
# ============================================================================= # =============================================================================
# Additional parameters # Additional parameters
self.attention_dim = self.params.features self.embed_dim = self.params.embedding_size
# Automatic Image Shaping
self.patch_size = self.params.patch_size
image_size = (max(height, width) // self.patch_size) * self.patch_size
self.image_size = image_size + self.patch_size if image_size < max(height, width) else image_size
# This should be obsolete
assert self.image_size % self.patch_size == 0, 'image dimensions must be divisible by the patch size'
num_patches = (self.image_size // self.patch_size) ** 2
patch_dim = channels * self.patch_size ** 2
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
f'attention. Try decreasing your patch size'
# Correct the Embedding Dim
if not self.embed_dim % self.params.heads == 0:
self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
message = ('Embedding Dimension was fixed to be devideable by the number' +
f' of attention heads, is now: {self.embed_dim}')
for func in print, warnings.warn:
func(message)
# Utility Modules # Utility Modules
self.autopad = AutoPadToShape((self.image_size, self.image_size)) self.autopad = AutoPadToShape((self.image_size, self.image_size))
# Modules with Parameters # Modules with Parameters
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.attention_dim), False) self.transformer = TransformerModule(in_shape=self.embed_dim, hidden_size=self.params.lat_dim,
self.embedding = nn.Linear(patch_dim, self.attention_dim) n_heads=self.params.heads, num_layers=self.params.attn_depth,
self.cls_token = nn.Parameter(torch.randn(1, 1, self.attention_dim), False) dropout=self.params.dropout, use_norm=self.params.use_norm,
self.dropout = nn.Dropout(self.params.dropout) activation=self.params.activation_as_string
)
self.transformer = TransformerModule(self.attention_dim, self.params.attn_depth, self.params.heads,
self.params.lat_dim, self.params.dropout) self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
else F_x(self.embed_dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
self.dropout = nn.Dropout(self.params.dropout)
self.to_cls_token = nn.Identity() self.to_cls_token = nn.Identity()
self.mlp_head = nn.Sequential( self.mlp_head = nn.Sequential(
nn.LayerNorm(self.attention_dim), nn.LayerNorm(self.embed_dim),
nn.Linear(self.attention_dim, self.params.lat_dim), nn.Linear(self.embed_dim, self.params.lat_dim),
nn.GELU(), nn.GELU(),
nn.Dropout(self.params.dropout), nn.Dropout(self.params.dropout),
nn.Linear(self.params.lat_dim, V.NUM_CLASSES) nn.Linear(self.params.lat_dim, 1),
nn.Sigmoid()
) )
def forward(self, x, mask=None): def forward(self, x, mask=None):
""" """
:param tensor: the sequence to the encoder (required). :param x: the sequence to the encoder (required).
:param mask: the mask for the src sequence (optional). :param mask: the mask for the src sequence (optional).
:return: :return:
""" """
tensor = self.autopad(x)
p = self.params.patch_size p = self.params.patch_size
# 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p tensor = rearrange(tensor, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
tensor = torch.reshape(x, (-1, self.image_size * self.image_size, p * p * self.in_shape[0]))
tensor = self.patch_to_embedding(tensor) tensor = self.patch_to_embedding(tensor)
b, n, _ = tensor.shape b, n, _ = tensor.shape
# '() n d -> b n d', b = b cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
cls_tokens = tensor.repeat(self.cls_token, b)
tensor = torch.cat((cls_tokens, tensor), dim=1) tensor = torch.cat((cls_tokens, tensor), dim=1)
tensor += self.pos_embedding[:, :(n + 1)] tensor += self.pos_embedding[:, :(n + 1)]
tensor = self.dropout(tensor) tensor = self.dropout(tensor)
@ -93,4 +110,4 @@ class VisualTransformer(BinaryMaskDatasetMixin,
tensor = self.to_cls_token(tensor[:, 0]) tensor = self.to_cls_token(tensor[:, 0])
tensor = self.mlp_head(tensor) tensor = self.mlp_head(tensor)
return tensor return Namespace(main_out=tensor)

View File

@ -0,0 +1,114 @@
from argparse import Namespace
import warnings
import torch
from torch import nn
from einops import repeat
from ml_lib.modules.blocks import TransformerModule
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
BaseDataloadersMixin)
MIN_NUM_PATCHES = 16
class SequentialVisualTransformer(BinaryMaskDatasetMixin,
BaseDataloadersMixin,
BaseTrainMixin,
BaseValMixin,
BaseOptimizerMixin,
LightningBaseModule
):
def __init__(self, hparams):
super(SequentialVisualTransformer, self).__init__(hparams)
# Dataset
# =============================================================================
self.dataset = self.build_dataset()
self.in_shape = self.dataset.train_dataset.sample_shape
assert len(self.in_shape) == 3, 'There need to be three Dimensions'
channels, height, width = self.in_shape
# Model Paramters
# =============================================================================
# Additional parameters
self.embed_dim = self.params.embedding_size
self.patch_size = self.params.patch_size
self.height = height
# Automatic Image Shaping
image_size = (max(height, width) // self.patch_size) * self.patch_size
self.image_size = image_size + self.patch_size if image_size < max(height, width) else image_size
# This should be obsolete
assert self.image_size % self.patch_size == 0, 'image dimensions must be divisible by the patch size'
num_patches = (self.image_size // self.patch_size) ** 2
patch_dim = channels * self.patch_size * self.image_size
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
f'attention. Try decreasing your patch size'
# Correct the Embedding Dim
if not self.embed_dim % self.params.heads == 0:
self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
message = ('Embedding Dimension was fixed to be devideable by the number' +
f' of attention heads, is now: {self.embed_dim}')
for func in print, warnings.warn:
func(message)
# Utility Modules
self.autopad = AutoPadToShape((self.image_size, self.image_size))
self.dropout = nn.Dropout(self.params.dropout)
self.slider = SlidingWindow((self.image_size, self.patch_size), keepdim=False)
# Modules with Parameters
self.transformer = TransformerModule(in_shape=self.embed_dim, hidden_size=self.params.lat_dim,
n_heads=self.params.heads, num_layers=self.params.attn_depth,
dropout=self.params.dropout, use_norm=self.params.use_norm,
activation=self.params.activation_as_string
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
else F_x(self.embed_dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
self.to_cls_token = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(self.embed_dim),
nn.Linear(self.embed_dim, self.params.lat_dim),
nn.GELU(),
nn.Dropout(self.params.dropout),
nn.Linear(self.params.lat_dim, 1),
nn.Sigmoid()
)
def forward(self, x, mask=None):
"""
:param x: the sequence to the encoder (required).
:param mask: the mask for the src sequence (optional).
:return:
"""
tensor = self.autopad(x)
tensor = self.slider(tensor)
tensor = self.patch_to_embedding(tensor)
b, n, _ = tensor.shape
# cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
cls_tokens = self.cls_token.repeat((b, 1, 1))
tensor = torch.cat((cls_tokens, tensor), dim=1)
tensor += self.pos_embedding[:, :(n + 1)]
tensor = self.dropout(tensor)
tensor = self.transformer(tensor, mask)
tensor = self.to_cls_token(tensor[:, 0])
tensor = self.mlp_head(tensor)
return Namespace(main_out=tensor)

View File

@ -1,7 +1,7 @@
import shutil import shutil
import warnings import warnings
from util.config import MConfig from ml_lib.utils.config import Config
warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning) warnings.filterwarnings('ignore', category=UserWarning)
@ -17,12 +17,12 @@ if __name__ == '__main__':
args = main_arg_parser.parse_args() args = main_arg_parser.parse_args()
# Model Settings # Model Settings
config = MConfig().read_namespace(args) config = Config().read_namespace(args)
arg_dict = dict() arg_dict = dict()
for seed in range(0, 10): for seed in range(0, 10):
arg_dict.update(main_seed=seed) arg_dict.update(main_seed=seed)
for model in ['CC', 'BCMC', 'BCC', 'RCC']: for model in ['VisualTransformer']:
arg_dict.update(model_type=model) arg_dict.update(model_type=model)
raw_conf = dict(data_speed_amount=0.0, data_speed_min=0.0, data_speed_max=0.0, raw_conf = dict(data_speed_amount=0.0, data_speed_min=0.0, data_speed_max=0.0,
data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0, data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,

View File

@ -42,7 +42,7 @@ msgpack-python==0.5.6
natsort==7.0.1 natsort==7.0.1
neptune-client==0.4.109 neptune-client==0.4.109
numba==0.49.1 numba==0.49.1
numpy==1.18.4 numpy~=1.18.2
oauthlib==3.1.0 oauthlib==3.1.0
packaging==20.3 packaging==20.3
pandas==1.0.3 pandas==1.0.3
@ -68,7 +68,7 @@ resampy==0.2.2
retrying==1.3.3 retrying==1.3.3
rfc3987==1.3.8 rfc3987==1.3.8
rsa==4.0 rsa==4.0
scikit-learn==0.23.1 scikit-learn~=0.22.2.post1
scipy==1.4.1 scipy==1.4.1
simplejson==3.17.0 simplejson==3.17.0
six==1.14.0 six==1.14.0
@ -91,3 +91,5 @@ webencodings==0.5.1
websocket-client==0.57.0 websocket-client==0.57.0
Werkzeug==1.0.1 Werkzeug==1.0.1
xmltodict==0.12.0 xmltodict==0.12.0
einops~=0.3.0

View File

@ -1,26 +0,0 @@
from ml_lib.utils.config import Config
from models.conv_classifier import ConvClassifier
from models.bandwise_conv_classifier import BandwiseConvClassifier
from models.bandwise_conv_multihead_classifier import BandwiseConvMultiheadClassifier
from models.ensemble import Ensemble
from models.residual_conv_classifier import ResidualConvClassifier
from models.transformer_model import VisualTransformer
class MConfig(Config):
# TODO: There should be a way to automate this.
@property
def _model_map(self):
return dict(ConvClassifier=ConvClassifier,
CC=ConvClassifier,
BandwiseConvClassifier=BandwiseConvClassifier,
BCC=BandwiseConvClassifier,
BandwiseConvMultiheadClassifier=BandwiseConvMultiheadClassifier,
BCMC=BandwiseConvMultiheadClassifier,
Ensemble=Ensemble,
E=Ensemble,
ResidualConvClassifier=ResidualConvClassifier,
RCC=ResidualConvClassifier,
ViT=VisualTransformer
)

View File

@ -8,7 +8,8 @@ import torch
import numpy as np import numpy as np
from torch import nn from torch import nn
from torch.optim import Adam from torch.optim import Adam
from torch.utils.data import DataLoader, RandomSampler from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.utils.data import DataLoader
from torchcontrib.optim import SWA from torchcontrib.optim import SWA
from torchvision.transforms import Compose, RandomApply from torchvision.transforms import Compose, RandomApply
@ -25,10 +26,23 @@ class BaseOptimizerMixin:
def configure_optimizers(self): def configure_optimizers(self):
assert isinstance(self, LightningBaseModule) assert isinstance(self, LightningBaseModule)
opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay) optimizer_dict = dict(
# 'optimizer':optimizer, # The Optimizer
# 'lr_scheduler': scheduler, # The LR scheduler
frequency=1, # The frequency of the scheduler
interval='epoch', # The unit of the scheduler's step size
# 'reduce_on_plateau': False, # For ReduceLROnPlateau scheduler
# 'monitor': 'mean_val_loss' # Metric to monitor
)
optimizer = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
if self.params.sto_weight_avg: if self.params.sto_weight_avg:
opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05) optimizer = SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=0.05)
return opt optimizer_dict.update(optimizer=optimizer)
if self.params.lr_warmup_steps:
scheduler = CosineAnnealingWarmRestarts(optimizer, self.params.lr_warmup_steps)
optimizer_dict.update(lr_scheduler=scheduler)
return optimizer_dict
def on_train_end(self): def on_train_end(self):
assert isinstance(self, LightningBaseModule) assert isinstance(self, LightningBaseModule)
@ -54,17 +68,18 @@ class BaseTrainMixin:
assert isinstance(self, LightningBaseModule) assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy batch_x, batch_y = batch_xy
y = self(batch_x).main_out y = self(batch_x).main_out
bce_loss = self.bce_loss(y, batch_y) bce_loss = self.bce_loss(y.squeeze(), batch_y)
return dict(loss=bce_loss) return dict(loss=bce_loss)
def training_epoch_end(self, outputs): def training_epoch_end(self, outputs):
assert isinstance(self, LightningBaseModule) assert isinstance(self, LightningBaseModule)
keys = list(outputs[0].keys()) keys = list(outputs[0].keys())
summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key] summary_dict = {f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs])) for output in outputs]))
for key in keys if 'loss' in key}) for key in keys if 'loss' in key}
return summary_dict for key in summary_dict.keys():
self.log(key, summary_dict[key])
class BaseValMixin: class BaseValMixin:
@ -77,17 +92,17 @@ class BaseValMixin:
assert isinstance(self, LightningBaseModule) assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy batch_x, batch_y = batch_xy
y = self(batch_x).main_out y = self(batch_x).main_out
val_bce_loss = self.bce_loss(y, batch_y) val_bce_loss = self.bce_loss(y.squeeze(), batch_y)
return dict(val_bce_loss=val_bce_loss, return dict(val_bce_loss=val_bce_loss,
batch_idx=batch_idx, y=y, batch_y=batch_y) batch_idx=batch_idx, y=y, batch_y=batch_y)
def validation_epoch_end(self, outputs, *args, **kwargs): def validation_epoch_end(self, outputs, *_, **__):
assert isinstance(self, LightningBaseModule) assert isinstance(self, LightningBaseModule)
summary_dict = dict(log=dict()) summary_dict = dict()
for output_idx, output in enumerate(outputs): for output_idx, output in enumerate(outputs):
keys = list(output[0].keys()) keys = list(output[0].keys())
ident = '' if output_idx == 0 else '_train' ident = '' if output_idx == 0 else '_train'
summary_dict['log'].update({f'mean{ident}_{key}': torch.mean(torch.stack([output[key] summary_dict.update({f'mean{ident}_{key}': torch.mean(torch.stack([output[key]
for output in output])) for output in output]))
for key in keys if 'loss' in key} for key in keys if 'loss' in key}
) )
@ -101,8 +116,9 @@ class BaseValMixin:
uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro', uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro',
sample_weight=None, zero_division='warn') sample_weight=None, zero_division='warn')
uar_score = torch.as_tensor(uar_score) uar_score = torch.as_tensor(uar_score)
summary_dict['log'].update({f'uar{ident}_score': uar_score}) summary_dict.update({f'uar{ident}_score': uar_score})
return summary_dict for key in summary_dict.keys():
self.log(key, summary_dict[key])
class BinaryMaskDatasetMixin: class BinaryMaskDatasetMixin:
@ -139,7 +155,7 @@ class BinaryMaskDatasetMixin:
LoudnessManipulator(self.params.loudness_ratio), LoudnessManipulator(self.params.loudness_ratio),
ShiftTime(self.params.shift_ratio), ShiftTime(self.params.shift_ratio),
MaskAug(self.params.mask_ratio), MaskAug(self.params.mask_ratio),
], p=0.6), ], p=0.6),
util_transforms]) util_transforms])
# Datasets # Datasets