2020-04-15 15:57:49 +02:00

54 lines
1.4 KiB
Python

from collections import defaultdict
from pathlib import Path
import librosa as librosa
from torch.utils.data import Dataset
import variables as V
class BinaryMasks(Dataset):
_to_label = defaultdict(-1)
_to_label['clear'] = V.CLEAR
_to_label['mask'] = V.MASK
def __init__(self, data_root, setting):
assert isinstance(setting, str)
assert setting in ['test', 'devel', 'train']
super(BinaryMasks, self).__init__()
self.data_root = Path(data_root)
self.setting = setting
self._labels = self._build_labels()
self._wav_folder = self.data_root / 'wav'
self._files = list(sorted(self._labels.keys()))
def _build_labels(self):
with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f:
# Exclude the header
_ = next(f)
labeldict = dict()
for row in f:
if self.setting not in row:
continue
filename, label = row.split(',')
labeldict[filename] = self._to_label[label.lower()]
return labeldict
def __len__(self):
return len(self._labels)
def __getitem__(self, item):
key = self._files[item]
sample = librosa.core.load(self._wav_folder / self._files[key])
label = self._labels[key]
return sample, label