diff --git a/_paramters.py b/_paramters.py index 5962bc7..02d1b75 100644 --- a/_paramters.py +++ b/_paramters.py @@ -2,7 +2,8 @@ from argparse import ArgumentParser, Namespace from distutils.util import strtobool from pathlib import Path -import os +NEPTUNE_API_KEY = 'eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vdWkubmVwdHVuZS5haSIsImFwaV91cmwiOiJodHRwczovL3VpLm5lcHR1bmUu' \ + 'YWkiLCJhcGlfa2V5IjoiZmI0OGMzNzUtOTg1NS00Yzg2LThjMzYtMWFiYjUwMDUyMjVlIn0=' # Parameter Configuration # ============================================================================= @@ -18,10 +19,10 @@ main_arg_parser.add_argument("--main_eval", type=strtobool, default=True, help=" main_arg_parser.add_argument("--main_seed", type=int, default=69, help="") # Data Parameters -main_arg_parser.add_argument("--data_worker", type=int, default=11, help="") +main_arg_parser.add_argument("--data_class_name", type=str, default='Urban8K', help="") +main_arg_parser.add_argument("--data_worker", type=int, default=6, help="") main_arg_parser.add_argument("--data_root", type=str, default='data', help="") -main_arg_parser.add_argument("--data_class_name", type=str, default='BinaryMasksDataset', help="") -main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=False, help="") +main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="") main_arg_parser.add_argument("--data_n_mels", type=int, default=64, help="") main_arg_parser.add_argument("--data_sr", type=int, default=16000, help="") main_arg_parser.add_argument("--data_hop_length", type=int, default=256, help="") @@ -29,17 +30,17 @@ main_arg_parser.add_argument("--data_n_fft", type=int, default=512, help="") main_arg_parser.add_argument("--data_stretch", type=strtobool, default=True, help="") # Transformation Parameters -main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0.0, help="") # 0.4 -main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.0, help="") # 0.4 -main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4 -main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0.3, help="") # 0.2 -main_arg_parser.add_argument("--data_speed_amount", type=float, default=0, help="") # 0.4 -main_arg_parser.add_argument("--data_speed_min", type=float, default=0, help="") # 0.7 -main_arg_parser.add_argument("--data_speed_max", type=float, default=0, help="") # 1.7 +main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0.0, help="") # 0.4 +main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.0, help="") # 0.4 +main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4 +main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0.3, help="") # 0.2 +main_arg_parser.add_argument("--data_speed_amount", type=float, default=0, help="") # 0.4 +main_arg_parser.add_argument("--data_speed_min", type=float, default=0, help="") # 0.7 +main_arg_parser.add_argument("--data_speed_max", type=float, default=0, help="") # 1.7 # Model Parameters # General -main_arg_parser.add_argument("--model_type", type=str, default="SequentialVisualTransformer", help="") +main_arg_parser.add_argument("--model_type", type=str, default="HorizontalVisualTransformer", help="") main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="") main_arg_parser.add_argument("--model_activation", type=str, default="gelu", help="") main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="") @@ -63,8 +64,8 @@ main_arg_parser.add_argument("--train_version", type=strtobool, required=False, main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="") main_arg_parser.add_argument("--train_weight_decay", type=float, default=0, help="") main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=0, help="") -main_arg_parser.add_argument("--train_epochs", type=int, default=100, help="") -main_arg_parser.add_argument("--train_batch_size", type=int, default=250, help="") +main_arg_parser.add_argument("--train_epochs", type=int, default=200, help="") +main_arg_parser.add_argument("--train_batch_size", type=int, default=200, help="") main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="") main_arg_parser.add_argument("--train_lr_warmup_steps", type=int, default=10, help="") main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="") @@ -72,7 +73,7 @@ main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0 # Project Parameters main_arg_parser.add_argument("--project_name", type=str, default=_ROOT.name, help="") main_arg_parser.add_argument("--project_owner", type=str, default='si11ium', help="") -main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.getenv('NEPTUNE_API_TOKEN'), help="") +main_arg_parser.add_argument("--project_neptune_key", type=str, default=NEPTUNE_API_KEY, help="") if __name__ == '__main__': # Parse it diff --git a/datasets/binar_masks.py b/datasets/binar_masks.py index 6b0bb08..c095086 100644 --- a/datasets/binar_masks.py +++ b/datasets/binar_masks.py @@ -18,15 +18,18 @@ class BinaryMasksDataset(Dataset): def sample_shape(self): return self[0][0].shape + @property + def _fingerprint(self): + return dict(**self._mel_kwargs, normalize=self.normalize) + def __init__(self, data_root, setting, mel_transforms, transforms=None, stretch_dataset=False, use_preprocessed=True): - self.use_preprocessed = use_preprocessed self.stretch = stretch_dataset assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.' assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.' super(BinaryMasksDataset, self).__init__() - self.data_root = Path(data_root) + self.data_root = Path(data_root) / 'ComParE2020_Mask' self.setting = setting self._wav_folder = self.data_root / 'wav' self._mel_folder = self.data_root / 'mel' @@ -37,16 +40,36 @@ class BinaryMasksDataset(Dataset): self._wav_files = list(sorted(self._labels.keys())) self._transforms = transforms or F_x(in_shape=None) + param_storage = self._mel_folder / 'data_params.pik' + self._mel_folder.mkdir(parents=True, exist_ok=True) + try: + pik_data = param_storage.read_bytes() + fingerprint = pickle.loads(pik_data) + if fingerprint == self._fingerprint: + self.use_preprocessed = use_preprocessed + else: + print('Diverging parameters were found; Refreshing...') + param_storage.unlink() + pik_data = pickle.dumps(self._fingerprint) + param_storage.write_bytes(pik_data) + self.use_preprocessed = True + + except FileNotFoundError: + pik_data = pickle.dumps(self._fingerprint) + param_storage.write_bytes(pik_data) + self.use_preprocessed = True + def _build_labels(self): labeldict = dict() - with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f: + labelfile = 'labels' if self.setting != V.DATA_OPTIONS.test else V.DATA_OPTIONS.test + with open(Path(self.data_root) / 'lab' / f'{labelfile}.csv', mode='r') as f: # Exclude the header _ = next(f) for row in f: if self.setting not in row: continue filename, label = row.strip().split(',') - labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename + labeldict[filename] = self._to_label[label.lower()] # if not self.setting == 'test' else filename if self.stretch and self.setting == V.DATA_OPTIONS.train: additional_dict = ({f'X{key}': val for key, val in labeldict.items()}) additional_dict.update({f'XX{key}': val for key, val in labeldict.items()}) diff --git a/datasets/urban_8k.py b/datasets/urban_8k.py index 6b0bb08..69d3e16 100644 --- a/datasets/urban_8k.py +++ b/datasets/urban_8k.py @@ -1,95 +1,140 @@ import pickle -from collections import defaultdict from pathlib import Path +import multiprocessing as mp import librosa as librosa -from torch.utils.data import Dataset +from torch.utils.data import Dataset, ConcatDataset import torch +from tqdm import tqdm import variables as V +from ml_lib.audio_toolset.mel_dataset import TorchMelDataset from ml_lib.modules.util import F_x -class BinaryMasksDataset(Dataset): - _to_label = defaultdict(lambda: -1) - _to_label.update(dict(clear=V.CLEAR, mask=V.MASK)) +class Urban8K(Dataset): @property def sample_shape(self): return self[0][0].shape - def __init__(self, data_root, setting, mel_transforms, transforms=None, stretch_dataset=False, - use_preprocessed=True): - self.use_preprocessed = use_preprocessed - self.stretch = stretch_dataset + @property + def _fingerprint(self): + return str(self._mel_transform) + + def __init__(self, data_root, setting, mel_transforms, fold=1, transforms=None, + use_preprocessed=True, audio_segment_len=62, audio_hop_len=30, num_worker=mp.cpu_count(), + **_): assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.' assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.' - super(BinaryMasksDataset, self).__init__() + assert fold in range(1, 11) + super(Urban8K, self).__init__() - self.data_root = Path(data_root) + self.data_root = Path(data_root) / 'UrbanSound8K' self.setting = setting - self._wav_folder = self.data_root / 'wav' - self._mel_folder = self.data_root / 'mel' + self.num_worker = num_worker + self.fold = fold if self.setting == V.DATA_OPTIONS.train else 10 + self.use_preprocessed = use_preprocessed + self._wav_folder = self.data_root / 'audio' / f'fold{self.fold}' + self._mel_folder = self.data_root / 'mel' / f'fold{self.fold}' self.container_ext = '.pik' self._mel_transform = mel_transforms self._labels = self._build_labels() self._wav_files = list(sorted(self._labels.keys())) - self._transforms = transforms or F_x(in_shape=None) + transforms = transforms or F_x(in_shape=None) + + param_storage = self._mel_folder / 'data_params.pik' + self._mel_folder.mkdir(parents=True, exist_ok=True) + try: + pik_data = param_storage.read_bytes() + fingerprint = pickle.loads(pik_data) + if fingerprint == self._fingerprint: + self.use_preprocessed = use_preprocessed + else: + print('Diverging parameters were found; Refreshing...') + param_storage.unlink() + pik_data = pickle.dumps(self._fingerprint) + param_storage.write_bytes(pik_data) + self.use_preprocessed = False + + except FileNotFoundError: + pik_data = pickle.dumps(self._fingerprint) + param_storage.write_bytes(pik_data) + self.use_preprocessed = False + + + while True: + if not self.use_preprocessed: + self._pre_process() + try: + self._dataset = ConcatDataset( + [TorchMelDataset(identifier=key, mel_path=self._mel_folder, transform=transforms, + segment_len=audio_segment_len, hop_len=audio_hop_len, + label=self._labels[key]['label'] + ) for key in self._labels.keys()] + ) + break + except IOError: + self.use_preprocessed = False + pass def _build_labels(self): labeldict = dict() - with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f: + with open(Path(self.data_root) / 'metadata' / 'UrbanSound8K.csv', mode='r') as f: # Exclude the header _ = next(f) for row in f: - if self.setting not in row: - continue - filename, label = row.strip().split(',') - labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename - if self.stretch and self.setting == V.DATA_OPTIONS.train: - additional_dict = ({f'X{key}': val for key, val in labeldict.items()}) - additional_dict.update({f'XX{key}': val for key, val in labeldict.items()}) - additional_dict.update({f'XXX{key}': val for key, val in labeldict.items()}) - labeldict.update(additional_dict) + slice_file_name, fs_id, start, end, salience, fold, class_id, class_name = row.strip().split(',') + if int(fold) == self.fold: + key = slice_file_name.replace('.wav', '') + labeldict[key] = dict(label=int(class_id), fold=int(fold)) # Delete File if one exists. if not self.use_preprocessed: for key in labeldict.keys(): - try: - (self._mel_folder / (key.replace('.wav', '') + self.container_ext)).unlink() - except FileNotFoundError: - pass + for mel_file in self._mel_folder.rglob(f'{key}_*'): + try: + mel_file.unlink(missing_ok=True) + except FileNotFoundError: + pass return labeldict def __len__(self): - return len(self._labels) + return len(self._dataset) - def _compute_or_retrieve(self, filename): + def _pre_process(self): + print('Preprocessing Mel Files....') + with mp.Pool(processes=self.num_worker) as pool: + with tqdm(total=len(self._labels)) as pbar: + for i, _ in enumerate(pool.imap_unordered(self._build_mel, self._labels.keys())): + pbar.update() - if not (self._mel_folder / (filename + self.container_ext)).exists(): - raw_sample, sr = librosa.core.load(self._wav_folder / (filename.replace('X', '') + '.wav')) + def _build_mel(self, filename): + + wav_file = self._wav_folder / (filename.replace('X', '') + '.wav') + mel_file = list(self._mel_folder.glob(f'{filename}_*')) + + if not mel_file: + raw_sample, sr = librosa.core.load(wav_file) mel_sample = self._mel_transform(raw_sample) + m, n = mel_sample.shape + mel_file = self._mel_folder / f'{filename}_{m}_{n}' self._mel_folder.mkdir(exist_ok=True, parents=True) - with (self._mel_folder / (filename + self.container_ext)).open(mode='wb') as f: + with mel_file.open(mode='wb') as f: pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL) + else: + # print(f"Already existed.. Skipping {filename}") + mel_file = mel_file[0] - with (self._mel_folder / (filename + self.container_ext)).open(mode='rb') as f: + with mel_file.open(mode='rb') as f: mel_sample = pickle.load(f, fix_imports=True) - return mel_sample + return mel_sample, mel_file def __getitem__(self, item): + transformed_samples, label = self._dataset[item] - key: str = list(self._labels.keys())[item] - filename = key.replace('.wav', '') - mel_sample = self._compute_or_retrieve(filename) - label = self._labels[key] - - transformed_samples = self._transforms(mel_sample) - - if self.setting != V.DATA_OPTIONS.test: - # In test, filenames instead of labels are returned. This is a little hacky though. - label = torch.as_tensor(label, dtype=torch.float) + label = torch.as_tensor(label, dtype=torch.float) return transformed_samples, label diff --git a/datasets/urban_8k_torchaudio.py b/datasets/urban_8k_torchaudio.py new file mode 100644 index 0000000..bc7b3cb --- /dev/null +++ b/datasets/urban_8k_torchaudio.py @@ -0,0 +1,140 @@ +import pickle +from pathlib import Path +import multiprocessing as mp + +import librosa as librosa +from torch.utils.data import Dataset, ConcatDataset +import torch +from tqdm import tqdm + +import variables as V +from ml_lib.audio_toolset.mel_dataset import TorchMelDataset +from ml_lib.modules.util import F_x + + +class Urban8K_TO(Dataset): + + @property + def sample_shape(self): + return self[0][0].shape + + @property + def _fingerprint(self): + return str(self._mel_transform) + + def __init__(self, data_root, setting, mel_transforms, fold=1, transforms=None, + use_preprocessed=True, audio_segment_len=1, audio_hop_len=1, num_worker=mp.cpu_count(), + **_): + assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.' + assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.' + assert fold in range(1, 11) + super(Urban8K_TO, self).__init__() + + self.data_root = Path(data_root) / 'UrbanSound8K' + self.setting = setting + self.num_worker = num_worker + self.fold = fold if self.setting == V.DATA_OPTIONS.train else 10 + self.use_preprocessed = use_preprocessed + self._wav_folder = self.data_root / 'audio' / f'fold{self.fold}' + self._mel_folder = self.data_root / 'mel' / f'fold{self.fold}' + self.container_ext = '.pik' + self._mel_transform = mel_transforms + + self._labels = self._build_labels() + self._wav_files = list(sorted(self._labels.keys())) + transforms = transforms or F_x(in_shape=None) + + param_storage = self._mel_folder / 'data_params.pik' + self._mel_folder.mkdir(parents=True, exist_ok=True) + try: + pik_data = param_storage.read_bytes() + fingerprint = pickle.loads(pik_data) + if fingerprint == self._fingerprint: + self.use_preprocessed = use_preprocessed + else: + print('Diverging parameters were found; Refreshing...') + param_storage.unlink() + pik_data = pickle.dumps(self._fingerprint) + param_storage.write_bytes(pik_data) + self.use_preprocessed = False + + except FileNotFoundError: + pik_data = pickle.dumps(self._fingerprint) + param_storage.write_bytes(pik_data) + self.use_preprocessed = False + + + while True: + if not self.use_preprocessed: + self._pre_process() + try: + self._dataset = ConcatDataset( + [TorchMelDataset(identifier=key, mel_path=self._mel_folder, transform=transforms, + segment_len=audio_segment_len, hop_len=audio_hop_len, + label=self._labels[key]['label'] + ) for key in self._labels.keys()] + ) + break + except IOError: + self.use_preprocessed = False + pass + + def _build_labels(self): + labeldict = dict() + with open(Path(self.data_root) / 'metadata' / 'UrbanSound8K.csv', mode='r') as f: + # Exclude the header + _ = next(f) + for row in f: + slice_file_name, fs_id, start, end, salience, fold, class_id, class_name = row.strip().split(',') + if int(fold) == self.fold: + key = slice_file_name.replace('.wav', '') + labeldict[key] = dict(label=int(class_id), fold=int(fold)) + + # Delete File if one exists. + if not self.use_preprocessed: + for key in labeldict.keys(): + for mel_file in self._mel_folder.rglob(f'{key}_*'): + try: + mel_file.unlink(missing_ok=True) + except FileNotFoundError: + pass + + return labeldict + + def __len__(self): + return len(self._dataset) + + def _pre_process(self): + print('Preprocessing Mel Files....') + with mp.Pool(processes=self.num_worker) as pool: + with tqdm(total=len(self._labels)) as pbar: + for i, _ in enumerate(pool.imap_unordered(self._build_mel, self._labels.keys())): + pbar.update() + + def _build_mel(self, filename): + + wav_file = self._wav_folder / (filename.replace('X', '') + '.wav') + mel_file = list(self._mel_folder.glob(f'{filename}_*')) + + if not mel_file: + raw_sample, sr = librosa.core.load(wav_file) + mel_sample = self._mel_transform(raw_sample) + m, n = mel_sample.shape + mel_file = self._mel_folder / f'{filename}_{m}_{n}' + self._mel_folder.mkdir(exist_ok=True, parents=True) + with mel_file.open(mode='wb') as f: + pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL) + else: + # print(f"Already existed.. Skipping {filename}") + mel_file = mel_file[0] + + with mel_file.open(mode='rb') as f: + mel_sample = pickle.load(f, fix_imports=True) + return mel_sample, mel_file + + def __getitem__(self, item): + transformed_samples, label = self._dataset[item] + + label = torch.as_tensor(label, dtype=torch.float) + + return transformed_samples, label diff --git a/main.py b/main.py index 5be95c5..c7ffce8 100644 --- a/main.py +++ b/main.py @@ -82,47 +82,9 @@ def run_lightning_loop(config_obj): # Save the last state & all parameters trainer.save_checkpoint(str(logger.log_dir / 'weights.ckpt')) model.save_to_disk(logger.log_dir) + # trainer.run_evaluation(test_mode=True) - # Evaluate It - if config_obj.main.eval: - with torch.no_grad(): - model.eval() - if torch.cuda.is_available(): - model.cuda() - outputs = [] - from tqdm import tqdm - for idx, batch in enumerate(tqdm(model.val_dataloader()[0])): - batch_x, label = batch - batch_x = batch_x.to(device='cuda' if model.on_gpu else 'cpu') - label = label.to(device='cuda' if model.on_gpu else 'cpu') - outputs.append( - model.validation_step((batch_x, label), idx, 1) - ) - model.validation_epoch_end([outputs]) - # trainer.test() - outpath = Path(config_obj.train.outpath) - model_type = config_obj.model.type - parameters = logger.name - version = f'version_{logger.version}' - inference_out = f'{parameters}_test_out.csv' - - from main_inference import prepare_dataloader - import variables as V - test_dataloader = prepare_dataloader(config_obj) - - with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile: - outfile.write(f'file_name,prediction\n') - - from tqdm import tqdm - for batch in tqdm(test_dataloader, total=len(test_dataloader)): - batch_x, file_names = batch - batch_x = batch_x.to(device='cuda' if model.on_gpu else 'cpu') - y = model(batch_x).main_out - predictions = (y >= 0.5).int() - for prediction, file_name in zip(predictions, file_names): - prediction_text = 'clear' if prediction == V.CLEAR else 'mask' - outfile.write(f'{file_name},{prediction_text}\n') return model diff --git a/models/bandwise_conv_classifier.py b/models/bandwise_conv_classifier.py index 50597a5..072ea7b 100644 --- a/models/bandwise_conv_classifier.py +++ b/models/bandwise_conv_classifier.py @@ -5,11 +5,11 @@ from torch.nn import ModuleList from ml_lib.modules.blocks import ConvModule, LinearModule from ml_lib.modules.util import (LightningBaseModule, Splitter, Merger) -from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin, +from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin, BaseDataloadersMixin) -class BandwiseConvClassifier(BinaryMaskDatasetMixin, +class BandwiseConvClassifier(DatasetMixin, BaseDataloadersMixin, BaseTrainMixin, BaseValMixin, diff --git a/models/bandwise_conv_multihead_classifier.py b/models/bandwise_conv_multihead_classifier.py index 4f1331a..578b4cf 100644 --- a/models/bandwise_conv_multihead_classifier.py +++ b/models/bandwise_conv_multihead_classifier.py @@ -6,11 +6,11 @@ from torch.nn import ModuleList from ml_lib.modules.blocks import ConvModule, LinearModule from ml_lib.modules.util import (LightningBaseModule, Splitter) -from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin, +from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin, BaseDataloadersMixin) -class BandwiseConvMultiheadClassifier(BinaryMaskDatasetMixin, +class BandwiseConvMultiheadClassifier(DatasetMixin, BaseDataloadersMixin, BaseTrainMixin, BaseValMixin, diff --git a/models/conv_classifier.py b/models/conv_classifier.py index f89930a..eacbec9 100644 --- a/models/conv_classifier.py +++ b/models/conv_classifier.py @@ -5,11 +5,11 @@ from torch.nn import ModuleList from ml_lib.modules.blocks import ConvModule, LinearModule from ml_lib.modules.util import LightningBaseModule -from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin, +from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin, BaseDataloadersMixin) -class ConvClassifier(BinaryMaskDatasetMixin, +class ConvClassifier(DatasetMixin, BaseDataloadersMixin, BaseTrainMixin, BaseValMixin, diff --git a/models/ensemble.py b/models/ensemble.py index 5b446b1..f599266 100644 --- a/models/ensemble.py +++ b/models/ensemble.py @@ -8,11 +8,11 @@ from torch.nn import ModuleList from ml_lib.modules.util import LightningBaseModule from ml_lib.utils.config import Config from ml_lib.utils.model_io import SavedLightningModels -from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin, +from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin, BaseDataloadersMixin) -class Ensemble(BinaryMaskDatasetMixin, +class Ensemble(DatasetMixin, BaseDataloadersMixin, BaseTrainMixin, BaseValMixin, diff --git a/models/residual_conv_classifier.py b/models/residual_conv_classifier.py index 51fd7fd..5aa1a0c 100644 --- a/models/residual_conv_classifier.py +++ b/models/residual_conv_classifier.py @@ -5,11 +5,11 @@ from torch.nn import ModuleList from ml_lib.modules.blocks import ConvModule, LinearModule, ResidualModule from ml_lib.modules.util import LightningBaseModule -from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin, +from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin, BaseDataloadersMixin) -class ResidualConvClassifier(BinaryMaskDatasetMixin, +class ResidualConvClassifier(DatasetMixin, BaseDataloadersMixin, BaseTrainMixin, BaseValMixin, diff --git a/models/transformer_model.py b/models/transformer_model.py index 58f5643..80deea4 100644 --- a/models/transformer_model.py +++ b/models/transformer_model.py @@ -9,15 +9,16 @@ from einops import rearrange, repeat from ml_lib.modules.blocks import TransformerModule from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x) -from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin, - BaseDataloadersMixin) +from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin, + BaseDataloadersMixin, BaseTestMixin) MIN_NUM_PATCHES = 16 -class VisualTransformer(BinaryMaskDatasetMixin, +class VisualTransformer(DatasetMixin, BaseDataloadersMixin, BaseTrainMixin, BaseValMixin, + BaseTestMixin, BaseOptimizerMixin, LightningBaseModule ): diff --git a/models/transformer_model_horizontal.py b/models/transformer_model_horizontal.py new file mode 100644 index 0000000..701ccb4 --- /dev/null +++ b/models/transformer_model_horizontal.py @@ -0,0 +1,111 @@ +from argparse import Namespace + +import warnings + +import torch +from torch import nn + +from ml_lib.modules.blocks import TransformerModule +from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow) +from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin, + BaseDataloadersMixin, BaseTestMixin) + +MIN_NUM_PATCHES = 16 + +class HorizontalVisualTransformer(DatasetMixin, + BaseDataloadersMixin, + BaseTrainMixin, + BaseValMixin, + BaseTestMixin, + BaseOptimizerMixin, + LightningBaseModule + ): + + def __init__(self, hparams): + super(HorizontalVisualTransformer, self).__init__(hparams) + + # Dataset + # ============================================================================= + self.dataset = self.build_dataset() + + self.in_shape = self.dataset.train_dataset.sample_shape + assert len(self.in_shape) == 3, 'There need to be three Dimensions' + channels, height, width = self.in_shape + + # Model Paramters + # ============================================================================= + # Additional parameters + self.embed_dim = self.params.embedding_size + self.patch_size = self.params.patch_size + self.height = height + self.width = width + self.channels = channels + + self.new_height = ((self.height - self.patch_size)//1) + 1 + + num_patches = self.new_height - (self.patch_size // 2) + patch_dim = channels * self.patch_size * self.width + assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \ + f'attention. Try decreasing your patch size' + + # Correct the Embedding Dim + if not self.embed_dim % self.params.heads == 0: + self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads + message = ('Embedding Dimension was fixed to be devideable by the number' + + f' of attention heads, is now: {self.embed_dim}') + for func in print, warnings.warn: + func(message) + + # Utility Modules + self.autopad = AutoPadToShape((self.new_height, self.width)) + self.dropout = nn.Dropout(self.params.dropout) + self.slider = SlidingWindow((channels, *self.autopad.target_shape), (self.patch_size, self.width), + keepdim=False) + + # Modules with Parameters + self.transformer = TransformerModule(in_shape=self.embed_dim, hidden_size=self.params.lat_dim, + n_heads=self.params.heads, num_layers=self.params.attn_depth, + dropout=self.params.dropout, use_norm=self.params.use_norm, + activation=self.params.activation_as_string + ) + + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim)) + self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \ + else F_x(self.embed_dim) + self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim)) + self.to_cls_token = nn.Identity() + + self.mlp_head = nn.Sequential( + nn.LayerNorm(self.embed_dim), + nn.Linear(self.embed_dim, self.params.lat_dim), + nn.GELU(), + nn.Dropout(self.params.dropout), + nn.Linear(self.params.lat_dim, 1), + nn.Sigmoid() + ) + + def forward(self, x, mask=None): + """ + :param x: the sequence to the encoder (required). + :param mask: the mask for the src sequence (optional). + :return: + """ + tensor = self.autopad(x) + tensor = self.slider(tensor) + + tensor = self.patch_to_embedding(tensor) + b, n, _ = tensor.shape + + # cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) + cls_tokens = self.cls_token.repeat((b, 1, 1)) + + tensor = torch.cat((cls_tokens, tensor), dim=1) + tensor += self.pos_embedding[:, :(n + 1)] + tensor = self.dropout(tensor) + + tensor = self.transformer(tensor, mask) + + tensor = self.to_cls_token(tensor[:, 0]) + tensor = self.mlp_head(tensor) + return Namespace(main_out=tensor) diff --git a/models/transformer_model_sequential.py b/models/transformer_model_vertical.py similarity index 88% rename from models/transformer_model_sequential.py rename to models/transformer_model_vertical.py index 2a9bc97..914baf4 100644 --- a/models/transformer_model_sequential.py +++ b/models/transformer_model_vertical.py @@ -7,21 +7,22 @@ from torch import nn from ml_lib.modules.blocks import TransformerModule from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow) -from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin, - BaseDataloadersMixin) +from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin, + BaseDataloadersMixin, BaseTestMixin) MIN_NUM_PATCHES = 16 -class SequentialVisualTransformer(BinaryMaskDatasetMixin, - BaseDataloadersMixin, - BaseTrainMixin, - BaseValMixin, - BaseOptimizerMixin, - LightningBaseModule - ): +class VerticalVisualTransformer(DatasetMixin, + BaseDataloadersMixin, + BaseTrainMixin, + BaseValMixin, + BaseTestMixin, + BaseOptimizerMixin, + LightningBaseModule + ): def __init__(self, hparams): - super(SequentialVisualTransformer, self).__init__(hparams) + super(VerticalVisualTransformer, self).__init__(hparams) # Dataset # ============================================================================= diff --git a/multi_run.py b/multi_run.py index 5e996cf..7777d1e 100644 --- a/multi_run.py +++ b/multi_run.py @@ -4,6 +4,7 @@ from _paramters import main_arg_parser from main import run_lightning_loop import warnings +import shutil from ml_lib.utils.config import Config @@ -22,7 +23,7 @@ if __name__ == '__main__': arg_dict.update(main_seed=seed) if False: for patch_size in [3, 5 , 9]: - for model in ['SequentialVisualTransformer']: + for model in ['VerticalVisualTransformer']: arg_dict.update(model_type=model, model_patch_size=patch_size) raw_conf = dict(data_speed_amount=0.0, data_speed_min=0.0, data_speed_max=0.0, data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0, @@ -52,12 +53,12 @@ if __name__ == '__main__': arg_dict.update(dicts) if True: - for patch_size in [3, 7]: - for lat_dim in [4, 32]: - for heads in [2, 4]: - for embedding_size in [32, 64]: - for attn_depth in [1, 3]: - for model in ['SequentialVisualTransformer', 'VisualTransformer']: + for patch_size in [7]: + for lat_dim in [32]: + for heads in [8]: + for embedding_size in [7**2]: + for attn_depth in [1, 3, 5, 7]: + for model in ['HorizontalVisualTransformer']: arg_dict.update( model_type=model, model_patch_size=patch_size, diff --git a/util/module_mixins.py b/util/module_mixins.py index e280805..be2a7b8 100644 --- a/util/module_mixins.py +++ b/util/module_mixins.py @@ -103,9 +103,9 @@ class BaseValMixin: keys = list(output[0].keys()) ident = '' if output_idx == 0 else '_train' summary_dict.update({f'mean{ident}_{key}': torch.mean(torch.stack([output[key] - for output in output])) - for key in keys if 'loss' in key} - ) + for output in output])) + for key in keys if 'loss' in key} + ) # UnweightedAverageRecall y_true = torch.cat([output['batch_y'] for output in output]) .cpu().numpy() @@ -121,7 +121,45 @@ class BaseValMixin: self.log(key, summary_dict[key]) -class BinaryMaskDatasetMixin: +class BaseTestMixin: + + absolute_loss = nn.L1Loss() + nll_loss = nn.NLLLoss() + bce_loss = nn.BCELoss() + + def test_step(self, batch_xy, batch_idx, dataloader_idx, *args, **kwargs): + assert isinstance(self, LightningBaseModule) + batch_x, batch_y = batch_xy + y = self(batch_x).main_out + test_bce_loss = self.bce_loss(y.squeeze(), batch_y) + return dict(test_bce_loss=test_bce_loss, + batch_idx=batch_idx, y=y, batch_y=batch_y) + + def test_epoch_end(self, outputs, *_, **__): + assert isinstance(self, LightningBaseModule) + summary_dict = dict() + + keys = list(outputs[0].keys()) + summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key] + for output in outputs])) + for key in keys if 'loss' in key} + ) + + # UnweightedAverageRecall + y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy() + y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy() + + y_pred = (y_pred >= 0.5).astype(np.float32) + + uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro', + sample_weight=None, zero_division='warn') + uar_score = torch.as_tensor(uar_score) + summary_dict.update({f'uar_score': uar_score}) + for key in summary_dict.keys(): + self.log(key, summary_dict[key]) + + +class DatasetMixin: def build_dataset(self): assert isinstance(self, LightningBaseModule) @@ -159,21 +197,20 @@ class BinaryMaskDatasetMixin: util_transforms]) # Datasets - from datasets.binar_masks import BinaryMasksDataset dataset = Namespace( **dict( # TRAIN DATASET - train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train, + train_dataset=self.dataset_class(self.params.root, setting=V.DATA_OPTIONS.train, use_preprocessed=self.params.use_preprocessed, stretch_dataset=self.params.stretch, mel_transforms=mel_transforms_train, transforms=aug_transforms), # VALIDATION DATASET - val_train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train, + val_train_dataset=self.dataset_class(self.params.root, setting=V.DATA_OPTIONS.train, mel_transforms=mel_transforms, transforms=util_transforms), - val_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.devel, + val_dataset=self.dataset_class(self.params.root, setting=V.DATA_OPTIONS.devel, mel_transforms=mel_transforms, transforms=util_transforms), # TEST DATASET - test_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.test, + test_dataset=self.dataset_class(self.params.root, setting=V.DATA_OPTIONS.test, mel_transforms=mel_transforms, transforms=util_transforms), ) ) @@ -190,22 +227,23 @@ class BaseDataloadersMixin(ABC): # sampler = RandomSampler(self.dataset.train_dataset, True, len(self.dataset.train_dataset)) sampler = None return DataLoader(dataset=self.dataset.train_dataset, shuffle=True if not sampler else None, sampler=sampler, - batch_size=self.params.batch_size, + batch_size=self.params.batch_size, pin_memory=True, num_workers=self.params.worker) # Test Dataloader def test_dataloader(self): assert isinstance(self, LightningBaseModule) return DataLoader(dataset=self.dataset.test_dataset, shuffle=False, - batch_size=self.params.batch_size, + batch_size=self.params.batch_size, pin_memory=True, num_workers=self.params.worker) # Validation Dataloader def val_dataloader(self): assert isinstance(self, LightningBaseModule) - val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=False, + val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=False, pin_memory=True, batch_size=self.params.batch_size, num_workers=self.params.worker) train_dataloader = DataLoader(self.dataset.val_train_dataset, num_workers=self.params.worker, + pin_memory=True, batch_size=self.params.batch_size, shuffle=False) return [val_dataloader, train_dataloader]