Model Init
This commit is contained in:
53
datasets/binar_masks.py
Normal file
53
datasets/binar_masks.py
Normal file
@ -0,0 +1,53 @@
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import librosa as librosa
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import variables as V
|
||||
|
||||
|
||||
class BinaryMasks(Dataset):
|
||||
_to_label = defaultdict(-1)
|
||||
_to_label['clear'] = V.CLEAR
|
||||
_to_label['mask'] = V.MASK
|
||||
|
||||
def __init__(self, data_root, setting):
|
||||
assert isinstance(setting, str)
|
||||
assert setting in ['test', 'devel', 'train']
|
||||
super(BinaryMasks, self).__init__()
|
||||
|
||||
self.data_root = Path(data_root)
|
||||
self.setting = setting
|
||||
self._labels = self._build_labels()
|
||||
self._wav_folder = self.data_root / 'wav'
|
||||
self._files = list(sorted(self._labels.keys()))
|
||||
|
||||
def _build_labels(self):
|
||||
with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f:
|
||||
# Exclude the header
|
||||
_ = next(f)
|
||||
labeldict = dict()
|
||||
for row in f:
|
||||
if self.setting not in row:
|
||||
continue
|
||||
filename, label = row.split(',')
|
||||
labeldict[filename] = self._to_label[label.lower()]
|
||||
return labeldict
|
||||
|
||||
def __len__(self):
|
||||
return len(self._labels)
|
||||
|
||||
def __getitem__(self, item):
|
||||
key = self._files[item]
|
||||
sample = librosa.core.load(self._wav_folder / self._files[key])
|
||||
label = self._labels[key]
|
||||
return sample, label
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user