import pickle from collections import defaultdict from pathlib import Path import random import librosa as librosa from torch.utils.data import Dataset import torch import variables as V from ml_lib.modules.utils import F_x class BinaryMasksDataset(Dataset): _to_label = defaultdict(lambda: -1) _to_label.update(dict(clear=V.CLEAR, mask=V.MASK)) @property def sample_shape(self): return self[0][0].shape def __init__(self, data_root, setting, mel_transforms, transforms=None, mixup=False, stretch_dataset=True): self.stretch = stretch_dataset assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.' assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.' super(BinaryMasksDataset, self).__init__() self.data_root = Path(data_root) self.setting = setting self.mixup = mixup self.container_ext = '.pik' self._mel_transform = mel_transforms self._labels = self._build_labels() self._wav_folder = self.data_root / 'wav' self._wav_files = list(sorted(self._labels.keys())) self._mel_folder = self.data_root / 'mel' self._transforms = transforms or F_x(in_shape=None) def _build_labels(self): with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f: # Exclude the header _ = next(f) labeldict = dict() for row in f: if self.setting not in row: continue filename, label = row.strip().split(',') labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename if self.stretch: labeldict.update({f'X_{key}': val for key, val in labeldict.items()}) return labeldict def __len__(self): return len(self._labels) * 2 if self.mixup else len(self._labels) def _compute_or_retrieve(self, filename): if not (self._mel_folder / (filename + self.container_ext)).exists(): raw_sample, sr = librosa.core.load(self._wav_folder / (filename.replace('X_', '') + '.wav')) mel_sample = self._mel_transform(raw_sample) self._mel_folder.mkdir(exist_ok=True, parents=True) with (self._mel_folder / (filename + self.container_ext)).open(mode='wb') as f: pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL) with (self._mel_folder / (filename + self.container_ext)).open(mode='rb') as f: mel_sample = pickle.load(f, fix_imports=True) return mel_sample def __getitem__(self, item): is_mixed = item >= len(self._labels) if is_mixed: item = item - len(self._labels) key: str = list(self._labels.keys())[item] filename = key.replace('.wav', '') mel_sample = self._compute_or_retrieve(filename) label = self._labels[key] if is_mixed: label_sec = -1 while label_sec != self._labels[key]: key_sec = random.choice(list(self._labels.keys())) label_sec = self._labels[key_sec] # noinspection PyUnboundLocalVariable filename_sec = key_sec[:-4] mel_sample_sec = self._compute_or_retrieve(filename_sec) mix_in_border = int(random.random() * mel_sample.shape[-1]) * random.choice([1, -1]) mel_sample[:, :mix_in_border] = mel_sample_sec[:, :mix_in_border] transformed_samples = self._transforms(mel_sample) if not self.setting == 'test': label = torch.as_tensor(label, dtype=torch.float) return transformed_samples, label