Parameter Adjustmens and Ensemble Model Implementation

This commit is contained in:
Si11ium
2020-05-08 16:30:55 +02:00
parent c2860b0aed
commit 5e6b0e598f
16 changed files with 648 additions and 313 deletions

View File

@ -1,6 +1,7 @@
import pickle
from collections import defaultdict
from pathlib import Path
import random
import librosa as librosa
from torch.utils.data import Dataset
@ -18,19 +19,21 @@ class BinaryMasksDataset(Dataset):
def sample_shape(self):
return self[0][0].shape
def __init__(self, data_root, setting, transforms=None):
def __init__(self, data_root, setting, mel_transforms, transforms=None, mixup=False):
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
assert callable(transforms) or None, f'Transforms has to be callable, but was: {type(transforms)}'
super(BinaryMasksDataset, self).__init__()
self.data_root = Path(data_root)
self.setting = setting
self.mixup = mixup
self.container_ext = '.pik'
self._mel_transform = mel_transforms
self._transforms = transforms or F_x()
self._labels = self._build_labels()
self._wav_folder = self.data_root / 'wav'
self._wav_files = list(sorted(self._labels.keys()))
self._transformed_folder = self.data_root / 'transformed'
self._mel_folder = self.data_root / 'mel'
def _build_labels(self):
with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f:
@ -41,23 +44,45 @@ class BinaryMasksDataset(Dataset):
if self.setting not in row:
continue
filename, label = row.strip().split(',')
labeldict[filename] = self._to_label[label.lower()]
labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename
return labeldict
def __len__(self):
return len(self._labels)
return len(self._labels) * 2 if self.mixup else len(self._labels)
def _compute_or_retrieve(self, filename):
if not (self._mel_folder / (filename + self.container_ext)).exists():
raw_sample, sr = librosa.core.load(self._wav_folder / (filename + '.wav'))
mel_sample = self._mel_transform(raw_sample)
self._mel_folder.mkdir(exist_ok=True, parents=True)
with (self._mel_folder / (filename + self.container_ext)).open(mode='wb') as f:
pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
with (self._mel_folder / (filename + self.container_ext)).open(mode='rb') as f:
mel_sample = pickle.load(f, fix_imports=True)
return mel_sample
def __getitem__(self, item):
is_mixed = item >= len(self._labels)
if is_mixed:
item = item - len(self._labels)
key = self._wav_files[item]
filename = key[:-4] + '.pik'
filename = key[:-4]
mel_sample = self._compute_or_retrieve(filename)
label = self._labels[key]
if not (self._transformed_folder / filename).exists():
raw_sample, sr = librosa.core.load(self._wav_folder / self._wav_files[item])
transformed_sample = self._transforms(raw_sample)
self._transformed_folder.mkdir(exist_ok=True, parents=True)
with (self._transformed_folder / filename).open(mode='wb') as f:
pickle.dump(transformed_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
with (self._transformed_folder / filename).open(mode='rb') as f:
sample = pickle.load(f, fix_imports=True)
label = torch.as_tensor(self._labels[key], dtype=torch.float)
return sample, label
if is_mixed:
label_sec = -1
while label_sec != self._labels[key]:
key_sec = random.choice(list(self._labels.keys()))
label_sec = self._labels[key_sec]
# noinspection PyUnboundLocalVariable
filename_sec = key_sec[:-4]
mel_sample_sec = self._compute_or_retrieve(filename_sec)
mix_in_border = int(random.random() * mel_sample.shape[-1]) * random.choice([1, -1])
mel_sample[:, :mix_in_border] = mel_sample_sec[:, :mix_in_border]
transformed_samples = self._transforms(mel_sample)
if not self.setting == 'test':
label = torch.as_tensor(label, dtype=torch.float)
return transformed_samples, label