fingerprinted now should work correctly
This commit is contained in:
@ -1,7 +1,6 @@
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
import random
|
||||
|
||||
import librosa as librosa
|
||||
from torch.utils.data import Dataset
|
||||
@ -19,7 +18,7 @@ class BinaryMasksDataset(Dataset):
|
||||
def sample_shape(self):
|
||||
return self[0][0].shape
|
||||
|
||||
def __init__(self, data_root, setting, mel_transforms, transforms=None, mixup=False, stretch_dataset=False,
|
||||
def __init__(self, data_root, setting, mel_transforms, transforms=None, stretch_dataset=False,
|
||||
use_preprocessed=True):
|
||||
self.use_preprocessed = use_preprocessed
|
||||
self.stretch = stretch_dataset
|
||||
@ -29,7 +28,6 @@ class BinaryMasksDataset(Dataset):
|
||||
|
||||
self.data_root = Path(data_root)
|
||||
self.setting = setting
|
||||
self.mixup = mixup
|
||||
self._wav_folder = self.data_root / 'wav'
|
||||
self._mel_folder = self.data_root / 'mel'
|
||||
self.container_ext = '.pik'
|
||||
@ -40,19 +38,20 @@ class BinaryMasksDataset(Dataset):
|
||||
self._transforms = transforms or F_x(in_shape=None)
|
||||
|
||||
def _build_labels(self):
|
||||
labeldict = dict()
|
||||
with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f:
|
||||
# Exclude the header
|
||||
_ = next(f)
|
||||
labeldict = dict()
|
||||
for row in f:
|
||||
if self.setting not in row:
|
||||
continue
|
||||
filename, label = row.strip().split(',')
|
||||
labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename
|
||||
if self.stretch and self.setting == V.DATA_OPTIONS.train:
|
||||
additional_dict = ({f'X_{key}': val for key, val in labeldict.items()})
|
||||
additional_dict.update({f'X_X_{key}': val for key, val in labeldict.items()})
|
||||
additional_dict.update({f'X_X_X_{key}': val for key, val in labeldict.items()})
|
||||
additional_dict = ({f'X{key}': val for key, val in labeldict.items()})
|
||||
additional_dict.update({f'XX{key}': val for key, val in labeldict.items()})
|
||||
additional_dict.update({f'XXX{key}': val for key, val in labeldict.items()})
|
||||
additional_dict.update({f'XXXX{key}': val for key, val in labeldict.items()})
|
||||
labeldict.update(additional_dict)
|
||||
|
||||
# Delete File if one exists.
|
||||
@ -66,12 +65,12 @@ class BinaryMasksDataset(Dataset):
|
||||
return labeldict
|
||||
|
||||
def __len__(self):
|
||||
return len(self._labels) * 2 if self.mixup else len(self._labels)
|
||||
return len(self._labels)
|
||||
|
||||
def _compute_or_retrieve(self, filename):
|
||||
|
||||
if not (self._mel_folder / (filename + self.container_ext)).exists():
|
||||
raw_sample, sr = librosa.core.load(self._wav_folder / (filename.replace('X_', '') + '.wav'))
|
||||
raw_sample, sr = librosa.core.load(self._wav_folder / (filename.replace('X', '') + '.wav'))
|
||||
mel_sample = self._mel_transform(raw_sample)
|
||||
self._mel_folder.mkdir(exist_ok=True, parents=True)
|
||||
with (self._mel_folder / (filename + self.container_ext)).open(mode='wb') as f:
|
||||
@ -82,28 +81,16 @@ class BinaryMasksDataset(Dataset):
|
||||
return mel_sample
|
||||
|
||||
def __getitem__(self, item):
|
||||
is_mixed = item >= len(self._labels)
|
||||
if is_mixed:
|
||||
item = item - len(self._labels)
|
||||
|
||||
key: str = list(self._labels.keys())[item]
|
||||
filename = key.replace('.wav', '')
|
||||
mel_sample = self._compute_or_retrieve(filename)
|
||||
label = self._labels[key]
|
||||
|
||||
if is_mixed:
|
||||
label_sec = -1
|
||||
while label_sec != self._labels[key]:
|
||||
key_sec = random.choice(list(self._labels.keys()))
|
||||
label_sec = self._labels[key_sec]
|
||||
# noinspection PyUnboundLocalVariable
|
||||
filename_sec = key_sec[:-4]
|
||||
mel_sample_sec = self._compute_or_retrieve(filename_sec)
|
||||
mix_in_border = int(random.random() * mel_sample.shape[-1]) * random.choice([1, -1])
|
||||
mel_sample[:, :mix_in_border] = mel_sample_sec[:, :mix_in_border]
|
||||
|
||||
transformed_samples = self._transforms(mel_sample)
|
||||
if not self.setting == 'test':
|
||||
|
||||
if self.setting != V.DATA_OPTIONS.test:
|
||||
# In test, filenames instead of labels are returned. This is a little hacky though.
|
||||
label = torch.as_tensor(label, dtype=torch.float)
|
||||
|
||||
return transformed_samples, label
|
||||
|
Reference in New Issue
Block a user