64 lines
2.4 KiB
Python
64 lines
2.4 KiB
Python
import pickle
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
|
|
import librosa as librosa
|
|
from torch.utils.data import Dataset
|
|
import torch
|
|
|
|
import variables as V
|
|
from ml_lib.modules.utils import F_x
|
|
|
|
|
|
class BinaryMasksDataset(Dataset):
|
|
_to_label = defaultdict(lambda: -1)
|
|
_to_label.update(dict(clear=V.CLEAR, mask=V.MASK))
|
|
|
|
@property
|
|
def sample_shape(self):
|
|
return self[0][0].shape
|
|
|
|
def __init__(self, data_root, setting, transforms=None):
|
|
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
|
|
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
|
|
assert callable(transforms) or None, f'Transforms has to be callable, but was: {type(transforms)}'
|
|
super(BinaryMasksDataset, self).__init__()
|
|
|
|
self.data_root = Path(data_root)
|
|
self.setting = setting
|
|
self._transforms = transforms or F_x()
|
|
self._labels = self._build_labels()
|
|
self._wav_folder = self.data_root / 'wav'
|
|
self._wav_files = list(sorted(self._labels.keys()))
|
|
self._transformed_folder = self.data_root / 'transformed'
|
|
|
|
def _build_labels(self):
|
|
with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f:
|
|
# Exclude the header
|
|
_ = next(f)
|
|
labeldict = dict()
|
|
for row in f:
|
|
if self.setting not in row:
|
|
continue
|
|
filename, label = row.strip().split(',')
|
|
labeldict[filename] = self._to_label[label.lower()]
|
|
return labeldict
|
|
|
|
def __len__(self):
|
|
return len(self._labels)
|
|
|
|
def __getitem__(self, item):
|
|
key = self._wav_files[item]
|
|
filename = key[:-4] + '.pik'
|
|
|
|
if not (self._transformed_folder / filename).exists():
|
|
raw_sample, sr = librosa.core.load(self._wav_folder / self._wav_files[item])
|
|
transformed_sample = self._transforms(raw_sample)
|
|
self._transformed_folder.mkdir(exist_ok=True, parents=True)
|
|
with (self._transformed_folder / filename).open(mode='wb') as f:
|
|
pickle.dump(transformed_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
|
|
with (self._transformed_folder / filename).open(mode='rb') as f:
|
|
sample = pickle.load(f, fix_imports=True)
|
|
label = torch.as_tensor(self._labels[key], dtype=torch.float)
|
|
return sample, label
|