From 68431b848e5b332c8d0fdbf11f6dd9370a597857 Mon Sep 17 00:00:00 2001 From: Si11ium Date: Thu, 17 Dec 2020 08:02:29 +0100 Subject: [PATCH] torchaudio testing --- _paramters.py | 2 +- datasets/base_dataset.py | 104 +++++++ datasets/binar_masks.py | 7 +- datasets/urban_8k.py | 154 ++++------- datasets/urban_8k_torchaudio.py | 140 ---------- ensemble_methods/global_inference.py | 6 +- main_inference.py | 6 +- models/transformer_model.py | 7 +- models/transformer_model_horizontal.py | 8 +- models/transformer_model_vertical.py | 7 +- multi_run.py | 174 ++++++++---- util/module_mixins.py | 363 ++++++++++++++++++------- variables.py | 18 +- 13 files changed, 578 insertions(+), 418 deletions(-) create mode 100644 datasets/base_dataset.py delete mode 100644 datasets/urban_8k_torchaudio.py diff --git a/_paramters.py b/_paramters.py index 02d1b75..3a1ea07 100644 --- a/_paramters.py +++ b/_paramters.py @@ -22,7 +22,7 @@ main_arg_parser.add_argument("--main_seed", type=int, default=69, help="") main_arg_parser.add_argument("--data_class_name", type=str, default='Urban8K', help="") main_arg_parser.add_argument("--data_worker", type=int, default=6, help="") main_arg_parser.add_argument("--data_root", type=str, default='data', help="") -main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="") +main_arg_parser.add_argument("--data_reset", type=strtobool, default=False, help="") main_arg_parser.add_argument("--data_n_mels", type=int, default=64, help="") main_arg_parser.add_argument("--data_sr", type=int, default=16000, help="") main_arg_parser.add_argument("--data_hop_length", type=int, default=256, help="") diff --git a/datasets/base_dataset.py b/datasets/base_dataset.py new file mode 100644 index 0000000..2291d6f --- /dev/null +++ b/datasets/base_dataset.py @@ -0,0 +1,104 @@ +import pickle +from pathlib import Path +from typing import Union +from abc import ABC + +import variables as V + +from torch.utils.data import Dataset + +class BaseAudioToMelDataset(Dataset, ABC): + + @property + def task_type(self): + return self._task_type + + @property + def classes(self): + return V.multi_classes + + @property + def n_classes(self): + return V.N_CLASS_binary if self.task_type == V.TASK_OPTION_binary else V.N_CLASS_multi + + @property + def sample_shape(self): + return self[0][0].shape + + @property + def _fingerprint(self): + raise NotImplementedError + return str(self._mel_transform) + + # Data Structures + @property + def mel_folder(self): + return self.data_root / 'mel' + + @property + def wav_folder(self): + return self.data_root / self._wav_folder_name + + @property + def _container_ext(self): + return '.mel' + + def __init__(self, data_root: Union[str, Path], task_type, mel_kwargs, + mel_augmentations=None, audio_augmentations=None, reset=False, + wav_folder_name='wav', **_): + super(BaseAudioToMelDataset, self).__init__() + + # Dataset Parameters + self.data_root = Path(data_root) + self._wav_folder_name = wav_folder_name + self.reset = reset + self.mel_kwargs = mel_kwargs + + # Transformations + self.mel_augmentations = mel_augmentations + self.audio_augmentations = audio_augmentations + self._task_type = task_type + + # Find all raw files and turn generator to persistent list: + self._wav_files = list(self.wav_folder.rglob('*.wav')) + + # Build the Dataset + self._dataset = self._build_dataset() + + + def __len__(self): + raise NotImplementedError + + def __getitem__(self, item): + raise NotImplementedError + + def _build_dataset(self): + raise NotImplementedError + + def _check_reset_and_clean_up(self, reset): + all_mel_folders = set([str(x.parent).replace(self._wav_folder_name, 'mel') for x in self._wav_files]) + for mel_folder in all_mel_folders: + param_storage = Path(mel_folder) / 'data_params.pik' + param_storage.parent.mkdir(parents=True, exist_ok=True) + try: + pik_data = param_storage.read_bytes() + fingerprint = pickle.loads(pik_data) + if fingerprint == self._fingerprint: + this_reset = reset + else: + print('Diverging parameters were found; Refreshing...') + param_storage.unlink() + pik_data = pickle.dumps(self._fingerprint) + param_storage.write_bytes(pik_data) + this_reset = True + + except FileNotFoundError: + pik_data = pickle.dumps(self._fingerprint) + param_storage.write_bytes(pik_data) + this_reset = True + + if this_reset: + all_mel_files = self.mel_folder.rglob(f'*{self._container_ext}') + for mel_file in all_mel_files: + mel_file.unlink() + diff --git a/datasets/binar_masks.py b/datasets/binar_masks.py index c095086..afe54f8 100644 --- a/datasets/binar_masks.py +++ b/datasets/binar_masks.py @@ -20,17 +20,20 @@ class BinaryMasksDataset(Dataset): @property def _fingerprint(self): - return dict(**self._mel_kwargs, normalize=self.normalize) + return dict(**self._mel_kwargs if self._mel_kwargs else dict()) def __init__(self, data_root, setting, mel_transforms, transforms=None, stretch_dataset=False, - use_preprocessed=True): + use_preprocessed=True, mel_kwargs=None): self.stretch = stretch_dataset assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.' assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.' super(BinaryMasksDataset, self).__init__() + self.task = V.TASK_OPTION_binary + self.data_root = Path(data_root) / 'ComParE2020_Mask' self.setting = setting + self._mel_kwargs = mel_kwargs self._wav_folder = self.data_root / 'wav' self._mel_folder = self.data_root / 'mel' self.container_ext = '.pik' diff --git a/datasets/urban_8k.py b/datasets/urban_8k.py index 69d3e16..fd49231 100644 --- a/datasets/urban_8k.py +++ b/datasets/urban_8k.py @@ -1,140 +1,78 @@ -import pickle from pathlib import Path -import multiprocessing as mp +from typing import Union, List -import librosa as librosa -from torch.utils.data import Dataset, ConcatDataset +import multiprocessing as mp +from torch.utils.data import ConcatDataset import torch from tqdm import tqdm import variables as V -from ml_lib.audio_toolset.mel_dataset import TorchMelDataset -from ml_lib.modules.util import F_x +from datasets.base_dataset import BaseAudioToMelDataset +from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset, PyTorchAudioToMelDataset -class Urban8K(Dataset): +try: + torch.multiprocessing.set_sharing_strategy('file_system') +except AttributeError: + pass - @property - def sample_shape(self): - return self[0][0].shape +class Urban8K(BaseAudioToMelDataset): - @property - def _fingerprint(self): - return str(self._mel_transform) - - def __init__(self, data_root, setting, mel_transforms, fold=1, transforms=None, - use_preprocessed=True, audio_segment_len=62, audio_hop_len=30, num_worker=mp.cpu_count(), - **_): + def __init__(self, + data_root, setting, fold: Union[int, List]=1, num_worker=mp.cpu_count(), + reset=False, sample_segment_len=50, sample_hop_len=20, + **kwargs): + self.num_worker = num_worker assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.' assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.' - assert fold in range(1, 11) - super(Urban8K, self).__init__() + assert fold in range(1, 11) if isinstance(fold, int) else all([f in range(1, 11) for f in fold]) - self.data_root = Path(data_root) / 'UrbanSound8K' + #Dataset Paramters self.setting = setting - self.num_worker = num_worker - self.fold = fold if self.setting == V.DATA_OPTIONS.train else 10 - self.use_preprocessed = use_preprocessed - self._wav_folder = self.data_root / 'audio' / f'fold{self.fold}' - self._mel_folder = self.data_root / 'mel' / f'fold{self.fold}' - self.container_ext = '.pik' - self._mel_transform = mel_transforms + fold = fold if self.setting != V.DATA_OPTION_test else 10 + self.fold = fold if isinstance(fold, list) else [fold] - self._labels = self._build_labels() - self._wav_files = list(sorted(self._labels.keys())) - transforms = transforms or F_x(in_shape=None) + self.sample_segment_len = sample_segment_len + self.sample_hop_len = sample_hop_len - param_storage = self._mel_folder / 'data_params.pik' - self._mel_folder.mkdir(parents=True, exist_ok=True) - try: - pik_data = param_storage.read_bytes() - fingerprint = pickle.loads(pik_data) - if fingerprint == self._fingerprint: - self.use_preprocessed = use_preprocessed - else: - print('Diverging parameters were found; Refreshing...') - param_storage.unlink() - pik_data = pickle.dumps(self._fingerprint) - param_storage.write_bytes(pik_data) - self.use_preprocessed = False + # Dataset specific super init + super(Urban8K, self).__init__(Path(data_root) / 'UrbanSound8K', + V.TASK_OPTION_multiclass, reset=reset, wav_folder_name='audio', **kwargs + ) - except FileNotFoundError: - pik_data = pickle.dumps(self._fingerprint) - param_storage.write_bytes(pik_data) - self.use_preprocessed = False + def _build_subdataset(self, row): + slice_file_name, fs_id, start, end, salience, fold, class_id, class_name = row.strip().split(',') + fold, class_id = (int(x) for x in (fold, class_id)) + if int(fold) in self.fold: + audio_file_path = self.wav_folder / f'fold{fold}' / slice_file_name + return PyTorchAudioToMelDataset(audio_file_path, class_id, **self.__dict__) + else: + return None - - while True: - if not self.use_preprocessed: - self._pre_process() - try: - self._dataset = ConcatDataset( - [TorchMelDataset(identifier=key, mel_path=self._mel_folder, transform=transforms, - segment_len=audio_segment_len, hop_len=audio_hop_len, - label=self._labels[key]['label'] - ) for key in self._labels.keys()] - ) - break - except IOError: - self.use_preprocessed = False - pass - - def _build_labels(self): - labeldict = dict() + def _build_dataset(self): + dataset= list() with open(Path(self.data_root) / 'metadata' / 'UrbanSound8K.csv', mode='r') as f: # Exclude the header _ = next(f) - for row in f: - slice_file_name, fs_id, start, end, salience, fold, class_id, class_name = row.strip().split(',') - if int(fold) == self.fold: - key = slice_file_name.replace('.wav', '') - labeldict[key] = dict(label=int(class_id), fold=int(fold)) + all_rows = list(f) + chunksize = len(all_rows) // max(self.num_worker, 1) + with mp.Pool(processes=self.num_worker) as pool: + with tqdm(total=len(all_rows)) as pbar: + for i, sub_dataset in enumerate( + pool.imap_unordered(self._build_subdataset, all_rows, chunksize=chunksize)): + pbar.update() + dataset.append(sub_dataset) - # Delete File if one exists. - if not self.use_preprocessed: - for key in labeldict.keys(): - for mel_file in self._mel_folder.rglob(f'{key}_*'): - try: - mel_file.unlink(missing_ok=True) - except FileNotFoundError: - pass - - return labeldict + dataset = ConcatDataset([x for x in dataset if x is not None]) + return dataset def __len__(self): return len(self._dataset) - def _pre_process(self): - print('Preprocessing Mel Files....') - with mp.Pool(processes=self.num_worker) as pool: - with tqdm(total=len(self._labels)) as pbar: - for i, _ in enumerate(pool.imap_unordered(self._build_mel, self._labels.keys())): - pbar.update() - - def _build_mel(self, filename): - - wav_file = self._wav_folder / (filename.replace('X', '') + '.wav') - mel_file = list(self._mel_folder.glob(f'{filename}_*')) - - if not mel_file: - raw_sample, sr = librosa.core.load(wav_file) - mel_sample = self._mel_transform(raw_sample) - m, n = mel_sample.shape - mel_file = self._mel_folder / f'{filename}_{m}_{n}' - self._mel_folder.mkdir(exist_ok=True, parents=True) - with mel_file.open(mode='wb') as f: - pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL) - else: - # print(f"Already existed.. Skipping {filename}") - mel_file = mel_file[0] - - with mel_file.open(mode='rb') as f: - mel_sample = pickle.load(f, fix_imports=True) - return mel_sample, mel_file def __getitem__(self, item): transformed_samples, label = self._dataset[item] - label = torch.as_tensor(label, dtype=torch.float) + label = torch.as_tensor(label, dtype=torch.int) return transformed_samples, label diff --git a/datasets/urban_8k_torchaudio.py b/datasets/urban_8k_torchaudio.py deleted file mode 100644 index bc7b3cb..0000000 --- a/datasets/urban_8k_torchaudio.py +++ /dev/null @@ -1,140 +0,0 @@ -import pickle -from pathlib import Path -import multiprocessing as mp - -import librosa as librosa -from torch.utils.data import Dataset, ConcatDataset -import torch -from tqdm import tqdm - -import variables as V -from ml_lib.audio_toolset.mel_dataset import TorchMelDataset -from ml_lib.modules.util import F_x - - -class Urban8K_TO(Dataset): - - @property - def sample_shape(self): - return self[0][0].shape - - @property - def _fingerprint(self): - return str(self._mel_transform) - - def __init__(self, data_root, setting, mel_transforms, fold=1, transforms=None, - use_preprocessed=True, audio_segment_len=1, audio_hop_len=1, num_worker=mp.cpu_count(), - **_): - assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.' - assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.' - assert fold in range(1, 11) - super(Urban8K_TO, self).__init__() - - self.data_root = Path(data_root) / 'UrbanSound8K' - self.setting = setting - self.num_worker = num_worker - self.fold = fold if self.setting == V.DATA_OPTIONS.train else 10 - self.use_preprocessed = use_preprocessed - self._wav_folder = self.data_root / 'audio' / f'fold{self.fold}' - self._mel_folder = self.data_root / 'mel' / f'fold{self.fold}' - self.container_ext = '.pik' - self._mel_transform = mel_transforms - - self._labels = self._build_labels() - self._wav_files = list(sorted(self._labels.keys())) - transforms = transforms or F_x(in_shape=None) - - param_storage = self._mel_folder / 'data_params.pik' - self._mel_folder.mkdir(parents=True, exist_ok=True) - try: - pik_data = param_storage.read_bytes() - fingerprint = pickle.loads(pik_data) - if fingerprint == self._fingerprint: - self.use_preprocessed = use_preprocessed - else: - print('Diverging parameters were found; Refreshing...') - param_storage.unlink() - pik_data = pickle.dumps(self._fingerprint) - param_storage.write_bytes(pik_data) - self.use_preprocessed = False - - except FileNotFoundError: - pik_data = pickle.dumps(self._fingerprint) - param_storage.write_bytes(pik_data) - self.use_preprocessed = False - - - while True: - if not self.use_preprocessed: - self._pre_process() - try: - self._dataset = ConcatDataset( - [TorchMelDataset(identifier=key, mel_path=self._mel_folder, transform=transforms, - segment_len=audio_segment_len, hop_len=audio_hop_len, - label=self._labels[key]['label'] - ) for key in self._labels.keys()] - ) - break - except IOError: - self.use_preprocessed = False - pass - - def _build_labels(self): - labeldict = dict() - with open(Path(self.data_root) / 'metadata' / 'UrbanSound8K.csv', mode='r') as f: - # Exclude the header - _ = next(f) - for row in f: - slice_file_name, fs_id, start, end, salience, fold, class_id, class_name = row.strip().split(',') - if int(fold) == self.fold: - key = slice_file_name.replace('.wav', '') - labeldict[key] = dict(label=int(class_id), fold=int(fold)) - - # Delete File if one exists. - if not self.use_preprocessed: - for key in labeldict.keys(): - for mel_file in self._mel_folder.rglob(f'{key}_*'): - try: - mel_file.unlink(missing_ok=True) - except FileNotFoundError: - pass - - return labeldict - - def __len__(self): - return len(self._dataset) - - def _pre_process(self): - print('Preprocessing Mel Files....') - with mp.Pool(processes=self.num_worker) as pool: - with tqdm(total=len(self._labels)) as pbar: - for i, _ in enumerate(pool.imap_unordered(self._build_mel, self._labels.keys())): - pbar.update() - - def _build_mel(self, filename): - - wav_file = self._wav_folder / (filename.replace('X', '') + '.wav') - mel_file = list(self._mel_folder.glob(f'{filename}_*')) - - if not mel_file: - raw_sample, sr = librosa.core.load(wav_file) - mel_sample = self._mel_transform(raw_sample) - m, n = mel_sample.shape - mel_file = self._mel_folder / f'{filename}_{m}_{n}' - self._mel_folder.mkdir(exist_ok=True, parents=True) - with mel_file.open(mode='wb') as f: - pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL) - else: - # print(f"Already existed.. Skipping {filename}") - mel_file = mel_file[0] - - with mel_file.open(mode='rb') as f: - mel_sample = pickle.load(f, fix_imports=True) - return mel_sample, mel_file - - def __getitem__(self, item): - transformed_samples, label = self._dataset[item] - - label = torch.as_tensor(label, dtype=torch.float) - - return transformed_samples, label diff --git a/ensemble_methods/global_inference.py b/ensemble_methods/global_inference.py index 950b295..5f26609 100644 --- a/ensemble_methods/global_inference.py +++ b/ensemble_methods/global_inference.py @@ -9,7 +9,7 @@ from torch.utils.data import DataLoader, Dataset from torchvision.transforms import Compose, RandomApply from ml_lib.audio_toolset.audio_augmentation import Speed -from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal, MelToImage +from ml_lib.audio_toolset.audio_io import LibrosaAudioToMel, NormalizeLocal, MelToImage # Dataset and Dataloaders # ============================================================================= @@ -28,8 +28,8 @@ from datasets.binar_masks import BinaryMasksDataset def prepare_dataloader(config_obj): mel_transforms = Compose([ - AudioToMel(sr=config_obj.data.sr, n_mels=config_obj.data.n_mels, n_fft=config_obj.data.n_fft, - hop_length=config_obj.data.hop_length), + LibrosaAudioToMel(sr=config_obj.data.sr, n_mels=config_obj.data.n_mels, n_fft=config_obj.data.n_fft, + hop_length=config_obj.data.hop_length), MelToImage()]) transforms = Compose([NormalizeLocal(), ToTensor()]) """ diff --git a/main_inference.py b/main_inference.py index 3068648..9c83336 100644 --- a/main_inference.py +++ b/main_inference.py @@ -8,7 +8,7 @@ from torch.utils.data import DataLoader, Dataset from torchvision.transforms import Compose, RandomApply from ml_lib.audio_toolset.audio_augmentation import Speed -from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal, MelToImage +from ml_lib.audio_toolset.audio_io import LibrosaAudioToMel, NormalizeLocal, MelToImage # Dataset and Dataloaders # ============================================================================= @@ -26,8 +26,8 @@ from datasets.binar_masks import BinaryMasksDataset def prepare_dataloader(config_obj): mel_transforms = Compose([ - AudioToMel(sr=config_obj.data.sr, n_mels=config_obj.data.n_mels, n_fft=config_obj.data.n_fft, - hop_length=config_obj.data.hop_length), + LibrosaAudioToMel(sr=config_obj.data.sr, n_mels=config_obj.data.n_mels, n_fft=config_obj.data.n_fft, + hop_length=config_obj.data.hop_length), MelToImage()]) transforms = Compose([NormalizeLocal(), ToTensor()]) aug_transforms = Compose([ diff --git a/models/transformer_model.py b/models/transformer_model.py index 80deea4..156c0a0 100644 --- a/models/transformer_model.py +++ b/models/transformer_model.py @@ -10,11 +10,12 @@ from einops import rearrange, repeat from ml_lib.modules.blocks import TransformerModule from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x) from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin, - BaseDataloadersMixin, BaseTestMixin) + BaseDataloadersMixin, BaseTestMixin, BaseLossMixin) MIN_NUM_PATCHES = 16 class VisualTransformer(DatasetMixin, + BaseLossMixin, BaseDataloadersMixin, BaseTrainMixin, BaseValMixin, @@ -84,8 +85,8 @@ class VisualTransformer(DatasetMixin, nn.Linear(self.embed_dim, self.params.lat_dim), nn.GELU(), nn.Dropout(self.params.dropout), - nn.Linear(self.params.lat_dim, 1), - nn.Sigmoid() + nn.Linear(self.params.lat_dim, 10), + nn.Softmax() ) def forward(self, x, mask=None): diff --git a/models/transformer_model_horizontal.py b/models/transformer_model_horizontal.py index 701ccb4..0c5ecc0 100644 --- a/models/transformer_model_horizontal.py +++ b/models/transformer_model_horizontal.py @@ -8,11 +8,12 @@ from torch import nn from ml_lib.modules.blocks import TransformerModule from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow) from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin, - BaseDataloadersMixin, BaseTestMixin) + BaseDataloadersMixin, BaseTestMixin, BaseLossMixin) MIN_NUM_PATCHES = 16 class HorizontalVisualTransformer(DatasetMixin, + BaseLossMixin, BaseDataloadersMixin, BaseTrainMixin, BaseValMixin, @@ -35,6 +36,7 @@ class HorizontalVisualTransformer(DatasetMixin, # Model Paramters # ============================================================================= # Additional parameters + self.n_classes = self.dataset.train_dataset.n_classes self.embed_dim = self.params.embedding_size self.patch_size = self.params.patch_size self.height = height @@ -81,8 +83,8 @@ class HorizontalVisualTransformer(DatasetMixin, nn.Linear(self.embed_dim, self.params.lat_dim), nn.GELU(), nn.Dropout(self.params.dropout), - nn.Linear(self.params.lat_dim, 1), - nn.Sigmoid() + nn.Linear(self.params.lat_dim, 10), + nn.Softmax() ) def forward(self, x, mask=None): diff --git a/models/transformer_model_vertical.py b/models/transformer_model_vertical.py index 914baf4..44351c0 100644 --- a/models/transformer_model_vertical.py +++ b/models/transformer_model_vertical.py @@ -8,11 +8,12 @@ from torch import nn from ml_lib.modules.blocks import TransformerModule from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow) from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin, - BaseDataloadersMixin, BaseTestMixin) + BaseDataloadersMixin, BaseTestMixin, BaseLossMixin) MIN_NUM_PATCHES = 16 class VerticalVisualTransformer(DatasetMixin, + BaseLossMixin, BaseDataloadersMixin, BaseTrainMixin, BaseValMixin, @@ -80,8 +81,8 @@ class VerticalVisualTransformer(DatasetMixin, nn.Linear(self.embed_dim, self.params.lat_dim), nn.GELU(), nn.Dropout(self.params.dropout), - nn.Linear(self.params.lat_dim, 1), - nn.Sigmoid() + nn.Linear(self.params.lat_dim, 10), + nn.Softmax() ) def forward(self, x, mask=None): diff --git a/multi_run.py b/multi_run.py index 7777d1e..882b2bd 100644 --- a/multi_run.py +++ b/multi_run.py @@ -14,64 +14,134 @@ warnings.filterwarnings('ignore', category=UserWarning) if __name__ == '__main__': - args = main_arg_parser.parse_args() - # Model Settings - config = Config().read_namespace(args) + if False: + args = main_arg_parser.parse_args() + # Model Settings + config = Config().read_namespace(args) - arg_dict = dict() - for seed in range(1): - arg_dict.update(main_seed=seed) - if False: - for patch_size in [3, 5 , 9]: - for model in ['VerticalVisualTransformer']: - arg_dict.update(model_type=model, model_patch_size=patch_size) - raw_conf = dict(data_speed_amount=0.0, data_speed_min=0.0, data_speed_max=0.0, - data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0, - data_stretch=False, train_epochs=401) + arg_dict = dict() + for seed in range(1): + arg_dict.update(main_seed=seed) + if False: + for patch_size in [3, 5 , 9]: + for model in ['VerticalVisualTransformer']: + arg_dict.update(model_type=model, model_patch_size=patch_size) + raw_conf = dict(data_speed_amount=0.0, data_speed_min=0.0, data_speed_max=0.0, + data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0, + data_stretch=False, train_epochs=401) - all_conf = dict(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7, - data_mask_ratio=0.2, data_noise_ratio=0.4, data_shift_ratio=0.4, data_loudness_ratio=0.4, - data_stretch=True, train_epochs=101) + all_conf = dict(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7, + data_mask_ratio=0.2, data_noise_ratio=0.4, data_shift_ratio=0.4, data_loudness_ratio=0.4, + data_stretch=True, train_epochs=101) - speed_conf = raw_conf.copy() - speed_conf.update(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7, - data_stretch=True, train_epochs=101) + speed_conf = raw_conf.copy() + speed_conf.update(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7, + data_stretch=True, train_epochs=101) - mask_conf = raw_conf.copy() - mask_conf.update(data_mask_ratio=0.2, data_stretch=True, train_epochs=101) + mask_conf = raw_conf.copy() + mask_conf.update(data_mask_ratio=0.2, data_stretch=True, train_epochs=101) - noise_conf = raw_conf.copy() - noise_conf.update(data_noise_ratio=0.4, data_stretch=True, train_epochs=101) + noise_conf = raw_conf.copy() + noise_conf.update(data_noise_ratio=0.4, data_stretch=True, train_epochs=101) - shift_conf = raw_conf.copy() - shift_conf.update(data_shift_ratio=0.4, data_stretch=True, train_epochs=101) + shift_conf = raw_conf.copy() + shift_conf.update(data_shift_ratio=0.4, data_stretch=True, train_epochs=101) - loudness_conf = raw_conf.copy() - loudness_conf.update(data_loudness_ratio=0.4, data_stretch=True, train_epochs=101) + loudness_conf = raw_conf.copy() + loudness_conf.update(data_loudness_ratio=0.4, data_stretch=True, train_epochs=101) - for dicts in [raw_conf, all_conf, speed_conf, mask_conf, noise_conf, shift_conf, loudness_conf]: + for dicts in [raw_conf, all_conf, speed_conf, mask_conf, noise_conf, shift_conf, loudness_conf]: + + arg_dict.update(dicts) + if True: + for patch_size in [7]: + for lat_dim in [32]: + for heads in [8]: + for embedding_size in [7**2]: + for attn_depth in [1, 3, 5, 7]: + for model in ['HorizontalVisualTransformer']: + arg_dict.update( + model_type=model, + model_patch_size=patch_size, + model_lat_dim=lat_dim, + model_heads=heads, + model_embedding_size=embedding_size, + model_attn_depth=attn_depth + ) + config = config.update(arg_dict) + version_path = config.exp_path / config.version + if version_path.exists(): + if not (version_path / 'weights.ckpt').exists(): + shutil.rmtree(version_path) + else: + continue + run_lightning_loop(config) + + import matplotlib + import matplotlib.pyplot as plt + import matplotlib.cm as cm + import numpy as np + + from diffractio import degrees, mm, plt, sp, um, np + from diffractio.scalar_fields_XY import Scalar_field_XY + from diffractio.utils_drawing import draw_several_fields + from diffractio.scalar_masks_XY import Scalar_mask_XY + from diffractio.scalar_sources_XY import Scalar_source_XY + + from matplotlib import rcParams + + rcParams['figure.figsize']=(7,5) + rcParams['figure.dpi']=75 + + period = 20 * um + num_pixels = 512 + + length = 250 * um + x0 = np.linspace(-length / 2, length / 2, num_pixels) + y0 = np.linspace(-length / 2, length / 2, num_pixels) + wavelength = 0.6238 * um + + u1 = Scalar_source_XY(x=x0, y=y0, wavelength=wavelength) + u1.plane_wave(A=1, theta=0 * degrees, phi=0 * degrees) + + t1 = Scalar_mask_XY(x=x0, y=y0, wavelength=wavelength) + t1.forked_grating(kind='amplitude', + r0=(0 * um, 0 * um), period=period, l=3, alpha=2, angle=0 * degrees) + + u2 = u1 * t1 + + t2 = Scalar_mask_XY(x=x0, y=y0, wavelength=wavelength) + t2.roughness(t=(20 * um, 20 * um), s=1 * um) + + u2 = u2 * t2 + u2.draw(kind='phase') + + u3 = u2.RS(z=1 * mm, new_field=True) + + u4 = u2.RS(z=5 * mm, new_field=True) + + u5 = u2.RS(z=10 * mm, new_field=True) + + print('draw') + + draw_several_fields((u3, u4, u5), titulos=('1 mm', '5 mm', '10 mm'), logarithm=True) + + plt.show() + + pass + + u2 = t2 * u1 + u2.draw(kind='phase') + + u3 = u2.RS(z=1 * mm, new_field=True) + + u4 = u2.RS(z=5 * mm, new_field=True) + + u5 = u2.RS(z=10 * mm, new_field=True) + + print('draw') + + draw_several_fields((u3, u4, u5), titulos=('1 mm', '5 mm', '10 mm'), logarithm=True) + + plt.show() - arg_dict.update(dicts) - if True: - for patch_size in [7]: - for lat_dim in [32]: - for heads in [8]: - for embedding_size in [7**2]: - for attn_depth in [1, 3, 5, 7]: - for model in ['HorizontalVisualTransformer']: - arg_dict.update( - model_type=model, - model_patch_size=patch_size, - model_lat_dim=lat_dim, - model_heads=heads, - model_embedding_size=embedding_size, - model_attn_depth=attn_depth - ) - config = config.update(arg_dict) - version_path = config.exp_path / config.version - if version_path.exists(): - if not (version_path / 'weights.ckpt').exists(): - shutil.rmtree(version_path) - else: - continue - run_lightning_loop(config) diff --git a/util/module_mixins.py b/util/module_mixins.py index be2a7b8..924ef4e 100644 --- a/util/module_mixins.py +++ b/util/module_mixins.py @@ -1,11 +1,17 @@ -from collections import defaultdict - +# Imports from python Internals from abc import ABC -from argparse import Namespace +from itertools import cycle +from collections import defaultdict, namedtuple -import sklearn -import torch +# Numerical Imports, Metrics and Plotting import numpy as np +from sklearn.ensemble import IsolationForest +from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, roc_auc_score, roc_curve, auc, f1_score, \ + recall_score, average_precision_score +from matplotlib import pyplot as plt + +# Import Deep Learning Framework +import torch from torch import nn from torch.optim import Adam from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts @@ -13,15 +19,25 @@ from torch.utils.data import DataLoader from torchcontrib.optim import SWA from torchvision.transforms import Compose, RandomApply -from ml_lib.audio_toolset.audio_augmentation import Speed +# Import Functions and Modules from MLLIB from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug -from ml_lib.audio_toolset.audio_io import AudioToMel, MelToImage, NormalizeLocal +from ml_lib.audio_toolset.audio_io import NormalizeLocal from ml_lib.modules.util import LightningBaseModule +from ml_lib.utils.tools import to_one_hot from ml_lib.utils.transforms import ToTensor +# Import Project Variables import variables as V +class BaseLossMixin: + + absolute_loss = nn.L1Loss() + nll_loss = nn.NLLLoss() + bce_loss = nn.BCELoss() + ce_loss = nn.CrossEntropyLoss() + + class BaseOptimizerMixin: def configure_optimizers(self): @@ -60,16 +76,12 @@ class BaseOptimizerMixin: class BaseTrainMixin: - absolute_loss = nn.L1Loss() - nll_loss = nn.NLLLoss() - bce_loss = nn.BCELoss() - def training_step(self, batch_xy, batch_nb, *args, **kwargs): assert isinstance(self, LightningBaseModule) batch_x, batch_y = batch_xy y = self(batch_x).main_out - bce_loss = self.bce_loss(y.squeeze(), batch_y) - return dict(loss=bce_loss) + loss = self.ce_loss(y.squeeze(), batch_y.long()) + return dict(loss=loss) def training_epoch_end(self, outputs): assert isinstance(self, LightningBaseModule) @@ -84,55 +96,39 @@ class BaseTrainMixin: class BaseValMixin: - absolute_loss = nn.L1Loss() - nll_loss = nn.NLLLoss() - bce_loss = nn.BCELoss() - - def validation_step(self, batch_xy, batch_idx, dataloader_idx, *args, **kwargs): + def validation_step(self, batch_xy, batch_idx, *args, **kwargs): assert isinstance(self, LightningBaseModule) batch_x, batch_y = batch_xy y = self(batch_x).main_out - val_bce_loss = self.bce_loss(y.squeeze(), batch_y) - return dict(val_bce_loss=val_bce_loss, + val_loss = self.ce_loss(y.squeeze(), batch_y.long()) + return dict(val_loss=val_loss, batch_idx=batch_idx, y=y, batch_y=batch_y) def validation_epoch_end(self, outputs, *_, **__): assert isinstance(self, LightningBaseModule) summary_dict = dict() - for output_idx, output in enumerate(outputs): - keys = list(output[0].keys()) - ident = '' if output_idx == 0 else '_train' - summary_dict.update({f'mean{ident}_{key}': torch.mean(torch.stack([output[key] - for output in output])) - for key in keys if 'loss' in key} - ) - # UnweightedAverageRecall - y_true = torch.cat([output['batch_y'] for output in output]) .cpu().numpy() - y_pred = torch.cat([output['y'] for output in output]).squeeze().cpu().numpy() + keys = list(outputs[0].keys()) + summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key] + for output in outputs])) + for key in keys if 'loss' in key} + ) - y_pred = (y_pred >= 0.5).astype(np.float32) + additional_scores = self.additional_scores(outputs) + summary_dict.update(**additional_scores) - uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro', - sample_weight=None, zero_division='warn') - uar_score = torch.as_tensor(uar_score) - summary_dict.update({f'uar{ident}_score': uar_score}) - for key in summary_dict.keys(): - self.log(key, summary_dict[key]) + for key in summary_dict.keys(): + self.log(key, summary_dict[key]) class BaseTestMixin: - absolute_loss = nn.L1Loss() - nll_loss = nn.NLLLoss() - bce_loss = nn.BCELoss() - - def test_step(self, batch_xy, batch_idx, dataloader_idx, *args, **kwargs): + def test_step(self, batch_xy, batch_idx, *_, **__): assert isinstance(self, LightningBaseModule) batch_x, batch_y = batch_xy y = self(batch_x).main_out - test_bce_loss = self.bce_loss(y.squeeze(), batch_y) - return dict(test_bce_loss=test_bce_loss, + test_loss = self.ce_loss(y.squeeze(), batch_y.long()) + return dict(test_loss=test_loss, batch_idx=batch_idx, y=y, batch_y=batch_y) def test_epoch_end(self, outputs, *_, **__): @@ -145,16 +141,9 @@ class BaseTestMixin: for key in keys if 'loss' in key} ) - # UnweightedAverageRecall - y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy() - y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy() + additional_scores = self.additional_scores(outputs) + summary_dict.update(**additional_scores) - y_pred = (y_pred >= 0.5).astype(np.float32) - - uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro', - sample_weight=None, zero_division='warn') - uar_score = torch.as_tensor(uar_score) - summary_dict.update({f'uar_score': uar_score}) for key in summary_dict.keys(): self.log(key, summary_dict[key]) @@ -167,53 +156,56 @@ class DatasetMixin: # Dataset # ============================================================================= # Mel Transforms - mel_transforms = Compose([ - # Audio to Mel Transformations - AudioToMel(sr=self.params.sr, - n_mels=self.params.n_mels, - n_fft=self.params.n_fft, - hop_length=self.params.hop_length), - MelToImage()]) - - mel_transforms_train = Compose([ - # Audio to Mel Transformations - Speed(max_amount=self.params.speed_amount, - speed_min=self.params.speed_min, - speed_max=self.params.speed_max - ), - mel_transforms]) + mel_kwargs = dict(sample_rate=self.params.sr, + n_mels=self.params.n_mels, + n_fft=self.params.n_fft, + hop_length=self.params.hop_length) # Utility - util_transforms = Compose([NormalizeLocal(), ToTensor()]) + normalize = NormalizeLocal() # Data Augmentations - aug_transforms = Compose([ + mel_augmentations = Compose([ RandomApply([ - NoiseInjection(self.params.noise_ratio), - LoudnessManipulator(self.params.loudness_ratio), - ShiftTime(self.params.shift_ratio), - MaskAug(self.params.mask_ratio), + NoiseInjection(0.2), + LoudnessManipulator(0.5), + ShiftTime(0.4), + MaskAug(0.2), ], p=0.6), - util_transforms]) + normalize]) # Datasets - dataset = Namespace( - **dict( - # TRAIN DATASET - train_dataset=self.dataset_class(self.params.root, setting=V.DATA_OPTIONS.train, - use_preprocessed=self.params.use_preprocessed, - stretch_dataset=self.params.stretch, - mel_transforms=mel_transforms_train, transforms=aug_transforms), - # VALIDATION DATASET - val_train_dataset=self.dataset_class(self.params.root, setting=V.DATA_OPTIONS.train, - mel_transforms=mel_transforms, transforms=util_transforms), - val_dataset=self.dataset_class(self.params.root, setting=V.DATA_OPTIONS.devel, - mel_transforms=mel_transforms, transforms=util_transforms), - # TEST DATASET - test_dataset=self.dataset_class(self.params.root, setting=V.DATA_OPTIONS.test, - mel_transforms=mel_transforms, transforms=util_transforms), - ) - ) + Dataset = namedtuple('Datasets', 'train_dataset val_dataset test_dataset') + dataset = Dataset(self.dataset_class(data_root=self.params.root, # TRAIN DATASET + setting=V.DATA_OPTION_train, + fold=list(range(1,8)), + reset=self.params.reset, + mel_kwargs=mel_kwargs, + mel_augmentations=mel_augmentations), + val_dataset=self.dataset_class(data_root=self.params.root, # VALIDATION DATASET + setting=V.DATA_OPTION_devel, + fold=9, + reset=self.params.reset, + mel_kwargs=mel_kwargs, + mel_augmentations=normalize), + test_dataset=self.dataset_class(data_root=self.params.root, # TEST DATASET + setting=V.DATA_OPTION_test, + fold=10, + reset=self.params.reset, + mel_kwargs=mel_kwargs, + mel_augmentations=normalize), + ) + + if dataset.train_dataset.task_type == V.TASK_OPTION_binary: + # noinspection PyAttributeOutsideInit + self.additional_scores = BinaryScores(self) + + elif dataset.train_dataset.task_type == V.TASK_OPTION_multiclass: + # noinspection PyAttributeOutsideInit + self.additional_scores = MultiClassScores(self) + else: + raise ValueError + return dataset @@ -240,10 +232,185 @@ class BaseDataloadersMixin(ABC): # Validation Dataloader def val_dataloader(self): assert isinstance(self, LightningBaseModule) - val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=False, pin_memory=True, - batch_size=self.params.batch_size, num_workers=self.params.worker) + return DataLoader(dataset=self.dataset.val_dataset, shuffle=False, pin_memory=True, + batch_size=self.params.batch_size, num_workers=self.params.worker) - train_dataloader = DataLoader(self.dataset.val_train_dataset, num_workers=self.params.worker, - pin_memory=True, - batch_size=self.params.batch_size, shuffle=False) - return [val_dataloader, train_dataloader] + +class BaseScores(ABC): + + def __init__(self, lightning_model): + self.model = lightning_model + pass + + def __call__(self, outputs): + # summary_dict = dict() + # return summary_dict + raise NotImplementedError + + +class MultiClassScores(BaseScores): + + def __init__(self, *args): + super(MultiClassScores, self).__init__(*args) + pass + + def __call__(self, outputs): + summary_dict = dict() + ####################################################################################### + # Additional Score - UAR - ROC - Conf. Matrix - F1 + ####################################################################################### + # + # INIT + y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy() + y_true_one_hot = to_one_hot(y_true, self.model.n_classes) + + y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy() + y_pred_max = np.argmax(y_pred, axis=1) + + class_names = {val: key for key, val in self.model.dataset.test_dataset.classes.items()} + ###################################################################################### + # + # F1 SCORE + micro_f1_score = f1_score(y_true, y_pred_max, labels=None, pos_label=1, average='micro', sample_weight=None, + zero_division=True) + macro_f1_score = f1_score(y_true, y_pred_max, labels=None, pos_label=1, average='macro', sample_weight=None, + zero_division=True) + summary_dict.update(dict(micro_f1_score=micro_f1_score, macro_f1_score=macro_f1_score)) + + ####################################################################################### + # + # ROC Curve + + # Compute ROC curve and ROC area for each class + fpr = dict() + tpr = dict() + roc_auc = dict() + for i in range(self.model.n_classes): + fpr[i], tpr[i], _ = roc_curve(y_true_one_hot[:, i], y_pred[:, i]) + roc_auc[i] = auc(fpr[i], tpr[i]) + + # Compute micro-average ROC curve and ROC area + fpr["micro"], tpr["micro"], _ = roc_curve(y_true_one_hot.ravel(), y_pred.ravel()) + roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) + + # First aggregate all false positive rates + all_fpr = np.unique(np.concatenate([fpr[i] for i in range(self.model.n_classes)])) + + # Then interpolate all ROC curves at this points + mean_tpr = np.zeros_like(all_fpr) + for i in range(self.model.n_classes): + mean_tpr += np.interp(all_fpr, fpr[i], tpr[i]) + + # Finally average it and compute AUC + mean_tpr /= self.model.n_classes + + fpr["macro"] = all_fpr + tpr["macro"] = mean_tpr + roc_auc["macro"] = auc(fpr["macro"], tpr["macro"]) + + # Plot all ROC curves + plt.figure() + plt.plot(fpr["micro"], tpr["micro"], + label=f'micro ROC ({round(roc_auc["micro"], 2)})', + color='deeppink', linestyle=':', linewidth=4) + + plt.plot(fpr["macro"], tpr["macro"], + label=f'macro ROC({round(roc_auc["macro"], 2)})', + color='navy', linestyle=':', linewidth=4) + + colors = cycle(['firebrick', 'orangered', 'gold', 'olive', 'limegreen', 'aqua', + 'dodgerblue', 'slategrey', 'royalblue', 'indigo', 'fuchsia'], ) + + for i, color in zip(range(self.model.n_classes), colors): + plt.plot(fpr[i], tpr[i], color=color, lw=2, label=f'{class_names[i]} ({round(roc_auc[i], 2)})') + + plt.plot([0, 1], [0, 1], 'k--', lw=2) + plt.xlim([0.0, 1.0]) + plt.ylim([0.0, 1.05]) + plt.xlabel('False Positive Rate') + plt.ylabel('True Positive Rate') + plt.legend(loc="lower right") + + self.model.logger.log_image('ROC', image=plt.gcf(), step=self.model.current_epoch) + self.model.logger.log_image('ROC', image=plt.gcf(), step=self.model.current_epoch, ext='pdf') + plt.clf() + + ####################################################################################### + # + # ROC SCORE + + try: + macro_roc_auc_ovr = roc_auc_score(y_true_one_hot, y_pred, multi_class="ovr", + average="macro") + summary_dict.update(macro_roc_auc_ovr=macro_roc_auc_ovr) + except ValueError: + micro_roc_auc_ovr = roc_auc_score(y_true_one_hot, y_pred, multi_class="ovr", + average="micro") + summary_dict.update(micro_roc_auc_ovr=micro_roc_auc_ovr) + + ####################################################################################### + # + # Confusion matrix + + cm = confusion_matrix([class_names[x] for x in y_true], [class_names[x] for x in y_pred_max], + labels=[class_names[key] for key in class_names.keys()], + normalize='all') + disp = ConfusionMatrixDisplay(confusion_matrix=cm, + display_labels=[class_names[i] for i in range(self.model.n_classes)] + ) + disp.plot(include_values=True) + + self.model.logger.log_image('Confusion_Matrix', image=disp.figure_, step=self.model.current_epoch) + self.model.logger.log_image('Confusion_Matrix', image=disp.figure_, step=self.model.current_epoch, ext='pdf') + + plt.close('all') + return summary_dict + + +class BinaryScores(BaseScores): + + def __init__(self, *args): + super(BinaryScores, self).__init__(*args) + + def __call__(self, outputs): + summary_dict = dict() + + # Additional Score like the unweighted Average Recall: + ######################### + # UnweightedAverageRecall + y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy() + y_pred = torch.cat([output['element_wise_recon_error'] for output in outputs]).squeeze().cpu().numpy() + + # How to apply a threshold manualy + # y_pred = (y_pred >= 0.5).astype(np.float32) + + # How to apply a threshold by IF (Isolation Forest) + clf = IsolationForest(random_state=self.model.seed) + y_score = clf.fit_predict(y_pred.reshape(-1,1)) + y_score = (np.asarray(y_score) == -1).astype(np.float32) + + uar_score = recall_score(y_true, y_score, labels=[0, 1], average='macro', + sample_weight=None, zero_division='warn') + summary_dict.update(dict(uar_score=uar_score)) + ######################### + # Precission + precision_score = average_precision_score(y_true, y_score) + summary_dict.update(dict(precision_score=precision_score)) + + ######################### + # AUC + try: + auc_score = roc_auc_score(y_true=y_true, y_score=y_score) + summary_dict.update(dict(auc_score=auc_score)) + except ValueError: + summary_dict.update(dict(auc_score=-1)) + + ######################### + # pAUC + try: + pauc = roc_auc_score(y_true=y_true, y_score=y_score, max_fpr=0.15) + summary_dict.update(dict(pauc_score=pauc)) + except ValueError: + summary_dict.update(dict(pauc_score=-1)) + + return summary_dict diff --git a/variables.py b/variables.py index 9995fb9..3a7ee9c 100644 --- a/variables.py +++ b/variables.py @@ -4,8 +4,22 @@ from argparse import Namespace CLEAR = 0 MASK = 1 -NUM_CLASSES = 2 +# Task Options +TASK_OPTION_multiclass = 'multiclass' +N_CLASS_multi = 10 +multi_classes_names = ['air_conditioner', 'car_horn', 'children_playing', + 'dog_bar', 'drilling', 'engine_idling', + 'gun_shot', 'jackhammer', 'siren', 'street_music'] +multi_classes = {key: val for val, key in enumerate(multi_classes_names)} +TASK_OPTION_binary = 'binary' +N_CLASS_binary = 2 +binary_CLASS_clear = 0 +binary_CLASS_maske = 1 + # Dataset Options -DATA_OPTIONS = Namespace(test='test', devel='devel', train='train') +DATA_OPTION_test = 'test' +DATA_OPTION_devel = 'devel' +DATA_OPTION_train = 'train' +DATA_OPTIONS = [DATA_OPTION_train, DATA_OPTION_devel, DATA_OPTION_test]