torchaudio testing

This commit is contained in:
Si11ium 2020-12-17 08:02:29 +01:00
parent 95dcf22f3d
commit 68431b848e
13 changed files with 578 additions and 418 deletions

View File

@ -22,7 +22,7 @@ main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
main_arg_parser.add_argument("--data_class_name", type=str, default='Urban8K', help="")
main_arg_parser.add_argument("--data_worker", type=int, default=6, help="")
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--data_reset", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--data_n_mels", type=int, default=64, help="")
main_arg_parser.add_argument("--data_sr", type=int, default=16000, help="")
main_arg_parser.add_argument("--data_hop_length", type=int, default=256, help="")

104
datasets/base_dataset.py Normal file
View File

@ -0,0 +1,104 @@
import pickle
from pathlib import Path
from typing import Union
from abc import ABC
import variables as V
from torch.utils.data import Dataset
class BaseAudioToMelDataset(Dataset, ABC):
@property
def task_type(self):
return self._task_type
@property
def classes(self):
return V.multi_classes
@property
def n_classes(self):
return V.N_CLASS_binary if self.task_type == V.TASK_OPTION_binary else V.N_CLASS_multi
@property
def sample_shape(self):
return self[0][0].shape
@property
def _fingerprint(self):
raise NotImplementedError
return str(self._mel_transform)
# Data Structures
@property
def mel_folder(self):
return self.data_root / 'mel'
@property
def wav_folder(self):
return self.data_root / self._wav_folder_name
@property
def _container_ext(self):
return '.mel'
def __init__(self, data_root: Union[str, Path], task_type, mel_kwargs,
mel_augmentations=None, audio_augmentations=None, reset=False,
wav_folder_name='wav', **_):
super(BaseAudioToMelDataset, self).__init__()
# Dataset Parameters
self.data_root = Path(data_root)
self._wav_folder_name = wav_folder_name
self.reset = reset
self.mel_kwargs = mel_kwargs
# Transformations
self.mel_augmentations = mel_augmentations
self.audio_augmentations = audio_augmentations
self._task_type = task_type
# Find all raw files and turn generator to persistent list:
self._wav_files = list(self.wav_folder.rglob('*.wav'))
# Build the Dataset
self._dataset = self._build_dataset()
def __len__(self):
raise NotImplementedError
def __getitem__(self, item):
raise NotImplementedError
def _build_dataset(self):
raise NotImplementedError
def _check_reset_and_clean_up(self, reset):
all_mel_folders = set([str(x.parent).replace(self._wav_folder_name, 'mel') for x in self._wav_files])
for mel_folder in all_mel_folders:
param_storage = Path(mel_folder) / 'data_params.pik'
param_storage.parent.mkdir(parents=True, exist_ok=True)
try:
pik_data = param_storage.read_bytes()
fingerprint = pickle.loads(pik_data)
if fingerprint == self._fingerprint:
this_reset = reset
else:
print('Diverging parameters were found; Refreshing...')
param_storage.unlink()
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
this_reset = True
except FileNotFoundError:
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
this_reset = True
if this_reset:
all_mel_files = self.mel_folder.rglob(f'*{self._container_ext}')
for mel_file in all_mel_files:
mel_file.unlink()

View File

@ -20,17 +20,20 @@ class BinaryMasksDataset(Dataset):
@property
def _fingerprint(self):
return dict(**self._mel_kwargs, normalize=self.normalize)
return dict(**self._mel_kwargs if self._mel_kwargs else dict())
def __init__(self, data_root, setting, mel_transforms, transforms=None, stretch_dataset=False,
use_preprocessed=True):
use_preprocessed=True, mel_kwargs=None):
self.stretch = stretch_dataset
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
super(BinaryMasksDataset, self).__init__()
self.task = V.TASK_OPTION_binary
self.data_root = Path(data_root) / 'ComParE2020_Mask'
self.setting = setting
self._mel_kwargs = mel_kwargs
self._wav_folder = self.data_root / 'wav'
self._mel_folder = self.data_root / 'mel'
self.container_ext = '.pik'

View File

@ -1,140 +1,78 @@
import pickle
from pathlib import Path
import multiprocessing as mp
from typing import Union, List
import librosa as librosa
from torch.utils.data import Dataset, ConcatDataset
import multiprocessing as mp
from torch.utils.data import ConcatDataset
import torch
from tqdm import tqdm
import variables as V
from ml_lib.audio_toolset.mel_dataset import TorchMelDataset
from ml_lib.modules.util import F_x
from datasets.base_dataset import BaseAudioToMelDataset
from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset, PyTorchAudioToMelDataset
class Urban8K(Dataset):
try:
torch.multiprocessing.set_sharing_strategy('file_system')
except AttributeError:
pass
@property
def sample_shape(self):
return self[0][0].shape
class Urban8K(BaseAudioToMelDataset):
@property
def _fingerprint(self):
return str(self._mel_transform)
def __init__(self, data_root, setting, mel_transforms, fold=1, transforms=None,
use_preprocessed=True, audio_segment_len=62, audio_hop_len=30, num_worker=mp.cpu_count(),
**_):
def __init__(self,
data_root, setting, fold: Union[int, List]=1, num_worker=mp.cpu_count(),
reset=False, sample_segment_len=50, sample_hop_len=20,
**kwargs):
self.num_worker = num_worker
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
assert fold in range(1, 11)
super(Urban8K, self).__init__()
assert fold in range(1, 11) if isinstance(fold, int) else all([f in range(1, 11) for f in fold])
self.data_root = Path(data_root) / 'UrbanSound8K'
#Dataset Paramters
self.setting = setting
self.num_worker = num_worker
self.fold = fold if self.setting == V.DATA_OPTIONS.train else 10
self.use_preprocessed = use_preprocessed
self._wav_folder = self.data_root / 'audio' / f'fold{self.fold}'
self._mel_folder = self.data_root / 'mel' / f'fold{self.fold}'
self.container_ext = '.pik'
self._mel_transform = mel_transforms
fold = fold if self.setting != V.DATA_OPTION_test else 10
self.fold = fold if isinstance(fold, list) else [fold]
self._labels = self._build_labels()
self._wav_files = list(sorted(self._labels.keys()))
transforms = transforms or F_x(in_shape=None)
self.sample_segment_len = sample_segment_len
self.sample_hop_len = sample_hop_len
param_storage = self._mel_folder / 'data_params.pik'
self._mel_folder.mkdir(parents=True, exist_ok=True)
try:
pik_data = param_storage.read_bytes()
fingerprint = pickle.loads(pik_data)
if fingerprint == self._fingerprint:
self.use_preprocessed = use_preprocessed
else:
print('Diverging parameters were found; Refreshing...')
param_storage.unlink()
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
self.use_preprocessed = False
# Dataset specific super init
super(Urban8K, self).__init__(Path(data_root) / 'UrbanSound8K',
V.TASK_OPTION_multiclass, reset=reset, wav_folder_name='audio', **kwargs
)
except FileNotFoundError:
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
self.use_preprocessed = False
def _build_subdataset(self, row):
slice_file_name, fs_id, start, end, salience, fold, class_id, class_name = row.strip().split(',')
fold, class_id = (int(x) for x in (fold, class_id))
if int(fold) in self.fold:
audio_file_path = self.wav_folder / f'fold{fold}' / slice_file_name
return PyTorchAudioToMelDataset(audio_file_path, class_id, **self.__dict__)
else:
return None
while True:
if not self.use_preprocessed:
self._pre_process()
try:
self._dataset = ConcatDataset(
[TorchMelDataset(identifier=key, mel_path=self._mel_folder, transform=transforms,
segment_len=audio_segment_len, hop_len=audio_hop_len,
label=self._labels[key]['label']
) for key in self._labels.keys()]
)
break
except IOError:
self.use_preprocessed = False
pass
def _build_labels(self):
labeldict = dict()
def _build_dataset(self):
dataset= list()
with open(Path(self.data_root) / 'metadata' / 'UrbanSound8K.csv', mode='r') as f:
# Exclude the header
_ = next(f)
for row in f:
slice_file_name, fs_id, start, end, salience, fold, class_id, class_name = row.strip().split(',')
if int(fold) == self.fold:
key = slice_file_name.replace('.wav', '')
labeldict[key] = dict(label=int(class_id), fold=int(fold))
all_rows = list(f)
chunksize = len(all_rows) // max(self.num_worker, 1)
with mp.Pool(processes=self.num_worker) as pool:
with tqdm(total=len(all_rows)) as pbar:
for i, sub_dataset in enumerate(
pool.imap_unordered(self._build_subdataset, all_rows, chunksize=chunksize)):
pbar.update()
dataset.append(sub_dataset)
# Delete File if one exists.
if not self.use_preprocessed:
for key in labeldict.keys():
for mel_file in self._mel_folder.rglob(f'{key}_*'):
try:
mel_file.unlink(missing_ok=True)
except FileNotFoundError:
pass
return labeldict
dataset = ConcatDataset([x for x in dataset if x is not None])
return dataset
def __len__(self):
return len(self._dataset)
def _pre_process(self):
print('Preprocessing Mel Files....')
with mp.Pool(processes=self.num_worker) as pool:
with tqdm(total=len(self._labels)) as pbar:
for i, _ in enumerate(pool.imap_unordered(self._build_mel, self._labels.keys())):
pbar.update()
def _build_mel(self, filename):
wav_file = self._wav_folder / (filename.replace('X', '') + '.wav')
mel_file = list(self._mel_folder.glob(f'{filename}_*'))
if not mel_file:
raw_sample, sr = librosa.core.load(wav_file)
mel_sample = self._mel_transform(raw_sample)
m, n = mel_sample.shape
mel_file = self._mel_folder / f'{filename}_{m}_{n}'
self._mel_folder.mkdir(exist_ok=True, parents=True)
with mel_file.open(mode='wb') as f:
pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
else:
# print(f"Already existed.. Skipping {filename}")
mel_file = mel_file[0]
with mel_file.open(mode='rb') as f:
mel_sample = pickle.load(f, fix_imports=True)
return mel_sample, mel_file
def __getitem__(self, item):
transformed_samples, label = self._dataset[item]
label = torch.as_tensor(label, dtype=torch.float)
label = torch.as_tensor(label, dtype=torch.int)
return transformed_samples, label

View File

@ -1,140 +0,0 @@
import pickle
from pathlib import Path
import multiprocessing as mp
import librosa as librosa
from torch.utils.data import Dataset, ConcatDataset
import torch
from tqdm import tqdm
import variables as V
from ml_lib.audio_toolset.mel_dataset import TorchMelDataset
from ml_lib.modules.util import F_x
class Urban8K_TO(Dataset):
@property
def sample_shape(self):
return self[0][0].shape
@property
def _fingerprint(self):
return str(self._mel_transform)
def __init__(self, data_root, setting, mel_transforms, fold=1, transforms=None,
use_preprocessed=True, audio_segment_len=1, audio_hop_len=1, num_worker=mp.cpu_count(),
**_):
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
assert fold in range(1, 11)
super(Urban8K_TO, self).__init__()
self.data_root = Path(data_root) / 'UrbanSound8K'
self.setting = setting
self.num_worker = num_worker
self.fold = fold if self.setting == V.DATA_OPTIONS.train else 10
self.use_preprocessed = use_preprocessed
self._wav_folder = self.data_root / 'audio' / f'fold{self.fold}'
self._mel_folder = self.data_root / 'mel' / f'fold{self.fold}'
self.container_ext = '.pik'
self._mel_transform = mel_transforms
self._labels = self._build_labels()
self._wav_files = list(sorted(self._labels.keys()))
transforms = transforms or F_x(in_shape=None)
param_storage = self._mel_folder / 'data_params.pik'
self._mel_folder.mkdir(parents=True, exist_ok=True)
try:
pik_data = param_storage.read_bytes()
fingerprint = pickle.loads(pik_data)
if fingerprint == self._fingerprint:
self.use_preprocessed = use_preprocessed
else:
print('Diverging parameters were found; Refreshing...')
param_storage.unlink()
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
self.use_preprocessed = False
except FileNotFoundError:
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
self.use_preprocessed = False
while True:
if not self.use_preprocessed:
self._pre_process()
try:
self._dataset = ConcatDataset(
[TorchMelDataset(identifier=key, mel_path=self._mel_folder, transform=transforms,
segment_len=audio_segment_len, hop_len=audio_hop_len,
label=self._labels[key]['label']
) for key in self._labels.keys()]
)
break
except IOError:
self.use_preprocessed = False
pass
def _build_labels(self):
labeldict = dict()
with open(Path(self.data_root) / 'metadata' / 'UrbanSound8K.csv', mode='r') as f:
# Exclude the header
_ = next(f)
for row in f:
slice_file_name, fs_id, start, end, salience, fold, class_id, class_name = row.strip().split(',')
if int(fold) == self.fold:
key = slice_file_name.replace('.wav', '')
labeldict[key] = dict(label=int(class_id), fold=int(fold))
# Delete File if one exists.
if not self.use_preprocessed:
for key in labeldict.keys():
for mel_file in self._mel_folder.rglob(f'{key}_*'):
try:
mel_file.unlink(missing_ok=True)
except FileNotFoundError:
pass
return labeldict
def __len__(self):
return len(self._dataset)
def _pre_process(self):
print('Preprocessing Mel Files....')
with mp.Pool(processes=self.num_worker) as pool:
with tqdm(total=len(self._labels)) as pbar:
for i, _ in enumerate(pool.imap_unordered(self._build_mel, self._labels.keys())):
pbar.update()
def _build_mel(self, filename):
wav_file = self._wav_folder / (filename.replace('X', '') + '.wav')
mel_file = list(self._mel_folder.glob(f'{filename}_*'))
if not mel_file:
raw_sample, sr = librosa.core.load(wav_file)
mel_sample = self._mel_transform(raw_sample)
m, n = mel_sample.shape
mel_file = self._mel_folder / f'{filename}_{m}_{n}'
self._mel_folder.mkdir(exist_ok=True, parents=True)
with mel_file.open(mode='wb') as f:
pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
else:
# print(f"Already existed.. Skipping {filename}")
mel_file = mel_file[0]
with mel_file.open(mode='rb') as f:
mel_sample = pickle.load(f, fix_imports=True)
return mel_sample, mel_file
def __getitem__(self, item):
transformed_samples, label = self._dataset[item]
label = torch.as_tensor(label, dtype=torch.float)
return transformed_samples, label

View File

@ -9,7 +9,7 @@ from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, RandomApply
from ml_lib.audio_toolset.audio_augmentation import Speed
from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal, MelToImage
from ml_lib.audio_toolset.audio_io import LibrosaAudioToMel, NormalizeLocal, MelToImage
# Dataset and Dataloaders
# =============================================================================
@ -28,8 +28,8 @@ from datasets.binar_masks import BinaryMasksDataset
def prepare_dataloader(config_obj):
mel_transforms = Compose([
AudioToMel(sr=config_obj.data.sr, n_mels=config_obj.data.n_mels, n_fft=config_obj.data.n_fft,
hop_length=config_obj.data.hop_length),
LibrosaAudioToMel(sr=config_obj.data.sr, n_mels=config_obj.data.n_mels, n_fft=config_obj.data.n_fft,
hop_length=config_obj.data.hop_length),
MelToImage()])
transforms = Compose([NormalizeLocal(), ToTensor()])
"""

View File

@ -8,7 +8,7 @@ from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, RandomApply
from ml_lib.audio_toolset.audio_augmentation import Speed
from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal, MelToImage
from ml_lib.audio_toolset.audio_io import LibrosaAudioToMel, NormalizeLocal, MelToImage
# Dataset and Dataloaders
# =============================================================================
@ -26,8 +26,8 @@ from datasets.binar_masks import BinaryMasksDataset
def prepare_dataloader(config_obj):
mel_transforms = Compose([
AudioToMel(sr=config_obj.data.sr, n_mels=config_obj.data.n_mels, n_fft=config_obj.data.n_fft,
hop_length=config_obj.data.hop_length),
LibrosaAudioToMel(sr=config_obj.data.sr, n_mels=config_obj.data.n_mels, n_fft=config_obj.data.n_fft,
hop_length=config_obj.data.hop_length),
MelToImage()])
transforms = Compose([NormalizeLocal(), ToTensor()])
aug_transforms = Compose([

View File

@ -10,11 +10,12 @@ from einops import rearrange, repeat
from ml_lib.modules.blocks import TransformerModule
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
BaseDataloadersMixin, BaseTestMixin)
BaseDataloadersMixin, BaseTestMixin, BaseLossMixin)
MIN_NUM_PATCHES = 16
class VisualTransformer(DatasetMixin,
BaseLossMixin,
BaseDataloadersMixin,
BaseTrainMixin,
BaseValMixin,
@ -84,8 +85,8 @@ class VisualTransformer(DatasetMixin,
nn.Linear(self.embed_dim, self.params.lat_dim),
nn.GELU(),
nn.Dropout(self.params.dropout),
nn.Linear(self.params.lat_dim, 1),
nn.Sigmoid()
nn.Linear(self.params.lat_dim, 10),
nn.Softmax()
)
def forward(self, x, mask=None):

View File

@ -8,11 +8,12 @@ from torch import nn
from ml_lib.modules.blocks import TransformerModule
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
BaseDataloadersMixin, BaseTestMixin)
BaseDataloadersMixin, BaseTestMixin, BaseLossMixin)
MIN_NUM_PATCHES = 16
class HorizontalVisualTransformer(DatasetMixin,
BaseLossMixin,
BaseDataloadersMixin,
BaseTrainMixin,
BaseValMixin,
@ -35,6 +36,7 @@ class HorizontalVisualTransformer(DatasetMixin,
# Model Paramters
# =============================================================================
# Additional parameters
self.n_classes = self.dataset.train_dataset.n_classes
self.embed_dim = self.params.embedding_size
self.patch_size = self.params.patch_size
self.height = height
@ -81,8 +83,8 @@ class HorizontalVisualTransformer(DatasetMixin,
nn.Linear(self.embed_dim, self.params.lat_dim),
nn.GELU(),
nn.Dropout(self.params.dropout),
nn.Linear(self.params.lat_dim, 1),
nn.Sigmoid()
nn.Linear(self.params.lat_dim, 10),
nn.Softmax()
)
def forward(self, x, mask=None):

View File

@ -8,11 +8,12 @@ from torch import nn
from ml_lib.modules.blocks import TransformerModule
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
BaseDataloadersMixin, BaseTestMixin)
BaseDataloadersMixin, BaseTestMixin, BaseLossMixin)
MIN_NUM_PATCHES = 16
class VerticalVisualTransformer(DatasetMixin,
BaseLossMixin,
BaseDataloadersMixin,
BaseTrainMixin,
BaseValMixin,
@ -80,8 +81,8 @@ class VerticalVisualTransformer(DatasetMixin,
nn.Linear(self.embed_dim, self.params.lat_dim),
nn.GELU(),
nn.Dropout(self.params.dropout),
nn.Linear(self.params.lat_dim, 1),
nn.Sigmoid()
nn.Linear(self.params.lat_dim, 10),
nn.Softmax()
)
def forward(self, x, mask=None):

View File

@ -14,64 +14,134 @@ warnings.filterwarnings('ignore', category=UserWarning)
if __name__ == '__main__':
args = main_arg_parser.parse_args()
# Model Settings
config = Config().read_namespace(args)
if False:
args = main_arg_parser.parse_args()
# Model Settings
config = Config().read_namespace(args)
arg_dict = dict()
for seed in range(1):
arg_dict.update(main_seed=seed)
if False:
for patch_size in [3, 5 , 9]:
for model in ['VerticalVisualTransformer']:
arg_dict.update(model_type=model, model_patch_size=patch_size)
raw_conf = dict(data_speed_amount=0.0, data_speed_min=0.0, data_speed_max=0.0,
data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,
data_stretch=False, train_epochs=401)
arg_dict = dict()
for seed in range(1):
arg_dict.update(main_seed=seed)
if False:
for patch_size in [3, 5 , 9]:
for model in ['VerticalVisualTransformer']:
arg_dict.update(model_type=model, model_patch_size=patch_size)
raw_conf = dict(data_speed_amount=0.0, data_speed_min=0.0, data_speed_max=0.0,
data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,
data_stretch=False, train_epochs=401)
all_conf = dict(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7,
data_mask_ratio=0.2, data_noise_ratio=0.4, data_shift_ratio=0.4, data_loudness_ratio=0.4,
data_stretch=True, train_epochs=101)
all_conf = dict(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7,
data_mask_ratio=0.2, data_noise_ratio=0.4, data_shift_ratio=0.4, data_loudness_ratio=0.4,
data_stretch=True, train_epochs=101)
speed_conf = raw_conf.copy()
speed_conf.update(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7,
data_stretch=True, train_epochs=101)
speed_conf = raw_conf.copy()
speed_conf.update(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7,
data_stretch=True, train_epochs=101)
mask_conf = raw_conf.copy()
mask_conf.update(data_mask_ratio=0.2, data_stretch=True, train_epochs=101)
mask_conf = raw_conf.copy()
mask_conf.update(data_mask_ratio=0.2, data_stretch=True, train_epochs=101)
noise_conf = raw_conf.copy()
noise_conf.update(data_noise_ratio=0.4, data_stretch=True, train_epochs=101)
noise_conf = raw_conf.copy()
noise_conf.update(data_noise_ratio=0.4, data_stretch=True, train_epochs=101)
shift_conf = raw_conf.copy()
shift_conf.update(data_shift_ratio=0.4, data_stretch=True, train_epochs=101)
shift_conf = raw_conf.copy()
shift_conf.update(data_shift_ratio=0.4, data_stretch=True, train_epochs=101)
loudness_conf = raw_conf.copy()
loudness_conf.update(data_loudness_ratio=0.4, data_stretch=True, train_epochs=101)
loudness_conf = raw_conf.copy()
loudness_conf.update(data_loudness_ratio=0.4, data_stretch=True, train_epochs=101)
for dicts in [raw_conf, all_conf, speed_conf, mask_conf, noise_conf, shift_conf, loudness_conf]:
for dicts in [raw_conf, all_conf, speed_conf, mask_conf, noise_conf, shift_conf, loudness_conf]:
arg_dict.update(dicts)
if True:
for patch_size in [7]:
for lat_dim in [32]:
for heads in [8]:
for embedding_size in [7**2]:
for attn_depth in [1, 3, 5, 7]:
for model in ['HorizontalVisualTransformer']:
arg_dict.update(
model_type=model,
model_patch_size=patch_size,
model_lat_dim=lat_dim,
model_heads=heads,
model_embedding_size=embedding_size,
model_attn_depth=attn_depth
)
config = config.update(arg_dict)
version_path = config.exp_path / config.version
if version_path.exists():
if not (version_path / 'weights.ckpt').exists():
shutil.rmtree(version_path)
else:
continue
run_lightning_loop(config)
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
from diffractio import degrees, mm, plt, sp, um, np
from diffractio.scalar_fields_XY import Scalar_field_XY
from diffractio.utils_drawing import draw_several_fields
from diffractio.scalar_masks_XY import Scalar_mask_XY
from diffractio.scalar_sources_XY import Scalar_source_XY
from matplotlib import rcParams
rcParams['figure.figsize']=(7,5)
rcParams['figure.dpi']=75
period = 20 * um
num_pixels = 512
length = 250 * um
x0 = np.linspace(-length / 2, length / 2, num_pixels)
y0 = np.linspace(-length / 2, length / 2, num_pixels)
wavelength = 0.6238 * um
u1 = Scalar_source_XY(x=x0, y=y0, wavelength=wavelength)
u1.plane_wave(A=1, theta=0 * degrees, phi=0 * degrees)
t1 = Scalar_mask_XY(x=x0, y=y0, wavelength=wavelength)
t1.forked_grating(kind='amplitude',
r0=(0 * um, 0 * um), period=period, l=3, alpha=2, angle=0 * degrees)
u2 = u1 * t1
t2 = Scalar_mask_XY(x=x0, y=y0, wavelength=wavelength)
t2.roughness(t=(20 * um, 20 * um), s=1 * um)
u2 = u2 * t2
u2.draw(kind='phase')
u3 = u2.RS(z=1 * mm, new_field=True)
u4 = u2.RS(z=5 * mm, new_field=True)
u5 = u2.RS(z=10 * mm, new_field=True)
print('draw')
draw_several_fields((u3, u4, u5), titulos=('1 mm', '5 mm', '10 mm'), logarithm=True)
plt.show()
pass
u2 = t2 * u1
u2.draw(kind='phase')
u3 = u2.RS(z=1 * mm, new_field=True)
u4 = u2.RS(z=5 * mm, new_field=True)
u5 = u2.RS(z=10 * mm, new_field=True)
print('draw')
draw_several_fields((u3, u4, u5), titulos=('1 mm', '5 mm', '10 mm'), logarithm=True)
plt.show()
arg_dict.update(dicts)
if True:
for patch_size in [7]:
for lat_dim in [32]:
for heads in [8]:
for embedding_size in [7**2]:
for attn_depth in [1, 3, 5, 7]:
for model in ['HorizontalVisualTransformer']:
arg_dict.update(
model_type=model,
model_patch_size=patch_size,
model_lat_dim=lat_dim,
model_heads=heads,
model_embedding_size=embedding_size,
model_attn_depth=attn_depth
)
config = config.update(arg_dict)
version_path = config.exp_path / config.version
if version_path.exists():
if not (version_path / 'weights.ckpt').exists():
shutil.rmtree(version_path)
else:
continue
run_lightning_loop(config)

View File

@ -1,11 +1,17 @@
from collections import defaultdict
# Imports from python Internals
from abc import ABC
from argparse import Namespace
from itertools import cycle
from collections import defaultdict, namedtuple
import sklearn
import torch
# Numerical Imports, Metrics and Plotting
import numpy as np
from sklearn.ensemble import IsolationForest
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, roc_auc_score, roc_curve, auc, f1_score, \
recall_score, average_precision_score
from matplotlib import pyplot as plt
# Import Deep Learning Framework
import torch
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
@ -13,15 +19,25 @@ from torch.utils.data import DataLoader
from torchcontrib.optim import SWA
from torchvision.transforms import Compose, RandomApply
from ml_lib.audio_toolset.audio_augmentation import Speed
# Import Functions and Modules from MLLIB
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
from ml_lib.audio_toolset.audio_io import AudioToMel, MelToImage, NormalizeLocal
from ml_lib.audio_toolset.audio_io import NormalizeLocal
from ml_lib.modules.util import LightningBaseModule
from ml_lib.utils.tools import to_one_hot
from ml_lib.utils.transforms import ToTensor
# Import Project Variables
import variables as V
class BaseLossMixin:
absolute_loss = nn.L1Loss()
nll_loss = nn.NLLLoss()
bce_loss = nn.BCELoss()
ce_loss = nn.CrossEntropyLoss()
class BaseOptimizerMixin:
def configure_optimizers(self):
@ -60,16 +76,12 @@ class BaseOptimizerMixin:
class BaseTrainMixin:
absolute_loss = nn.L1Loss()
nll_loss = nn.NLLLoss()
bce_loss = nn.BCELoss()
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy
y = self(batch_x).main_out
bce_loss = self.bce_loss(y.squeeze(), batch_y)
return dict(loss=bce_loss)
loss = self.ce_loss(y.squeeze(), batch_y.long())
return dict(loss=loss)
def training_epoch_end(self, outputs):
assert isinstance(self, LightningBaseModule)
@ -84,55 +96,39 @@ class BaseTrainMixin:
class BaseValMixin:
absolute_loss = nn.L1Loss()
nll_loss = nn.NLLLoss()
bce_loss = nn.BCELoss()
def validation_step(self, batch_xy, batch_idx, dataloader_idx, *args, **kwargs):
def validation_step(self, batch_xy, batch_idx, *args, **kwargs):
assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy
y = self(batch_x).main_out
val_bce_loss = self.bce_loss(y.squeeze(), batch_y)
return dict(val_bce_loss=val_bce_loss,
val_loss = self.ce_loss(y.squeeze(), batch_y.long())
return dict(val_loss=val_loss,
batch_idx=batch_idx, y=y, batch_y=batch_y)
def validation_epoch_end(self, outputs, *_, **__):
assert isinstance(self, LightningBaseModule)
summary_dict = dict()
for output_idx, output in enumerate(outputs):
keys = list(output[0].keys())
ident = '' if output_idx == 0 else '_train'
summary_dict.update({f'mean{ident}_{key}': torch.mean(torch.stack([output[key]
for output in output]))
for key in keys if 'loss' in key}
)
# UnweightedAverageRecall
y_true = torch.cat([output['batch_y'] for output in output]) .cpu().numpy()
y_pred = torch.cat([output['y'] for output in output]).squeeze().cpu().numpy()
keys = list(outputs[0].keys())
summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key}
)
y_pred = (y_pred >= 0.5).astype(np.float32)
additional_scores = self.additional_scores(outputs)
summary_dict.update(**additional_scores)
uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro',
sample_weight=None, zero_division='warn')
uar_score = torch.as_tensor(uar_score)
summary_dict.update({f'uar{ident}_score': uar_score})
for key in summary_dict.keys():
self.log(key, summary_dict[key])
for key in summary_dict.keys():
self.log(key, summary_dict[key])
class BaseTestMixin:
absolute_loss = nn.L1Loss()
nll_loss = nn.NLLLoss()
bce_loss = nn.BCELoss()
def test_step(self, batch_xy, batch_idx, dataloader_idx, *args, **kwargs):
def test_step(self, batch_xy, batch_idx, *_, **__):
assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy
y = self(batch_x).main_out
test_bce_loss = self.bce_loss(y.squeeze(), batch_y)
return dict(test_bce_loss=test_bce_loss,
test_loss = self.ce_loss(y.squeeze(), batch_y.long())
return dict(test_loss=test_loss,
batch_idx=batch_idx, y=y, batch_y=batch_y)
def test_epoch_end(self, outputs, *_, **__):
@ -145,16 +141,9 @@ class BaseTestMixin:
for key in keys if 'loss' in key}
)
# UnweightedAverageRecall
y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy()
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy()
additional_scores = self.additional_scores(outputs)
summary_dict.update(**additional_scores)
y_pred = (y_pred >= 0.5).astype(np.float32)
uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro',
sample_weight=None, zero_division='warn')
uar_score = torch.as_tensor(uar_score)
summary_dict.update({f'uar_score': uar_score})
for key in summary_dict.keys():
self.log(key, summary_dict[key])
@ -167,53 +156,56 @@ class DatasetMixin:
# Dataset
# =============================================================================
# Mel Transforms
mel_transforms = Compose([
# Audio to Mel Transformations
AudioToMel(sr=self.params.sr,
n_mels=self.params.n_mels,
n_fft=self.params.n_fft,
hop_length=self.params.hop_length),
MelToImage()])
mel_transforms_train = Compose([
# Audio to Mel Transformations
Speed(max_amount=self.params.speed_amount,
speed_min=self.params.speed_min,
speed_max=self.params.speed_max
),
mel_transforms])
mel_kwargs = dict(sample_rate=self.params.sr,
n_mels=self.params.n_mels,
n_fft=self.params.n_fft,
hop_length=self.params.hop_length)
# Utility
util_transforms = Compose([NormalizeLocal(), ToTensor()])
normalize = NormalizeLocal()
# Data Augmentations
aug_transforms = Compose([
mel_augmentations = Compose([
RandomApply([
NoiseInjection(self.params.noise_ratio),
LoudnessManipulator(self.params.loudness_ratio),
ShiftTime(self.params.shift_ratio),
MaskAug(self.params.mask_ratio),
NoiseInjection(0.2),
LoudnessManipulator(0.5),
ShiftTime(0.4),
MaskAug(0.2),
], p=0.6),
util_transforms])
normalize])
# Datasets
dataset = Namespace(
**dict(
# TRAIN DATASET
train_dataset=self.dataset_class(self.params.root, setting=V.DATA_OPTIONS.train,
use_preprocessed=self.params.use_preprocessed,
stretch_dataset=self.params.stretch,
mel_transforms=mel_transforms_train, transforms=aug_transforms),
# VALIDATION DATASET
val_train_dataset=self.dataset_class(self.params.root, setting=V.DATA_OPTIONS.train,
mel_transforms=mel_transforms, transforms=util_transforms),
val_dataset=self.dataset_class(self.params.root, setting=V.DATA_OPTIONS.devel,
mel_transforms=mel_transforms, transforms=util_transforms),
# TEST DATASET
test_dataset=self.dataset_class(self.params.root, setting=V.DATA_OPTIONS.test,
mel_transforms=mel_transforms, transforms=util_transforms),
)
)
Dataset = namedtuple('Datasets', 'train_dataset val_dataset test_dataset')
dataset = Dataset(self.dataset_class(data_root=self.params.root, # TRAIN DATASET
setting=V.DATA_OPTION_train,
fold=list(range(1,8)),
reset=self.params.reset,
mel_kwargs=mel_kwargs,
mel_augmentations=mel_augmentations),
val_dataset=self.dataset_class(data_root=self.params.root, # VALIDATION DATASET
setting=V.DATA_OPTION_devel,
fold=9,
reset=self.params.reset,
mel_kwargs=mel_kwargs,
mel_augmentations=normalize),
test_dataset=self.dataset_class(data_root=self.params.root, # TEST DATASET
setting=V.DATA_OPTION_test,
fold=10,
reset=self.params.reset,
mel_kwargs=mel_kwargs,
mel_augmentations=normalize),
)
if dataset.train_dataset.task_type == V.TASK_OPTION_binary:
# noinspection PyAttributeOutsideInit
self.additional_scores = BinaryScores(self)
elif dataset.train_dataset.task_type == V.TASK_OPTION_multiclass:
# noinspection PyAttributeOutsideInit
self.additional_scores = MultiClassScores(self)
else:
raise ValueError
return dataset
@ -240,10 +232,185 @@ class BaseDataloadersMixin(ABC):
# Validation Dataloader
def val_dataloader(self):
assert isinstance(self, LightningBaseModule)
val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=False, pin_memory=True,
batch_size=self.params.batch_size, num_workers=self.params.worker)
return DataLoader(dataset=self.dataset.val_dataset, shuffle=False, pin_memory=True,
batch_size=self.params.batch_size, num_workers=self.params.worker)
train_dataloader = DataLoader(self.dataset.val_train_dataset, num_workers=self.params.worker,
pin_memory=True,
batch_size=self.params.batch_size, shuffle=False)
return [val_dataloader, train_dataloader]
class BaseScores(ABC):
def __init__(self, lightning_model):
self.model = lightning_model
pass
def __call__(self, outputs):
# summary_dict = dict()
# return summary_dict
raise NotImplementedError
class MultiClassScores(BaseScores):
def __init__(self, *args):
super(MultiClassScores, self).__init__(*args)
pass
def __call__(self, outputs):
summary_dict = dict()
#######################################################################################
# Additional Score - UAR - ROC - Conf. Matrix - F1
#######################################################################################
#
# INIT
y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy()
y_true_one_hot = to_one_hot(y_true, self.model.n_classes)
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy()
y_pred_max = np.argmax(y_pred, axis=1)
class_names = {val: key for key, val in self.model.dataset.test_dataset.classes.items()}
######################################################################################
#
# F1 SCORE
micro_f1_score = f1_score(y_true, y_pred_max, labels=None, pos_label=1, average='micro', sample_weight=None,
zero_division=True)
macro_f1_score = f1_score(y_true, y_pred_max, labels=None, pos_label=1, average='macro', sample_weight=None,
zero_division=True)
summary_dict.update(dict(micro_f1_score=micro_f1_score, macro_f1_score=macro_f1_score))
#######################################################################################
#
# ROC Curve
# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(self.model.n_classes):
fpr[i], tpr[i], _ = roc_curve(y_true_one_hot[:, i], y_pred[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])
# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(y_true_one_hot.ravel(), y_pred.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(self.model.n_classes)]))
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(self.model.n_classes):
mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
# Finally average it and compute AUC
mean_tpr /= self.model.n_classes
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
# Plot all ROC curves
plt.figure()
plt.plot(fpr["micro"], tpr["micro"],
label=f'micro ROC ({round(roc_auc["micro"], 2)})',
color='deeppink', linestyle=':', linewidth=4)
plt.plot(fpr["macro"], tpr["macro"],
label=f'macro ROC({round(roc_auc["macro"], 2)})',
color='navy', linestyle=':', linewidth=4)
colors = cycle(['firebrick', 'orangered', 'gold', 'olive', 'limegreen', 'aqua',
'dodgerblue', 'slategrey', 'royalblue', 'indigo', 'fuchsia'], )
for i, color in zip(range(self.model.n_classes), colors):
plt.plot(fpr[i], tpr[i], color=color, lw=2, label=f'{class_names[i]} ({round(roc_auc[i], 2)})')
plt.plot([0, 1], [0, 1], 'k--', lw=2)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc="lower right")
self.model.logger.log_image('ROC', image=plt.gcf(), step=self.model.current_epoch)
self.model.logger.log_image('ROC', image=plt.gcf(), step=self.model.current_epoch, ext='pdf')
plt.clf()
#######################################################################################
#
# ROC SCORE
try:
macro_roc_auc_ovr = roc_auc_score(y_true_one_hot, y_pred, multi_class="ovr",
average="macro")
summary_dict.update(macro_roc_auc_ovr=macro_roc_auc_ovr)
except ValueError:
micro_roc_auc_ovr = roc_auc_score(y_true_one_hot, y_pred, multi_class="ovr",
average="micro")
summary_dict.update(micro_roc_auc_ovr=micro_roc_auc_ovr)
#######################################################################################
#
# Confusion matrix
cm = confusion_matrix([class_names[x] for x in y_true], [class_names[x] for x in y_pred_max],
labels=[class_names[key] for key in class_names.keys()],
normalize='all')
disp = ConfusionMatrixDisplay(confusion_matrix=cm,
display_labels=[class_names[i] for i in range(self.model.n_classes)]
)
disp.plot(include_values=True)
self.model.logger.log_image('Confusion_Matrix', image=disp.figure_, step=self.model.current_epoch)
self.model.logger.log_image('Confusion_Matrix', image=disp.figure_, step=self.model.current_epoch, ext='pdf')
plt.close('all')
return summary_dict
class BinaryScores(BaseScores):
def __init__(self, *args):
super(BinaryScores, self).__init__(*args)
def __call__(self, outputs):
summary_dict = dict()
# Additional Score like the unweighted Average Recall:
#########################
# UnweightedAverageRecall
y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy()
y_pred = torch.cat([output['element_wise_recon_error'] for output in outputs]).squeeze().cpu().numpy()
# How to apply a threshold manualy
# y_pred = (y_pred >= 0.5).astype(np.float32)
# How to apply a threshold by IF (Isolation Forest)
clf = IsolationForest(random_state=self.model.seed)
y_score = clf.fit_predict(y_pred.reshape(-1,1))
y_score = (np.asarray(y_score) == -1).astype(np.float32)
uar_score = recall_score(y_true, y_score, labels=[0, 1], average='macro',
sample_weight=None, zero_division='warn')
summary_dict.update(dict(uar_score=uar_score))
#########################
# Precission
precision_score = average_precision_score(y_true, y_score)
summary_dict.update(dict(precision_score=precision_score))
#########################
# AUC
try:
auc_score = roc_auc_score(y_true=y_true, y_score=y_score)
summary_dict.update(dict(auc_score=auc_score))
except ValueError:
summary_dict.update(dict(auc_score=-1))
#########################
# pAUC
try:
pauc = roc_auc_score(y_true=y_true, y_score=y_score, max_fpr=0.15)
summary_dict.update(dict(pauc_score=pauc))
except ValueError:
summary_dict.update(dict(pauc_score=-1))
return summary_dict

View File

@ -4,8 +4,22 @@ from argparse import Namespace
CLEAR = 0
MASK = 1
NUM_CLASSES = 2
# Task Options
TASK_OPTION_multiclass = 'multiclass'
N_CLASS_multi = 10
multi_classes_names = ['air_conditioner', 'car_horn', 'children_playing',
'dog_bar', 'drilling', 'engine_idling',
'gun_shot', 'jackhammer', 'siren', 'street_music']
multi_classes = {key: val for val, key in enumerate(multi_classes_names)}
TASK_OPTION_binary = 'binary'
N_CLASS_binary = 2
binary_CLASS_clear = 0
binary_CLASS_maske = 1
# Dataset Options
DATA_OPTIONS = Namespace(test='test', devel='devel', train='train')
DATA_OPTION_test = 'test'
DATA_OPTION_devel = 'devel'
DATA_OPTION_train = 'train'
DATA_OPTIONS = [DATA_OPTION_train, DATA_OPTION_devel, DATA_OPTION_test]