from collections import defaultdict from pathlib import Path import librosa as librosa from torch.utils.data import Dataset import variables as V class BinaryMasks(Dataset): _to_label = defaultdict(-1) _to_label['clear'] = V.CLEAR _to_label['mask'] = V.MASK def __init__(self, data_root, setting): assert isinstance(setting, str) assert setting in ['test', 'devel', 'train'] super(BinaryMasks, self).__init__() self.data_root = Path(data_root) self.setting = setting self._labels = self._build_labels() self._wav_folder = self.data_root / 'wav' self._files = list(sorted(self._labels.keys())) def _build_labels(self): with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f: # Exclude the header _ = next(f) labeldict = dict() for row in f: if self.setting not in row: continue filename, label = row.split(',') labeldict[filename] = self._to_label[label.lower()] return labeldict def __len__(self): return len(self._labels) def __getitem__(self, item): key = self._files[item] sample = librosa.core.load(self._wav_folder / self._files[key]) label = self._labels[key] return sample, label