Model Training

This commit is contained in:
Si11ium
2020-05-03 18:00:51 +02:00
parent 8a97f59906
commit e4f6506a4b
9 changed files with 167 additions and 105 deletions

View File

@ -4,6 +4,7 @@ from pathlib import Path
import librosa as librosa
from torch.utils.data import Dataset
import torch
import variables as V
from ml_lib.modules.utils import F_x
@ -11,18 +12,16 @@ from ml_lib.modules.utils import F_x
class BinaryMasksDataset(Dataset):
_to_label = defaultdict(lambda: -1)
_to_label['clear'] = V.CLEAR
_to_label['mask'] = V.MASK
settings = ['test', 'devel', 'train']
_to_label.update(dict(clear=V.CLEAR, mask=V.MASK))
@property
def sample_shape(self):
return self[0][0].shape
def __init__(self, data_root, setting, transforms=None):
assert isinstance(setting, str), f'Setting has to be a string, but was: {self.settings}.'
assert setting in self.settings, f'Setting must match one of: {self.settings}.'
assert callable(transforms) or None, f'Transforms has to be callable, but was: {transforms}'
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
assert callable(transforms) or None, f'Transforms has to be callable, but was: {type(transforms)}'
super(BinaryMasksDataset, self).__init__()
self.data_root = Path(data_root)
@ -41,7 +40,7 @@ class BinaryMasksDataset(Dataset):
for row in f:
if self.setting not in row:
continue
filename, label = row.split(',')
filename, label = row.strip().split(',')
labeldict[filename] = self._to_label[label.lower()]
return labeldict
@ -60,5 +59,5 @@ class BinaryMasksDataset(Dataset):
pickle.dump(transformed_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
with (self._mel_folder / filename).open(mode='rb') as f:
sample = pickle.load(f, fix_imports=True)
label = self._labels[key]
label = torch.as_tensor(self._labels[key], dtype=torch.float)
return sample, label