Model Training
This commit is contained in:
@ -4,6 +4,7 @@ from pathlib import Path
|
||||
|
||||
import librosa as librosa
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
|
||||
import variables as V
|
||||
from ml_lib.modules.utils import F_x
|
||||
@ -11,18 +12,16 @@ from ml_lib.modules.utils import F_x
|
||||
|
||||
class BinaryMasksDataset(Dataset):
|
||||
_to_label = defaultdict(lambda: -1)
|
||||
_to_label['clear'] = V.CLEAR
|
||||
_to_label['mask'] = V.MASK
|
||||
settings = ['test', 'devel', 'train']
|
||||
_to_label.update(dict(clear=V.CLEAR, mask=V.MASK))
|
||||
|
||||
@property
|
||||
def sample_shape(self):
|
||||
return self[0][0].shape
|
||||
|
||||
def __init__(self, data_root, setting, transforms=None):
|
||||
assert isinstance(setting, str), f'Setting has to be a string, but was: {self.settings}.'
|
||||
assert setting in self.settings, f'Setting must match one of: {self.settings}.'
|
||||
assert callable(transforms) or None, f'Transforms has to be callable, but was: {transforms}'
|
||||
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
|
||||
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
|
||||
assert callable(transforms) or None, f'Transforms has to be callable, but was: {type(transforms)}'
|
||||
super(BinaryMasksDataset, self).__init__()
|
||||
|
||||
self.data_root = Path(data_root)
|
||||
@ -41,7 +40,7 @@ class BinaryMasksDataset(Dataset):
|
||||
for row in f:
|
||||
if self.setting not in row:
|
||||
continue
|
||||
filename, label = row.split(',')
|
||||
filename, label = row.strip().split(',')
|
||||
labeldict[filename] = self._to_label[label.lower()]
|
||||
return labeldict
|
||||
|
||||
@ -60,5 +59,5 @@ class BinaryMasksDataset(Dataset):
|
||||
pickle.dump(transformed_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
with (self._mel_folder / filename).open(mode='rb') as f:
|
||||
sample = pickle.load(f, fix_imports=True)
|
||||
label = self._labels[key]
|
||||
label = torch.as_tensor(self._labels[key], dtype=torch.float)
|
||||
return sample, label
|
||||
|
Reference in New Issue
Block a user