From be097a111a2ea564a904d192e3f4faba65ad0d7e Mon Sep 17 00:00:00 2001 From: Si11ium Date: Sat, 21 Nov 2020 09:28:26 +0100 Subject: [PATCH] New Model, Many Changes --- _paramters.py | 35 +++--- datasets/urban_8k.py | 95 ++++++++++++++++ main.py | 36 +++--- main_inference.py | 9 +- models/bandwise_conv_classifier.py | 6 +- models/bandwise_conv_multihead_classifier.py | 4 +- models/transformer_model.py | 89 +++++++++------ models/transformer_model_sequential.py | 114 +++++++++++++++++++ multi_run.py | 6 +- requirements.txt | 6 +- util/config.py | 26 ----- util/module_mixins.py | 48 +++++--- 12 files changed, 349 insertions(+), 125 deletions(-) create mode 100644 datasets/urban_8k.py create mode 100644 models/transformer_model_sequential.py delete mode 100644 util/config.py diff --git a/_paramters.py b/_paramters.py index c119a9d..db30f6f 100644 --- a/_paramters.py +++ b/_paramters.py @@ -29,35 +29,44 @@ main_arg_parser.add_argument("--data_n_fft", type=int, default=512, help="") main_arg_parser.add_argument("--data_stretch", type=strtobool, default=True, help="") # Transformation Parameters -main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, help="") # 0.4 -main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.3, help="") # 0.4 +main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0.0, help="") # 0.4 +main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.0, help="") # 0.4 main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4 -main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="") # 0.2 +main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0.3, help="") # 0.2 main_arg_parser.add_argument("--data_speed_amount", type=float, default=0, help="") # 0.4 main_arg_parser.add_argument("--data_speed_min", type=float, default=0, help="") # 0.7 main_arg_parser.add_argument("--data_speed_max", type=float, default=0, help="") # 1.7 # Model Parameters -main_arg_parser.add_argument("--model_type", type=str, default="ViT", help="") +# General +main_arg_parser.add_argument("--model_type", type=str, default="SequentialVisualTransformer", help="") main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="") -main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="") -main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="") -main_arg_parser.add_argument("--model_classes", type=int, default=2, help="") -main_arg_parser.add_argument("--model_lat_dim", type=int, default=128, help="") +main_arg_parser.add_argument("--model_activation", type=str, default="gelu", help="") main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="") main_arg_parser.add_argument("--model_norm", type=strtobool, default=True, help="") main_arg_parser.add_argument("--model_dropout", type=float, default=0.2, help="") +main_arg_parser.add_argument("--model_lat_dim", type=int, default=128, help="") +main_arg_parser.add_argument("--model_features", type=int, default=64, help="") + +# CNN Specific +main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="") + +# Transformer Specific +main_arg_parser.add_argument("--model_patch_size", type=int, default=9, help="") +main_arg_parser.add_argument("--model_attn_depth", type=int, default=3, help="") +main_arg_parser.add_argument("--model_heads", type=int, default=8, help="") +main_arg_parser.add_argument("--model_embedding_size", type=int, default=64, help="") # Training Parameters main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="") main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="") -# FIXME: Stochastic weight Avaraging is not good, maybe its my implementation? main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="") -main_arg_parser.add_argument("--train_weight_decay", type=float, default=1e-7, help="") +main_arg_parser.add_argument("--train_weight_decay", type=float, default=0, help="") main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=0, help="") -main_arg_parser.add_argument("--train_epochs", type=int, default=51, help="") -main_arg_parser.add_argument("--train_batch_size", type=int, default=300, help="") -main_arg_parser.add_argument("--train_lr", type=float, default=1e-4, help="") +main_arg_parser.add_argument("--train_epochs", type=int, default=100, help="") +main_arg_parser.add_argument("--train_batch_size", type=int, default=250, help="") +main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="") +main_arg_parser.add_argument("--train_lr_warmup_steps", type=int, default=10, help="") main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="") # Project Parameters diff --git a/datasets/urban_8k.py b/datasets/urban_8k.py new file mode 100644 index 0000000..6b0bb08 --- /dev/null +++ b/datasets/urban_8k.py @@ -0,0 +1,95 @@ +import pickle +from collections import defaultdict +from pathlib import Path + +import librosa as librosa +from torch.utils.data import Dataset +import torch + +import variables as V +from ml_lib.modules.util import F_x + + +class BinaryMasksDataset(Dataset): + _to_label = defaultdict(lambda: -1) + _to_label.update(dict(clear=V.CLEAR, mask=V.MASK)) + + @property + def sample_shape(self): + return self[0][0].shape + + def __init__(self, data_root, setting, mel_transforms, transforms=None, stretch_dataset=False, + use_preprocessed=True): + self.use_preprocessed = use_preprocessed + self.stretch = stretch_dataset + assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.' + assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.' + super(BinaryMasksDataset, self).__init__() + + self.data_root = Path(data_root) + self.setting = setting + self._wav_folder = self.data_root / 'wav' + self._mel_folder = self.data_root / 'mel' + self.container_ext = '.pik' + self._mel_transform = mel_transforms + + self._labels = self._build_labels() + self._wav_files = list(sorted(self._labels.keys())) + self._transforms = transforms or F_x(in_shape=None) + + def _build_labels(self): + labeldict = dict() + with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f: + # Exclude the header + _ = next(f) + for row in f: + if self.setting not in row: + continue + filename, label = row.strip().split(',') + labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename + if self.stretch and self.setting == V.DATA_OPTIONS.train: + additional_dict = ({f'X{key}': val for key, val in labeldict.items()}) + additional_dict.update({f'XX{key}': val for key, val in labeldict.items()}) + additional_dict.update({f'XXX{key}': val for key, val in labeldict.items()}) + labeldict.update(additional_dict) + + # Delete File if one exists. + if not self.use_preprocessed: + for key in labeldict.keys(): + try: + (self._mel_folder / (key.replace('.wav', '') + self.container_ext)).unlink() + except FileNotFoundError: + pass + + return labeldict + + def __len__(self): + return len(self._labels) + + def _compute_or_retrieve(self, filename): + + if not (self._mel_folder / (filename + self.container_ext)).exists(): + raw_sample, sr = librosa.core.load(self._wav_folder / (filename.replace('X', '') + '.wav')) + mel_sample = self._mel_transform(raw_sample) + self._mel_folder.mkdir(exist_ok=True, parents=True) + with (self._mel_folder / (filename + self.container_ext)).open(mode='wb') as f: + pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL) + + with (self._mel_folder / (filename + self.container_ext)).open(mode='rb') as f: + mel_sample = pickle.load(f, fix_imports=True) + return mel_sample + + def __getitem__(self, item): + + key: str = list(self._labels.keys())[item] + filename = key.replace('.wav', '') + mel_sample = self._compute_or_retrieve(filename) + label = self._labels[key] + + transformed_samples = self._transforms(mel_sample) + + if self.setting != V.DATA_OPTIONS.test: + # In test, filenames instead of labels are returned. This is a little hacky though. + label = torch.as_tensor(label, dtype=torch.float) + + return transformed_samples, label diff --git a/main.py b/main.py index 22245a9..5be95c5 100644 --- a/main.py +++ b/main.py @@ -6,14 +6,13 @@ import warnings import torch from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping +from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor from ml_lib.modules.util import LightningBaseModule +from ml_lib.utils.config import Config from ml_lib.utils.logging import Logger # Project Specific Logger SubClasses -from util.config import MConfig - warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) @@ -37,35 +36,30 @@ def run_lightning_loop(config_obj): # Callbacks # ============================================================================= # Checkpoint Saving - checkpoint_callback = ModelCheckpoint( - monitor='uar_score', + ckpt_callback = ModelCheckpoint( + monitor='mean_loss', filepath=str(logger.log_dir / 'ckpt_weights'), verbose=False, save_top_k=5, ) - # Early Stopping - # TODO: For This to work, set a validation step and End Eval and Score - early_stopping_callback = EarlyStopping( - monitor='uar_score', - min_delta=0.01, - patience=10, - ) + # Learning Rate Logger + lr_logger = LearningRateMonitor(logging_interval='epoch') # Trainer # ============================================================================= trainer = Trainer(max_epochs=config_obj.train.epochs, - show_progress_bar=True, weights_save_path=logger.log_dir, gpus=[0] if torch.cuda.is_available() else None, check_val_every_n_epoch=10, # num_sanity_val_steps=config_obj.train.num_sanity_val_steps, # row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting # log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting - checkpoint_callback=checkpoint_callback, + checkpoint_callback=True, + callbacks=[lr_logger, ckpt_callback], logger=logger, fast_dev_run=config_obj.main.debug, - early_stop_callback=None + auto_lr_find=not config_obj.main.debug ) # Model @@ -78,10 +72,15 @@ def run_lightning_loop(config_obj): # Train It if config_obj.model.type.lower() != 'ensemble': + if not config_obj.main.debug and not config_obj.train.lr: + trainer.tune(model) + # ToDo: LR Finder Plot + # fig = lr_finder.plot(suggest=True) + trainer.fit(model) # Save the last state & all parameters - trainer.save_checkpoint(logger.log_dir / 'weights.ckpt') + trainer.save_checkpoint(str(logger.log_dir / 'weights.ckpt')) model.save_to_disk(logger.log_dir) # Evaluate It @@ -99,8 +98,7 @@ def run_lightning_loop(config_obj): outputs.append( model.validation_step((batch_x, label), idx, 1) ) - summary_dict = model.validation_epoch_end([outputs]) - print(summary_dict['log']['uar_score']) + model.validation_epoch_end([outputs]) # trainer.test() outpath = Path(config_obj.train.outpath) @@ -132,6 +130,6 @@ if __name__ == "__main__": from _paramters import main_arg_parser - config = MConfig.read_argparser(main_arg_parser) + config = Config.read_argparser(main_arg_parser) fix_all_random_seeds(config) trained_model = run_lightning_loop(config) diff --git a/main_inference.py b/main_inference.py index 54473cc..3068648 100644 --- a/main_inference.py +++ b/main_inference.py @@ -15,11 +15,10 @@ from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal, MelToImage # Transforms from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug -from ml_lib.utils.logging import Logger +from ml_lib.utils.config import Config from ml_lib.utils.model_io import SavedLightningModels from ml_lib.utils.transforms import ToTensor -from ml_lib.visualization.tools import Plotter -from util.config import MConfig + # Datasets from datasets.binar_masks import BinaryMasksDataset @@ -66,8 +65,8 @@ if __name__ == '__main__': config_filename = 'config.ini' inference_out = 'manual_test_out.csv' - config = MConfig() - config.read_file((Path(model_path) / config_filename).open('r')) + config = Config() + config.read_file((Path(model_path) / config_filename).open()) test_dataloader = prepare_dataloader(config) loaded_model = restore_logger_and_model(model_path) diff --git a/models/bandwise_conv_classifier.py b/models/bandwise_conv_classifier.py index 854b516..50597a5 100644 --- a/models/bandwise_conv_classifier.py +++ b/models/bandwise_conv_classifier.py @@ -4,7 +4,7 @@ from torch import nn from torch.nn import ModuleList from ml_lib.modules.blocks import ConvModule, LinearModule -from ml_lib.modules.util import (LightningBaseModule, HorizontalSplitter, HorizontalMerger) +from ml_lib.modules.util import (LightningBaseModule, Splitter, Merger) from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin, BaseDataloadersMixin) @@ -33,7 +33,7 @@ class BandwiseConvClassifier(BinaryMaskDatasetMixin, # Modules # ============================================================================= - self.split = HorizontalSplitter(self.in_shape, self.n_band_sections) + self.split = Splitter(self.in_shape, self.n_band_sections) k = 3 self.band_list = ModuleList() @@ -48,7 +48,7 @@ class BandwiseConvClassifier(BinaryMaskDatasetMixin, # last_shape = self.conv_list[-1].shape self.band_list.append(conv_list) - self.merge = HorizontalMerger(self.band_list[-1][-1].shape, self.n_band_sections) + self.merge = Merger(self.band_list[-1][-1].shape, self.n_band_sections) self.full_1 = LinearModule(self.merge.shape, self.params.lat_dim, **self.params.module_kwargs) self.full_2 = LinearModule(self.full_1.shape, self.full_1.shape * 2, **self.params.module_kwargs) diff --git a/models/bandwise_conv_multihead_classifier.py b/models/bandwise_conv_multihead_classifier.py index 8b3ff63..4f1331a 100644 --- a/models/bandwise_conv_multihead_classifier.py +++ b/models/bandwise_conv_multihead_classifier.py @@ -5,7 +5,7 @@ from torch import nn from torch.nn import ModuleList from ml_lib.modules.blocks import ConvModule, LinearModule -from ml_lib.modules.util import (LightningBaseModule, Flatten, HorizontalSplitter) +from ml_lib.modules.util import (LightningBaseModule, Splitter) from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin, BaseDataloadersMixin) @@ -69,7 +69,7 @@ class BandwiseConvMultiheadClassifier(BinaryMaskDatasetMixin, # Modules # ============================================================================= - self.split = HorizontalSplitter(self.in_shape, self.n_band_sections) + self.split = Splitter(self.in_shape, self.n_band_sections) self.band_list = ModuleList() for band in range(self.n_band_sections): diff --git a/models/transformer_model.py b/models/transformer_model.py index 6a36b6a..58f5643 100644 --- a/models/transformer_model.py +++ b/models/transformer_model.py @@ -1,16 +1,19 @@ -import variables as V +from argparse import Namespace + +import warnings import torch from torch import nn +from einops import rearrange, repeat + from ml_lib.modules.blocks import TransformerModule -from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape) +from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x) from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin, BaseDataloadersMixin) MIN_NUM_PATCHES = 16 - class VisualTransformer(BinaryMaskDatasetMixin, BaseDataloadersMixin, BaseTrainMixin, @@ -22,69 +25,83 @@ class VisualTransformer(BinaryMaskDatasetMixin, def __init__(self, hparams): super(VisualTransformer, self).__init__(hparams) - self.in_shape = self.dataset.train_dataset.sample_shape - assert len(self.in_shape) == 3, 'There need to be three Dimensions' - channels, height, width = self.in_shape - - # Automatic Image Shaping - image_size = (max(height, width) // self.params.patch_size) * self.params.patch_size - self.image_size = image_size + self.params.patch_size if image_size < max(height, width) else image_size - - # This should be obsolete - assert self.image_size % self.params.patch_size == 0, 'image dimensions must be divisible by the patch size' - - num_patches = (self.image_size // self.params.patch_size) ** 2 - patch_dim = channels * self.params.patch_size ** 2 - assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' \ - f'attention. Try decreasing your patch size' - # Dataset # ============================================================================= self.dataset = self.build_dataset() + self.in_shape = self.dataset.train_dataset.sample_shape + assert len(self.in_shape) == 3, 'There need to be three Dimensions' + channels, height, width = self.in_shape + # Model Paramters # ============================================================================= # Additional parameters - self.attention_dim = self.params.features + self.embed_dim = self.params.embedding_size + + # Automatic Image Shaping + self.patch_size = self.params.patch_size + image_size = (max(height, width) // self.patch_size) * self.patch_size + self.image_size = image_size + self.patch_size if image_size < max(height, width) else image_size + + # This should be obsolete + assert self.image_size % self.patch_size == 0, 'image dimensions must be divisible by the patch size' + + num_patches = (self.image_size // self.patch_size) ** 2 + patch_dim = channels * self.patch_size ** 2 + assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \ + f'attention. Try decreasing your patch size' + + # Correct the Embedding Dim + if not self.embed_dim % self.params.heads == 0: + self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads + message = ('Embedding Dimension was fixed to be devideable by the number' + + f' of attention heads, is now: {self.embed_dim}') + for func in print, warnings.warn: + func(message) # Utility Modules self.autopad = AutoPadToShape((self.image_size, self.image_size)) # Modules with Parameters - self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.attention_dim), False) - self.embedding = nn.Linear(patch_dim, self.attention_dim) - self.cls_token = nn.Parameter(torch.randn(1, 1, self.attention_dim), False) - self.dropout = nn.Dropout(self.params.dropout) + self.transformer = TransformerModule(in_shape=self.embed_dim, hidden_size=self.params.lat_dim, + n_heads=self.params.heads, num_layers=self.params.attn_depth, + dropout=self.params.dropout, use_norm=self.params.use_norm, + activation=self.params.activation_as_string + ) - self.transformer = TransformerModule(self.attention_dim, self.params.attn_depth, self.params.heads, - self.params.lat_dim, self.params.dropout) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim)) + self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \ + else F_x(self.embed_dim) + self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim)) + self.dropout = nn.Dropout(self.params.dropout) self.to_cls_token = nn.Identity() self.mlp_head = nn.Sequential( - nn.LayerNorm(self.attention_dim), - nn.Linear(self.attention_dim, self.params.lat_dim), + nn.LayerNorm(self.embed_dim), + nn.Linear(self.embed_dim, self.params.lat_dim), nn.GELU(), nn.Dropout(self.params.dropout), - nn.Linear(self.params.lat_dim, V.NUM_CLASSES) + nn.Linear(self.params.lat_dim, 1), + nn.Sigmoid() ) def forward(self, x, mask=None): """ - :param tensor: the sequence to the encoder (required). + :param x: the sequence to the encoder (required). :param mask: the mask for the src sequence (optional). :return: """ - + tensor = self.autopad(x) p = self.params.patch_size - # 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p - tensor = torch.reshape(x, (-1, self.image_size * self.image_size, p * p * self.in_shape[0])) + tensor = rearrange(tensor, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) tensor = self.patch_to_embedding(tensor) b, n, _ = tensor.shape - # '() n d -> b n d', b = b - cls_tokens = tensor.repeat(self.cls_token, b) + cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) + tensor = torch.cat((cls_tokens, tensor), dim=1) tensor += self.pos_embedding[:, :(n + 1)] tensor = self.dropout(tensor) @@ -93,4 +110,4 @@ class VisualTransformer(BinaryMaskDatasetMixin, tensor = self.to_cls_token(tensor[:, 0]) tensor = self.mlp_head(tensor) - return tensor \ No newline at end of file + return Namespace(main_out=tensor) diff --git a/models/transformer_model_sequential.py b/models/transformer_model_sequential.py new file mode 100644 index 0000000..54ec1c7 --- /dev/null +++ b/models/transformer_model_sequential.py @@ -0,0 +1,114 @@ +from argparse import Namespace + +import warnings + +import torch +from torch import nn + +from einops import repeat + +from ml_lib.modules.blocks import TransformerModule +from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow) +from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin, + BaseDataloadersMixin) + +MIN_NUM_PATCHES = 16 + +class SequentialVisualTransformer(BinaryMaskDatasetMixin, + BaseDataloadersMixin, + BaseTrainMixin, + BaseValMixin, + BaseOptimizerMixin, + LightningBaseModule + ): + + def __init__(self, hparams): + super(SequentialVisualTransformer, self).__init__(hparams) + + # Dataset + # ============================================================================= + self.dataset = self.build_dataset() + + self.in_shape = self.dataset.train_dataset.sample_shape + assert len(self.in_shape) == 3, 'There need to be three Dimensions' + channels, height, width = self.in_shape + + # Model Paramters + # ============================================================================= + # Additional parameters + self.embed_dim = self.params.embedding_size + self.patch_size = self.params.patch_size + self.height = height + + # Automatic Image Shaping + image_size = (max(height, width) // self.patch_size) * self.patch_size + self.image_size = image_size + self.patch_size if image_size < max(height, width) else image_size + + # This should be obsolete + assert self.image_size % self.patch_size == 0, 'image dimensions must be divisible by the patch size' + + num_patches = (self.image_size // self.patch_size) ** 2 + patch_dim = channels * self.patch_size * self.image_size + assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \ + f'attention. Try decreasing your patch size' + + # Correct the Embedding Dim + if not self.embed_dim % self.params.heads == 0: + self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads + message = ('Embedding Dimension was fixed to be devideable by the number' + + f' of attention heads, is now: {self.embed_dim}') + for func in print, warnings.warn: + func(message) + + # Utility Modules + self.autopad = AutoPadToShape((self.image_size, self.image_size)) + self.dropout = nn.Dropout(self.params.dropout) + self.slider = SlidingWindow((self.image_size, self.patch_size), keepdim=False) + + # Modules with Parameters + self.transformer = TransformerModule(in_shape=self.embed_dim, hidden_size=self.params.lat_dim, + n_heads=self.params.heads, num_layers=self.params.attn_depth, + dropout=self.params.dropout, use_norm=self.params.use_norm, + activation=self.params.activation_as_string + ) + + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim)) + self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \ + else F_x(self.embed_dim) + self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim)) + self.to_cls_token = nn.Identity() + + self.mlp_head = nn.Sequential( + nn.LayerNorm(self.embed_dim), + nn.Linear(self.embed_dim, self.params.lat_dim), + nn.GELU(), + nn.Dropout(self.params.dropout), + nn.Linear(self.params.lat_dim, 1), + nn.Sigmoid() + ) + + def forward(self, x, mask=None): + """ + :param x: the sequence to the encoder (required). + :param mask: the mask for the src sequence (optional). + :return: + """ + tensor = self.autopad(x) + tensor = self.slider(tensor) + + tensor = self.patch_to_embedding(tensor) + b, n, _ = tensor.shape + + # cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) + cls_tokens = self.cls_token.repeat((b, 1, 1)) + + tensor = torch.cat((cls_tokens, tensor), dim=1) + tensor += self.pos_embedding[:, :(n + 1)] + tensor = self.dropout(tensor) + + tensor = self.transformer(tensor, mask) + + tensor = self.to_cls_token(tensor[:, 0]) + tensor = self.mlp_head(tensor) + return Namespace(main_out=tensor) diff --git a/multi_run.py b/multi_run.py index 7c8b5e3..22f8e78 100644 --- a/multi_run.py +++ b/multi_run.py @@ -1,7 +1,7 @@ import shutil import warnings -from util.config import MConfig +from ml_lib.utils.config import Config warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) @@ -17,12 +17,12 @@ if __name__ == '__main__': args = main_arg_parser.parse_args() # Model Settings - config = MConfig().read_namespace(args) + config = Config().read_namespace(args) arg_dict = dict() for seed in range(0, 10): arg_dict.update(main_seed=seed) - for model in ['CC', 'BCMC', 'BCC', 'RCC']: + for model in ['VisualTransformer']: arg_dict.update(model_type=model) raw_conf = dict(data_speed_amount=0.0, data_speed_min=0.0, data_speed_max=0.0, data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0, diff --git a/requirements.txt b/requirements.txt index d24c4ad..4388370 100644 --- a/requirements.txt +++ b/requirements.txt @@ -42,7 +42,7 @@ msgpack-python==0.5.6 natsort==7.0.1 neptune-client==0.4.109 numba==0.49.1 -numpy==1.18.4 +numpy~=1.18.2 oauthlib==3.1.0 packaging==20.3 pandas==1.0.3 @@ -68,7 +68,7 @@ resampy==0.2.2 retrying==1.3.3 rfc3987==1.3.8 rsa==4.0 -scikit-learn==0.23.1 +scikit-learn~=0.22.2.post1 scipy==1.4.1 simplejson==3.17.0 six==1.14.0 @@ -91,3 +91,5 @@ webencodings==0.5.1 websocket-client==0.57.0 Werkzeug==1.0.1 xmltodict==0.12.0 + +einops~=0.3.0 \ No newline at end of file diff --git a/util/config.py b/util/config.py deleted file mode 100644 index 3031d6c..0000000 --- a/util/config.py +++ /dev/null @@ -1,26 +0,0 @@ -from ml_lib.utils.config import Config -from models.conv_classifier import ConvClassifier -from models.bandwise_conv_classifier import BandwiseConvClassifier -from models.bandwise_conv_multihead_classifier import BandwiseConvMultiheadClassifier -from models.ensemble import Ensemble -from models.residual_conv_classifier import ResidualConvClassifier -from models.transformer_model import VisualTransformer - - -class MConfig(Config): - # TODO: There should be a way to automate this. - - @property - def _model_map(self): - return dict(ConvClassifier=ConvClassifier, - CC=ConvClassifier, - BandwiseConvClassifier=BandwiseConvClassifier, - BCC=BandwiseConvClassifier, - BandwiseConvMultiheadClassifier=BandwiseConvMultiheadClassifier, - BCMC=BandwiseConvMultiheadClassifier, - Ensemble=Ensemble, - E=Ensemble, - ResidualConvClassifier=ResidualConvClassifier, - RCC=ResidualConvClassifier, - ViT=VisualTransformer - ) diff --git a/util/module_mixins.py b/util/module_mixins.py index d9ce58f..e280805 100644 --- a/util/module_mixins.py +++ b/util/module_mixins.py @@ -8,7 +8,8 @@ import torch import numpy as np from torch import nn from torch.optim import Adam -from torch.utils.data import DataLoader, RandomSampler +from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts +from torch.utils.data import DataLoader from torchcontrib.optim import SWA from torchvision.transforms import Compose, RandomApply @@ -25,10 +26,23 @@ class BaseOptimizerMixin: def configure_optimizers(self): assert isinstance(self, LightningBaseModule) - opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay) + optimizer_dict = dict( + # 'optimizer':optimizer, # The Optimizer + # 'lr_scheduler': scheduler, # The LR scheduler + frequency=1, # The frequency of the scheduler + interval='epoch', # The unit of the scheduler's step size + # 'reduce_on_plateau': False, # For ReduceLROnPlateau scheduler + # 'monitor': 'mean_val_loss' # Metric to monitor + ) + + optimizer = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay) if self.params.sto_weight_avg: - opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05) - return opt + optimizer = SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=0.05) + optimizer_dict.update(optimizer=optimizer) + if self.params.lr_warmup_steps: + scheduler = CosineAnnealingWarmRestarts(optimizer, self.params.lr_warmup_steps) + optimizer_dict.update(lr_scheduler=scheduler) + return optimizer_dict def on_train_end(self): assert isinstance(self, LightningBaseModule) @@ -54,17 +68,18 @@ class BaseTrainMixin: assert isinstance(self, LightningBaseModule) batch_x, batch_y = batch_xy y = self(batch_x).main_out - bce_loss = self.bce_loss(y, batch_y) + bce_loss = self.bce_loss(y.squeeze(), batch_y) return dict(loss=bce_loss) def training_epoch_end(self, outputs): assert isinstance(self, LightningBaseModule) keys = list(outputs[0].keys()) - summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key] - for output in outputs])) - for key in keys if 'loss' in key}) - return summary_dict + summary_dict = {f'mean_{key}': torch.mean(torch.stack([output[key] + for output in outputs])) + for key in keys if 'loss' in key} + for key in summary_dict.keys(): + self.log(key, summary_dict[key]) class BaseValMixin: @@ -77,17 +92,17 @@ class BaseValMixin: assert isinstance(self, LightningBaseModule) batch_x, batch_y = batch_xy y = self(batch_x).main_out - val_bce_loss = self.bce_loss(y, batch_y) + val_bce_loss = self.bce_loss(y.squeeze(), batch_y) return dict(val_bce_loss=val_bce_loss, batch_idx=batch_idx, y=y, batch_y=batch_y) - def validation_epoch_end(self, outputs, *args, **kwargs): + def validation_epoch_end(self, outputs, *_, **__): assert isinstance(self, LightningBaseModule) - summary_dict = dict(log=dict()) + summary_dict = dict() for output_idx, output in enumerate(outputs): keys = list(output[0].keys()) ident = '' if output_idx == 0 else '_train' - summary_dict['log'].update({f'mean{ident}_{key}': torch.mean(torch.stack([output[key] + summary_dict.update({f'mean{ident}_{key}': torch.mean(torch.stack([output[key] for output in output])) for key in keys if 'loss' in key} ) @@ -101,8 +116,9 @@ class BaseValMixin: uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro', sample_weight=None, zero_division='warn') uar_score = torch.as_tensor(uar_score) - summary_dict['log'].update({f'uar{ident}_score': uar_score}) - return summary_dict + summary_dict.update({f'uar{ident}_score': uar_score}) + for key in summary_dict.keys(): + self.log(key, summary_dict[key]) class BinaryMaskDatasetMixin: @@ -139,7 +155,7 @@ class BinaryMaskDatasetMixin: LoudnessManipulator(self.params.loudness_ratio), ShiftTime(self.params.shift_ratio), MaskAug(self.params.mask_ratio), - ], p=0.6), + ], p=0.6), util_transforms]) # Datasets