ResidualModule and New Parameters, Speed Manipulation

This commit is contained in:
Si11ium
2020-05-12 12:37:26 +02:00
parent 3fbc98dfa3
commit 28bfcfdce3
8 changed files with 181 additions and 78 deletions

View File

@ -19,7 +19,8 @@ class BinaryMasksDataset(Dataset):
def sample_shape(self):
return self[0][0].shape
def __init__(self, data_root, setting, mel_transforms, transforms=None, mixup=False):
def __init__(self, data_root, setting, mel_transforms, transforms=None, mixup=False, stretch_dataset=True):
self.stretch = stretch_dataset
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
super(BinaryMasksDataset, self).__init__()
@ -29,11 +30,11 @@ class BinaryMasksDataset(Dataset):
self.mixup = mixup
self.container_ext = '.pik'
self._mel_transform = mel_transforms
self._transforms = transforms or F_x()
self._labels = self._build_labels()
self._wav_folder = self.data_root / 'wav'
self._wav_files = list(sorted(self._labels.keys()))
self._mel_folder = self.data_root / 'mel'
self._transforms = transforms or F_x(in_shape=None)
def _build_labels(self):
with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f:
@ -45,6 +46,8 @@ class BinaryMasksDataset(Dataset):
continue
filename, label = row.strip().split(',')
labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename
if self.stretch:
labeldict.update({f'X_{key}': val for key, val in labeldict.items()})
return labeldict
def __len__(self):
@ -52,7 +55,7 @@ class BinaryMasksDataset(Dataset):
def _compute_or_retrieve(self, filename):
if not (self._mel_folder / (filename + self.container_ext)).exists():
raw_sample, sr = librosa.core.load(self._wav_folder / (filename + '.wav'))
raw_sample, sr = librosa.core.load(self._wav_folder / (filename.replace('X_', '') + '.wav'))
mel_sample = self._mel_transform(raw_sample)
self._mel_folder.mkdir(exist_ok=True, parents=True)
with (self._mel_folder / (filename + self.container_ext)).open(mode='wb') as f:
@ -65,8 +68,9 @@ class BinaryMasksDataset(Dataset):
is_mixed = item >= len(self._labels)
if is_mixed:
item = item - len(self._labels)
key = self._wav_files[item]
filename = key[:-4]
key: str = list(self._labels.keys())[item]
filename = key.replace('.wav', '')
mel_sample = self._compute_or_retrieve(filename)
label = self._labels[key]