2020-12-01 16:37:16 +01:00

119 lines
4.6 KiB
Python

import pickle
from collections import defaultdict
from pathlib import Path
import librosa as librosa
from torch.utils.data import Dataset
import torch
import variables as V
from ml_lib.modules.util import F_x
class BinaryMasksDataset(Dataset):
_to_label = defaultdict(lambda: -1)
_to_label.update(dict(clear=V.CLEAR, mask=V.MASK))
@property
def sample_shape(self):
return self[0][0].shape
@property
def _fingerprint(self):
return dict(**self._mel_kwargs, normalize=self.normalize)
def __init__(self, data_root, setting, mel_transforms, transforms=None, stretch_dataset=False,
use_preprocessed=True):
self.stretch = stretch_dataset
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
super(BinaryMasksDataset, self).__init__()
self.data_root = Path(data_root) / 'ComParE2020_Mask'
self.setting = setting
self._wav_folder = self.data_root / 'wav'
self._mel_folder = self.data_root / 'mel'
self.container_ext = '.pik'
self._mel_transform = mel_transforms
self._labels = self._build_labels()
self._wav_files = list(sorted(self._labels.keys()))
self._transforms = transforms or F_x(in_shape=None)
param_storage = self._mel_folder / 'data_params.pik'
self._mel_folder.mkdir(parents=True, exist_ok=True)
try:
pik_data = param_storage.read_bytes()
fingerprint = pickle.loads(pik_data)
if fingerprint == self._fingerprint:
self.use_preprocessed = use_preprocessed
else:
print('Diverging parameters were found; Refreshing...')
param_storage.unlink()
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
self.use_preprocessed = True
except FileNotFoundError:
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
self.use_preprocessed = True
def _build_labels(self):
labeldict = dict()
labelfile = 'labels' if self.setting != V.DATA_OPTIONS.test else V.DATA_OPTIONS.test
with open(Path(self.data_root) / 'lab' / f'{labelfile}.csv', mode='r') as f:
# Exclude the header
_ = next(f)
for row in f:
if self.setting not in row:
continue
filename, label = row.strip().split(',')
labeldict[filename] = self._to_label[label.lower()] # if not self.setting == 'test' else filename
if self.stretch and self.setting == V.DATA_OPTIONS.train:
additional_dict = ({f'X{key}': val for key, val in labeldict.items()})
additional_dict.update({f'XX{key}': val for key, val in labeldict.items()})
additional_dict.update({f'XXX{key}': val for key, val in labeldict.items()})
labeldict.update(additional_dict)
# Delete File if one exists.
if not self.use_preprocessed:
for key in labeldict.keys():
try:
(self._mel_folder / (key.replace('.wav', '') + self.container_ext)).unlink()
except FileNotFoundError:
pass
return labeldict
def __len__(self):
return len(self._labels)
def _compute_or_retrieve(self, filename):
if not (self._mel_folder / (filename + self.container_ext)).exists():
raw_sample, sr = librosa.core.load(self._wav_folder / (filename.replace('X', '') + '.wav'))
mel_sample = self._mel_transform(raw_sample)
self._mel_folder.mkdir(exist_ok=True, parents=True)
with (self._mel_folder / (filename + self.container_ext)).open(mode='wb') as f:
pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
with (self._mel_folder / (filename + self.container_ext)).open(mode='rb') as f:
mel_sample = pickle.load(f, fix_imports=True)
return mel_sample
def __getitem__(self, item):
key: str = list(self._labels.keys())[item]
filename = key.replace('.wav', '')
mel_sample = self._compute_or_retrieve(filename)
label = self._labels[key]
transformed_samples = self._transforms(mel_sample)
if self.setting != V.DATA_OPTIONS.test:
# In test, filenames instead of labels are returned. This is a little hacky though.
label = torch.as_tensor(label, dtype=torch.float)
return transformed_samples, label