Audio Dataset

This commit is contained in:
Si11ium 2020-12-01 16:37:16 +01:00
parent 95561acc35
commit 95dcf22f3d
15 changed files with 468 additions and 145 deletions

View File

@ -2,7 +2,8 @@ from argparse import ArgumentParser, Namespace
from distutils.util import strtobool
from pathlib import Path
import os
NEPTUNE_API_KEY = 'eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vdWkubmVwdHVuZS5haSIsImFwaV91cmwiOiJodHRwczovL3VpLm5lcHR1bmUu' \
'YWkiLCJhcGlfa2V5IjoiZmI0OGMzNzUtOTg1NS00Yzg2LThjMzYtMWFiYjUwMDUyMjVlIn0='
# Parameter Configuration
# =============================================================================
@ -18,10 +19,10 @@ main_arg_parser.add_argument("--main_eval", type=strtobool, default=True, help="
main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
# Data Parameters
main_arg_parser.add_argument("--data_worker", type=int, default=11, help="")
main_arg_parser.add_argument("--data_class_name", type=str, default='Urban8K', help="")
main_arg_parser.add_argument("--data_worker", type=int, default=6, help="")
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
main_arg_parser.add_argument("--data_class_name", type=str, default='BinaryMasksDataset', help="")
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--data_n_mels", type=int, default=64, help="")
main_arg_parser.add_argument("--data_sr", type=int, default=16000, help="")
main_arg_parser.add_argument("--data_hop_length", type=int, default=256, help="")
@ -39,7 +40,7 @@ main_arg_parser.add_argument("--data_speed_max", type=float, default=0, help="")
# Model Parameters
# General
main_arg_parser.add_argument("--model_type", type=str, default="SequentialVisualTransformer", help="")
main_arg_parser.add_argument("--model_type", type=str, default="HorizontalVisualTransformer", help="")
main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="")
main_arg_parser.add_argument("--model_activation", type=str, default="gelu", help="")
main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="")
@ -63,8 +64,8 @@ main_arg_parser.add_argument("--train_version", type=strtobool, required=False,
main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--train_weight_decay", type=float, default=0, help="")
main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=0, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=100, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=250, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=200, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=200, help="")
main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="")
main_arg_parser.add_argument("--train_lr_warmup_steps", type=int, default=10, help="")
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
@ -72,7 +73,7 @@ main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0
# Project Parameters
main_arg_parser.add_argument("--project_name", type=str, default=_ROOT.name, help="")
main_arg_parser.add_argument("--project_owner", type=str, default='si11ium', help="")
main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.getenv('NEPTUNE_API_TOKEN'), help="")
main_arg_parser.add_argument("--project_neptune_key", type=str, default=NEPTUNE_API_KEY, help="")
if __name__ == '__main__':
# Parse it

View File

@ -18,15 +18,18 @@ class BinaryMasksDataset(Dataset):
def sample_shape(self):
return self[0][0].shape
@property
def _fingerprint(self):
return dict(**self._mel_kwargs, normalize=self.normalize)
def __init__(self, data_root, setting, mel_transforms, transforms=None, stretch_dataset=False,
use_preprocessed=True):
self.use_preprocessed = use_preprocessed
self.stretch = stretch_dataset
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
super(BinaryMasksDataset, self).__init__()
self.data_root = Path(data_root)
self.data_root = Path(data_root) / 'ComParE2020_Mask'
self.setting = setting
self._wav_folder = self.data_root / 'wav'
self._mel_folder = self.data_root / 'mel'
@ -37,16 +40,36 @@ class BinaryMasksDataset(Dataset):
self._wav_files = list(sorted(self._labels.keys()))
self._transforms = transforms or F_x(in_shape=None)
param_storage = self._mel_folder / 'data_params.pik'
self._mel_folder.mkdir(parents=True, exist_ok=True)
try:
pik_data = param_storage.read_bytes()
fingerprint = pickle.loads(pik_data)
if fingerprint == self._fingerprint:
self.use_preprocessed = use_preprocessed
else:
print('Diverging parameters were found; Refreshing...')
param_storage.unlink()
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
self.use_preprocessed = True
except FileNotFoundError:
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
self.use_preprocessed = True
def _build_labels(self):
labeldict = dict()
with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f:
labelfile = 'labels' if self.setting != V.DATA_OPTIONS.test else V.DATA_OPTIONS.test
with open(Path(self.data_root) / 'lab' / f'{labelfile}.csv', mode='r') as f:
# Exclude the header
_ = next(f)
for row in f:
if self.setting not in row:
continue
filename, label = row.strip().split(',')
labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename
labeldict[filename] = self._to_label[label.lower()] # if not self.setting == 'test' else filename
if self.stretch and self.setting == V.DATA_OPTIONS.train:
additional_dict = ({f'X{key}': val for key, val in labeldict.items()})
additional_dict.update({f'XX{key}': val for key, val in labeldict.items()})

View File

@ -1,95 +1,140 @@
import pickle
from collections import defaultdict
from pathlib import Path
import multiprocessing as mp
import librosa as librosa
from torch.utils.data import Dataset
from torch.utils.data import Dataset, ConcatDataset
import torch
from tqdm import tqdm
import variables as V
from ml_lib.audio_toolset.mel_dataset import TorchMelDataset
from ml_lib.modules.util import F_x
class BinaryMasksDataset(Dataset):
_to_label = defaultdict(lambda: -1)
_to_label.update(dict(clear=V.CLEAR, mask=V.MASK))
class Urban8K(Dataset):
@property
def sample_shape(self):
return self[0][0].shape
def __init__(self, data_root, setting, mel_transforms, transforms=None, stretch_dataset=False,
use_preprocessed=True):
self.use_preprocessed = use_preprocessed
self.stretch = stretch_dataset
@property
def _fingerprint(self):
return str(self._mel_transform)
def __init__(self, data_root, setting, mel_transforms, fold=1, transforms=None,
use_preprocessed=True, audio_segment_len=62, audio_hop_len=30, num_worker=mp.cpu_count(),
**_):
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
super(BinaryMasksDataset, self).__init__()
assert fold in range(1, 11)
super(Urban8K, self).__init__()
self.data_root = Path(data_root)
self.data_root = Path(data_root) / 'UrbanSound8K'
self.setting = setting
self._wav_folder = self.data_root / 'wav'
self._mel_folder = self.data_root / 'mel'
self.num_worker = num_worker
self.fold = fold if self.setting == V.DATA_OPTIONS.train else 10
self.use_preprocessed = use_preprocessed
self._wav_folder = self.data_root / 'audio' / f'fold{self.fold}'
self._mel_folder = self.data_root / 'mel' / f'fold{self.fold}'
self.container_ext = '.pik'
self._mel_transform = mel_transforms
self._labels = self._build_labels()
self._wav_files = list(sorted(self._labels.keys()))
self._transforms = transforms or F_x(in_shape=None)
transforms = transforms or F_x(in_shape=None)
param_storage = self._mel_folder / 'data_params.pik'
self._mel_folder.mkdir(parents=True, exist_ok=True)
try:
pik_data = param_storage.read_bytes()
fingerprint = pickle.loads(pik_data)
if fingerprint == self._fingerprint:
self.use_preprocessed = use_preprocessed
else:
print('Diverging parameters were found; Refreshing...')
param_storage.unlink()
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
self.use_preprocessed = False
except FileNotFoundError:
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
self.use_preprocessed = False
while True:
if not self.use_preprocessed:
self._pre_process()
try:
self._dataset = ConcatDataset(
[TorchMelDataset(identifier=key, mel_path=self._mel_folder, transform=transforms,
segment_len=audio_segment_len, hop_len=audio_hop_len,
label=self._labels[key]['label']
) for key in self._labels.keys()]
)
break
except IOError:
self.use_preprocessed = False
pass
def _build_labels(self):
labeldict = dict()
with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f:
with open(Path(self.data_root) / 'metadata' / 'UrbanSound8K.csv', mode='r') as f:
# Exclude the header
_ = next(f)
for row in f:
if self.setting not in row:
continue
filename, label = row.strip().split(',')
labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename
if self.stretch and self.setting == V.DATA_OPTIONS.train:
additional_dict = ({f'X{key}': val for key, val in labeldict.items()})
additional_dict.update({f'XX{key}': val for key, val in labeldict.items()})
additional_dict.update({f'XXX{key}': val for key, val in labeldict.items()})
labeldict.update(additional_dict)
slice_file_name, fs_id, start, end, salience, fold, class_id, class_name = row.strip().split(',')
if int(fold) == self.fold:
key = slice_file_name.replace('.wav', '')
labeldict[key] = dict(label=int(class_id), fold=int(fold))
# Delete File if one exists.
if not self.use_preprocessed:
for key in labeldict.keys():
for mel_file in self._mel_folder.rglob(f'{key}_*'):
try:
(self._mel_folder / (key.replace('.wav', '') + self.container_ext)).unlink()
mel_file.unlink(missing_ok=True)
except FileNotFoundError:
pass
return labeldict
def __len__(self):
return len(self._labels)
return len(self._dataset)
def _compute_or_retrieve(self, filename):
def _pre_process(self):
print('Preprocessing Mel Files....')
with mp.Pool(processes=self.num_worker) as pool:
with tqdm(total=len(self._labels)) as pbar:
for i, _ in enumerate(pool.imap_unordered(self._build_mel, self._labels.keys())):
pbar.update()
if not (self._mel_folder / (filename + self.container_ext)).exists():
raw_sample, sr = librosa.core.load(self._wav_folder / (filename.replace('X', '') + '.wav'))
def _build_mel(self, filename):
wav_file = self._wav_folder / (filename.replace('X', '') + '.wav')
mel_file = list(self._mel_folder.glob(f'{filename}_*'))
if not mel_file:
raw_sample, sr = librosa.core.load(wav_file)
mel_sample = self._mel_transform(raw_sample)
m, n = mel_sample.shape
mel_file = self._mel_folder / f'{filename}_{m}_{n}'
self._mel_folder.mkdir(exist_ok=True, parents=True)
with (self._mel_folder / (filename + self.container_ext)).open(mode='wb') as f:
with mel_file.open(mode='wb') as f:
pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
else:
# print(f"Already existed.. Skipping {filename}")
mel_file = mel_file[0]
with (self._mel_folder / (filename + self.container_ext)).open(mode='rb') as f:
with mel_file.open(mode='rb') as f:
mel_sample = pickle.load(f, fix_imports=True)
return mel_sample
return mel_sample, mel_file
def __getitem__(self, item):
transformed_samples, label = self._dataset[item]
key: str = list(self._labels.keys())[item]
filename = key.replace('.wav', '')
mel_sample = self._compute_or_retrieve(filename)
label = self._labels[key]
transformed_samples = self._transforms(mel_sample)
if self.setting != V.DATA_OPTIONS.test:
# In test, filenames instead of labels are returned. This is a little hacky though.
label = torch.as_tensor(label, dtype=torch.float)
return transformed_samples, label

View File

@ -0,0 +1,140 @@
import pickle
from pathlib import Path
import multiprocessing as mp
import librosa as librosa
from torch.utils.data import Dataset, ConcatDataset
import torch
from tqdm import tqdm
import variables as V
from ml_lib.audio_toolset.mel_dataset import TorchMelDataset
from ml_lib.modules.util import F_x
class Urban8K_TO(Dataset):
@property
def sample_shape(self):
return self[0][0].shape
@property
def _fingerprint(self):
return str(self._mel_transform)
def __init__(self, data_root, setting, mel_transforms, fold=1, transforms=None,
use_preprocessed=True, audio_segment_len=1, audio_hop_len=1, num_worker=mp.cpu_count(),
**_):
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
assert fold in range(1, 11)
super(Urban8K_TO, self).__init__()
self.data_root = Path(data_root) / 'UrbanSound8K'
self.setting = setting
self.num_worker = num_worker
self.fold = fold if self.setting == V.DATA_OPTIONS.train else 10
self.use_preprocessed = use_preprocessed
self._wav_folder = self.data_root / 'audio' / f'fold{self.fold}'
self._mel_folder = self.data_root / 'mel' / f'fold{self.fold}'
self.container_ext = '.pik'
self._mel_transform = mel_transforms
self._labels = self._build_labels()
self._wav_files = list(sorted(self._labels.keys()))
transforms = transforms or F_x(in_shape=None)
param_storage = self._mel_folder / 'data_params.pik'
self._mel_folder.mkdir(parents=True, exist_ok=True)
try:
pik_data = param_storage.read_bytes()
fingerprint = pickle.loads(pik_data)
if fingerprint == self._fingerprint:
self.use_preprocessed = use_preprocessed
else:
print('Diverging parameters were found; Refreshing...')
param_storage.unlink()
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
self.use_preprocessed = False
except FileNotFoundError:
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
self.use_preprocessed = False
while True:
if not self.use_preprocessed:
self._pre_process()
try:
self._dataset = ConcatDataset(
[TorchMelDataset(identifier=key, mel_path=self._mel_folder, transform=transforms,
segment_len=audio_segment_len, hop_len=audio_hop_len,
label=self._labels[key]['label']
) for key in self._labels.keys()]
)
break
except IOError:
self.use_preprocessed = False
pass
def _build_labels(self):
labeldict = dict()
with open(Path(self.data_root) / 'metadata' / 'UrbanSound8K.csv', mode='r') as f:
# Exclude the header
_ = next(f)
for row in f:
slice_file_name, fs_id, start, end, salience, fold, class_id, class_name = row.strip().split(',')
if int(fold) == self.fold:
key = slice_file_name.replace('.wav', '')
labeldict[key] = dict(label=int(class_id), fold=int(fold))
# Delete File if one exists.
if not self.use_preprocessed:
for key in labeldict.keys():
for mel_file in self._mel_folder.rglob(f'{key}_*'):
try:
mel_file.unlink(missing_ok=True)
except FileNotFoundError:
pass
return labeldict
def __len__(self):
return len(self._dataset)
def _pre_process(self):
print('Preprocessing Mel Files....')
with mp.Pool(processes=self.num_worker) as pool:
with tqdm(total=len(self._labels)) as pbar:
for i, _ in enumerate(pool.imap_unordered(self._build_mel, self._labels.keys())):
pbar.update()
def _build_mel(self, filename):
wav_file = self._wav_folder / (filename.replace('X', '') + '.wav')
mel_file = list(self._mel_folder.glob(f'{filename}_*'))
if not mel_file:
raw_sample, sr = librosa.core.load(wav_file)
mel_sample = self._mel_transform(raw_sample)
m, n = mel_sample.shape
mel_file = self._mel_folder / f'{filename}_{m}_{n}'
self._mel_folder.mkdir(exist_ok=True, parents=True)
with mel_file.open(mode='wb') as f:
pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
else:
# print(f"Already existed.. Skipping {filename}")
mel_file = mel_file[0]
with mel_file.open(mode='rb') as f:
mel_sample = pickle.load(f, fix_imports=True)
return mel_sample, mel_file
def __getitem__(self, item):
transformed_samples, label = self._dataset[item]
label = torch.as_tensor(label, dtype=torch.float)
return transformed_samples, label

40
main.py
View File

@ -82,47 +82,9 @@ def run_lightning_loop(config_obj):
# Save the last state & all parameters
trainer.save_checkpoint(str(logger.log_dir / 'weights.ckpt'))
model.save_to_disk(logger.log_dir)
# trainer.run_evaluation(test_mode=True)
# Evaluate It
if config_obj.main.eval:
with torch.no_grad():
model.eval()
if torch.cuda.is_available():
model.cuda()
outputs = []
from tqdm import tqdm
for idx, batch in enumerate(tqdm(model.val_dataloader()[0])):
batch_x, label = batch
batch_x = batch_x.to(device='cuda' if model.on_gpu else 'cpu')
label = label.to(device='cuda' if model.on_gpu else 'cpu')
outputs.append(
model.validation_step((batch_x, label), idx, 1)
)
model.validation_epoch_end([outputs])
# trainer.test()
outpath = Path(config_obj.train.outpath)
model_type = config_obj.model.type
parameters = logger.name
version = f'version_{logger.version}'
inference_out = f'{parameters}_test_out.csv'
from main_inference import prepare_dataloader
import variables as V
test_dataloader = prepare_dataloader(config_obj)
with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile:
outfile.write(f'file_name,prediction\n')
from tqdm import tqdm
for batch in tqdm(test_dataloader, total=len(test_dataloader)):
batch_x, file_names = batch
batch_x = batch_x.to(device='cuda' if model.on_gpu else 'cpu')
y = model(batch_x).main_out
predictions = (y >= 0.5).int()
for prediction, file_name in zip(predictions, file_names):
prediction_text = 'clear' if prediction == V.CLEAR else 'mask'
outfile.write(f'{file_name},{prediction_text}\n')
return model

View File

@ -5,11 +5,11 @@ from torch.nn import ModuleList
from ml_lib.modules.blocks import ConvModule, LinearModule
from ml_lib.modules.util import (LightningBaseModule, Splitter, Merger)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
BaseDataloadersMixin)
class BandwiseConvClassifier(BinaryMaskDatasetMixin,
class BandwiseConvClassifier(DatasetMixin,
BaseDataloadersMixin,
BaseTrainMixin,
BaseValMixin,

View File

@ -6,11 +6,11 @@ from torch.nn import ModuleList
from ml_lib.modules.blocks import ConvModule, LinearModule
from ml_lib.modules.util import (LightningBaseModule, Splitter)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
BaseDataloadersMixin)
class BandwiseConvMultiheadClassifier(BinaryMaskDatasetMixin,
class BandwiseConvMultiheadClassifier(DatasetMixin,
BaseDataloadersMixin,
BaseTrainMixin,
BaseValMixin,

View File

@ -5,11 +5,11 @@ from torch.nn import ModuleList
from ml_lib.modules.blocks import ConvModule, LinearModule
from ml_lib.modules.util import LightningBaseModule
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
BaseDataloadersMixin)
class ConvClassifier(BinaryMaskDatasetMixin,
class ConvClassifier(DatasetMixin,
BaseDataloadersMixin,
BaseTrainMixin,
BaseValMixin,

View File

@ -8,11 +8,11 @@ from torch.nn import ModuleList
from ml_lib.modules.util import LightningBaseModule
from ml_lib.utils.config import Config
from ml_lib.utils.model_io import SavedLightningModels
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
BaseDataloadersMixin)
class Ensemble(BinaryMaskDatasetMixin,
class Ensemble(DatasetMixin,
BaseDataloadersMixin,
BaseTrainMixin,
BaseValMixin,

View File

@ -5,11 +5,11 @@ from torch.nn import ModuleList
from ml_lib.modules.blocks import ConvModule, LinearModule, ResidualModule
from ml_lib.modules.util import LightningBaseModule
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
BaseDataloadersMixin)
class ResidualConvClassifier(BinaryMaskDatasetMixin,
class ResidualConvClassifier(DatasetMixin,
BaseDataloadersMixin,
BaseTrainMixin,
BaseValMixin,

View File

@ -9,15 +9,16 @@ from einops import rearrange, repeat
from ml_lib.modules.blocks import TransformerModule
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
BaseDataloadersMixin)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
BaseDataloadersMixin, BaseTestMixin)
MIN_NUM_PATCHES = 16
class VisualTransformer(BinaryMaskDatasetMixin,
class VisualTransformer(DatasetMixin,
BaseDataloadersMixin,
BaseTrainMixin,
BaseValMixin,
BaseTestMixin,
BaseOptimizerMixin,
LightningBaseModule
):

View File

@ -0,0 +1,111 @@
from argparse import Namespace
import warnings
import torch
from torch import nn
from ml_lib.modules.blocks import TransformerModule
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
BaseDataloadersMixin, BaseTestMixin)
MIN_NUM_PATCHES = 16
class HorizontalVisualTransformer(DatasetMixin,
BaseDataloadersMixin,
BaseTrainMixin,
BaseValMixin,
BaseTestMixin,
BaseOptimizerMixin,
LightningBaseModule
):
def __init__(self, hparams):
super(HorizontalVisualTransformer, self).__init__(hparams)
# Dataset
# =============================================================================
self.dataset = self.build_dataset()
self.in_shape = self.dataset.train_dataset.sample_shape
assert len(self.in_shape) == 3, 'There need to be three Dimensions'
channels, height, width = self.in_shape
# Model Paramters
# =============================================================================
# Additional parameters
self.embed_dim = self.params.embedding_size
self.patch_size = self.params.patch_size
self.height = height
self.width = width
self.channels = channels
self.new_height = ((self.height - self.patch_size)//1) + 1
num_patches = self.new_height - (self.patch_size // 2)
patch_dim = channels * self.patch_size * self.width
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
f'attention. Try decreasing your patch size'
# Correct the Embedding Dim
if not self.embed_dim % self.params.heads == 0:
self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
message = ('Embedding Dimension was fixed to be devideable by the number' +
f' of attention heads, is now: {self.embed_dim}')
for func in print, warnings.warn:
func(message)
# Utility Modules
self.autopad = AutoPadToShape((self.new_height, self.width))
self.dropout = nn.Dropout(self.params.dropout)
self.slider = SlidingWindow((channels, *self.autopad.target_shape), (self.patch_size, self.width),
keepdim=False)
# Modules with Parameters
self.transformer = TransformerModule(in_shape=self.embed_dim, hidden_size=self.params.lat_dim,
n_heads=self.params.heads, num_layers=self.params.attn_depth,
dropout=self.params.dropout, use_norm=self.params.use_norm,
activation=self.params.activation_as_string
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
else F_x(self.embed_dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
self.to_cls_token = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(self.embed_dim),
nn.Linear(self.embed_dim, self.params.lat_dim),
nn.GELU(),
nn.Dropout(self.params.dropout),
nn.Linear(self.params.lat_dim, 1),
nn.Sigmoid()
)
def forward(self, x, mask=None):
"""
:param x: the sequence to the encoder (required).
:param mask: the mask for the src sequence (optional).
:return:
"""
tensor = self.autopad(x)
tensor = self.slider(tensor)
tensor = self.patch_to_embedding(tensor)
b, n, _ = tensor.shape
# cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
cls_tokens = self.cls_token.repeat((b, 1, 1))
tensor = torch.cat((cls_tokens, tensor), dim=1)
tensor += self.pos_embedding[:, :(n + 1)]
tensor = self.dropout(tensor)
tensor = self.transformer(tensor, mask)
tensor = self.to_cls_token(tensor[:, 0])
tensor = self.mlp_head(tensor)
return Namespace(main_out=tensor)

View File

@ -7,21 +7,22 @@ from torch import nn
from ml_lib.modules.blocks import TransformerModule
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
BaseDataloadersMixin)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
BaseDataloadersMixin, BaseTestMixin)
MIN_NUM_PATCHES = 16
class SequentialVisualTransformer(BinaryMaskDatasetMixin,
class VerticalVisualTransformer(DatasetMixin,
BaseDataloadersMixin,
BaseTrainMixin,
BaseValMixin,
BaseTestMixin,
BaseOptimizerMixin,
LightningBaseModule
):
def __init__(self, hparams):
super(SequentialVisualTransformer, self).__init__(hparams)
super(VerticalVisualTransformer, self).__init__(hparams)
# Dataset
# =============================================================================

View File

@ -4,6 +4,7 @@ from _paramters import main_arg_parser
from main import run_lightning_loop
import warnings
import shutil
from ml_lib.utils.config import Config
@ -22,7 +23,7 @@ if __name__ == '__main__':
arg_dict.update(main_seed=seed)
if False:
for patch_size in [3, 5 , 9]:
for model in ['SequentialVisualTransformer']:
for model in ['VerticalVisualTransformer']:
arg_dict.update(model_type=model, model_patch_size=patch_size)
raw_conf = dict(data_speed_amount=0.0, data_speed_min=0.0, data_speed_max=0.0,
data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,
@ -52,12 +53,12 @@ if __name__ == '__main__':
arg_dict.update(dicts)
if True:
for patch_size in [3, 7]:
for lat_dim in [4, 32]:
for heads in [2, 4]:
for embedding_size in [32, 64]:
for attn_depth in [1, 3]:
for model in ['SequentialVisualTransformer', 'VisualTransformer']:
for patch_size in [7]:
for lat_dim in [32]:
for heads in [8]:
for embedding_size in [7**2]:
for attn_depth in [1, 3, 5, 7]:
for model in ['HorizontalVisualTransformer']:
arg_dict.update(
model_type=model,
model_patch_size=patch_size,

View File

@ -121,7 +121,45 @@ class BaseValMixin:
self.log(key, summary_dict[key])
class BinaryMaskDatasetMixin:
class BaseTestMixin:
absolute_loss = nn.L1Loss()
nll_loss = nn.NLLLoss()
bce_loss = nn.BCELoss()
def test_step(self, batch_xy, batch_idx, dataloader_idx, *args, **kwargs):
assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy
y = self(batch_x).main_out
test_bce_loss = self.bce_loss(y.squeeze(), batch_y)
return dict(test_bce_loss=test_bce_loss,
batch_idx=batch_idx, y=y, batch_y=batch_y)
def test_epoch_end(self, outputs, *_, **__):
assert isinstance(self, LightningBaseModule)
summary_dict = dict()
keys = list(outputs[0].keys())
summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key}
)
# UnweightedAverageRecall
y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy()
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy()
y_pred = (y_pred >= 0.5).astype(np.float32)
uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro',
sample_weight=None, zero_division='warn')
uar_score = torch.as_tensor(uar_score)
summary_dict.update({f'uar_score': uar_score})
for key in summary_dict.keys():
self.log(key, summary_dict[key])
class DatasetMixin:
def build_dataset(self):
assert isinstance(self, LightningBaseModule)
@ -159,21 +197,20 @@ class BinaryMaskDatasetMixin:
util_transforms])
# Datasets
from datasets.binar_masks import BinaryMasksDataset
dataset = Namespace(
**dict(
# TRAIN DATASET
train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train,
train_dataset=self.dataset_class(self.params.root, setting=V.DATA_OPTIONS.train,
use_preprocessed=self.params.use_preprocessed,
stretch_dataset=self.params.stretch,
mel_transforms=mel_transforms_train, transforms=aug_transforms),
# VALIDATION DATASET
val_train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train,
val_train_dataset=self.dataset_class(self.params.root, setting=V.DATA_OPTIONS.train,
mel_transforms=mel_transforms, transforms=util_transforms),
val_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.devel,
val_dataset=self.dataset_class(self.params.root, setting=V.DATA_OPTIONS.devel,
mel_transforms=mel_transforms, transforms=util_transforms),
# TEST DATASET
test_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.test,
test_dataset=self.dataset_class(self.params.root, setting=V.DATA_OPTIONS.test,
mel_transforms=mel_transforms, transforms=util_transforms),
)
)
@ -190,22 +227,23 @@ class BaseDataloadersMixin(ABC):
# sampler = RandomSampler(self.dataset.train_dataset, True, len(self.dataset.train_dataset))
sampler = None
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True if not sampler else None, sampler=sampler,
batch_size=self.params.batch_size,
batch_size=self.params.batch_size, pin_memory=True,
num_workers=self.params.worker)
# Test Dataloader
def test_dataloader(self):
assert isinstance(self, LightningBaseModule)
return DataLoader(dataset=self.dataset.test_dataset, shuffle=False,
batch_size=self.params.batch_size,
batch_size=self.params.batch_size, pin_memory=True,
num_workers=self.params.worker)
# Validation Dataloader
def val_dataloader(self):
assert isinstance(self, LightningBaseModule)
val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=False,
val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=False, pin_memory=True,
batch_size=self.params.batch_size, num_workers=self.params.worker)
train_dataloader = DataLoader(self.dataset.val_train_dataset, num_workers=self.params.worker,
pin_memory=True,
batch_size=self.params.batch_size, shuffle=False)
return [val_dataloader, train_dataloader]