dataset stretch now optional
This commit is contained in:
parent
09f9dd9131
commit
9ae9143544
@ -22,19 +22,19 @@ main_arg_parser.add_argument("--data_worker", type=int, default=11, help="")
|
|||||||
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
||||||
main_arg_parser.add_argument("--data_class_name", type=str, default='BinaryMasksDataset', help="")
|
main_arg_parser.add_argument("--data_class_name", type=str, default='BinaryMasksDataset', help="")
|
||||||
main_arg_parser.add_argument("--data_normalized", type=strtobool, default=True, help="")
|
main_arg_parser.add_argument("--data_normalized", type=strtobool, default=True, help="")
|
||||||
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="")
|
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=False, help="")
|
||||||
main_arg_parser.add_argument("--data_n_mels", type=int, default=64, help="")
|
main_arg_parser.add_argument("--data_n_mels", type=int, default=64, help="")
|
||||||
main_arg_parser.add_argument("--data_sr", type=int, default=16000, help="")
|
main_arg_parser.add_argument("--data_sr", type=int, default=16000, help="")
|
||||||
main_arg_parser.add_argument("--data_hop_length", type=int, default=256, help="")
|
main_arg_parser.add_argument("--data_hop_length", type=int, default=256, help="")
|
||||||
main_arg_parser.add_argument("--data_n_fft", type=int, default=512, help="")
|
main_arg_parser.add_argument("--data_n_fft", type=int, default=512, help="")
|
||||||
main_arg_parser.add_argument("--data_mixup", type=strtobool, default=False, help="")
|
main_arg_parser.add_argument("--data_mixup", type=strtobool, default=False, help="")
|
||||||
main_arg_parser.add_argument("--data_stretch", type=strtobool, default=False, help="")
|
main_arg_parser.add_argument("--data_stretch", type=strtobool, default=True, help="")
|
||||||
|
|
||||||
# Transformation Parameters
|
# Transformation Parameters
|
||||||
main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, help="") # 0.4
|
main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0.4, help="") # 0.4
|
||||||
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0, help="") # 0.3
|
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.3, help="") # 0.3
|
||||||
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4
|
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0.4, help="") # 0.4
|
||||||
main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="") # 0.2
|
main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0.2, help="") # 0.2
|
||||||
main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0.3, help="") # 0.3
|
main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0.3, help="") # 0.3
|
||||||
main_arg_parser.add_argument("--data_speed_factor", type=float, default=0.7, help="") # 0.7
|
main_arg_parser.add_argument("--data_speed_factor", type=float, default=0.7, help="") # 0.7
|
||||||
|
|
||||||
|
@ -19,7 +19,9 @@ class BinaryMasksDataset(Dataset):
|
|||||||
def sample_shape(self):
|
def sample_shape(self):
|
||||||
return self[0][0].shape
|
return self[0][0].shape
|
||||||
|
|
||||||
def __init__(self, data_root, setting, mel_transforms, transforms=None, mixup=False, stretch_dataset=False):
|
def __init__(self, data_root, setting, mel_transforms, transforms=None, mixup=False, stretch_dataset=False,
|
||||||
|
use_preprocessed=True):
|
||||||
|
self.use_preprocessed = use_preprocessed
|
||||||
self.stretch = stretch_dataset
|
self.stretch = stretch_dataset
|
||||||
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
|
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
|
||||||
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
|
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
|
||||||
@ -28,12 +30,13 @@ class BinaryMasksDataset(Dataset):
|
|||||||
self.data_root = Path(data_root)
|
self.data_root = Path(data_root)
|
||||||
self.setting = setting
|
self.setting = setting
|
||||||
self.mixup = mixup
|
self.mixup = mixup
|
||||||
|
self._wav_folder = self.data_root / 'wav'
|
||||||
|
self._mel_folder = self.data_root / 'mel'
|
||||||
self.container_ext = '.pik'
|
self.container_ext = '.pik'
|
||||||
self._mel_transform = mel_transforms
|
self._mel_transform = mel_transforms
|
||||||
|
|
||||||
self._labels = self._build_labels()
|
self._labels = self._build_labels()
|
||||||
self._wav_folder = self.data_root / 'wav'
|
|
||||||
self._wav_files = list(sorted(self._labels.keys()))
|
self._wav_files = list(sorted(self._labels.keys()))
|
||||||
self._mel_folder = self.data_root / 'mel'
|
|
||||||
self._transforms = transforms or F_x(in_shape=None)
|
self._transforms = transforms or F_x(in_shape=None)
|
||||||
|
|
||||||
def _build_labels(self):
|
def _build_labels(self):
|
||||||
@ -51,18 +54,25 @@ class BinaryMasksDataset(Dataset):
|
|||||||
additional_dict.update({f'X_X_{key}': val for key, val in labeldict.items()})
|
additional_dict.update({f'X_X_{key}': val for key, val in labeldict.items()})
|
||||||
additional_dict.update({f'X_X_X_{key}': val for key, val in labeldict.items()})
|
additional_dict.update({f'X_X_X_{key}': val for key, val in labeldict.items()})
|
||||||
labeldict.update(additional_dict)
|
labeldict.update(additional_dict)
|
||||||
|
|
||||||
|
# Delete File if one exists.
|
||||||
|
if not self.use_preprocessed:
|
||||||
|
for key in labeldict.keys():
|
||||||
|
(self._mel_folder / (key.replace('.wav', '') + self.container_ext)).unlink(missing_ok=True)
|
||||||
return labeldict
|
return labeldict
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self._labels) * 2 if self.mixup else len(self._labels)
|
return len(self._labels) * 2 if self.mixup else len(self._labels)
|
||||||
|
|
||||||
def _compute_or_retrieve(self, filename):
|
def _compute_or_retrieve(self, filename):
|
||||||
|
|
||||||
if not (self._mel_folder / (filename + self.container_ext)).exists():
|
if not (self._mel_folder / (filename + self.container_ext)).exists():
|
||||||
raw_sample, sr = librosa.core.load(self._wav_folder / (filename.replace('X_', '') + '.wav'))
|
raw_sample, sr = librosa.core.load(self._wav_folder / (filename.replace('X_', '') + '.wav'))
|
||||||
mel_sample = self._mel_transform(raw_sample)
|
mel_sample = self._mel_transform(raw_sample)
|
||||||
self._mel_folder.mkdir(exist_ok=True, parents=True)
|
self._mel_folder.mkdir(exist_ok=True, parents=True)
|
||||||
with (self._mel_folder / (filename + self.container_ext)).open(mode='wb') as f:
|
with (self._mel_folder / (filename + self.container_ext)).open(mode='wb') as f:
|
||||||
pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
|
pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
|
||||||
with (self._mel_folder / (filename + self.container_ext)).open(mode='rb') as f:
|
with (self._mel_folder / (filename + self.container_ext)).open(mode='rb') as f:
|
||||||
mel_sample = pickle.load(f, fix_imports=True)
|
mel_sample = pickle.load(f, fix_imports=True)
|
||||||
return mel_sample
|
return mel_sample
|
||||||
|
@ -41,13 +41,11 @@ class ResidualConvClassifier(BinaryMaskDatasetMixin,
|
|||||||
last_shape = self.conv_list[-1].shape
|
last_shape = self.conv_list[-1].shape
|
||||||
for idx in range(len(self.conv_filters)):
|
for idx in range(len(self.conv_filters)):
|
||||||
conv_module_params.update(conv_filters=self.conv_filters[idx])
|
conv_module_params.update(conv_filters=self.conv_filters[idx])
|
||||||
self.conv_list.append(ResidualModule(last_shape, ConvModule, 3, **conv_module_params))
|
self.conv_list.append(ResidualModule(last_shape, ConvModule, 2, **conv_module_params))
|
||||||
last_shape = self.conv_list[-1].shape
|
last_shape = self.conv_list[-1].shape
|
||||||
try:
|
try:
|
||||||
self.conv_list.append(ConvModule(last_shape, self.conv_filters[idx+1], (k, k), conv_stride=(2, 2), conv_padding=2,
|
self.conv_list.append(ConvModule(last_shape, self.conv_filters[idx+1], (k, k), conv_stride=(2, 2), conv_padding=2,
|
||||||
**self.params.module_kwargs))
|
**self.params.module_kwargs))
|
||||||
for param in self.conv_list[-1].parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
last_shape = self.conv_list[-1].shape
|
last_shape = self.conv_list[-1].shape
|
||||||
except IndexError:
|
except IndexError:
|
||||||
pass
|
pass
|
||||||
|
@ -142,6 +142,7 @@ class BinaryMaskDatasetMixin:
|
|||||||
**dict(
|
**dict(
|
||||||
# TRAIN DATASET
|
# TRAIN DATASET
|
||||||
train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train,
|
train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train,
|
||||||
|
use_preprocessed=self.params.use_preprocessed,
|
||||||
mixup=self.params.mixup, stretch_dataset=self.params.stretch,
|
mixup=self.params.mixup, stretch_dataset=self.params.stretch,
|
||||||
mel_transforms=mel_transforms_train, transforms=aug_transforms),
|
mel_transforms=mel_transforms_train, transforms=aug_transforms),
|
||||||
# VALIDATION DATASET
|
# VALIDATION DATASET
|
||||||
|
Loading…
x
Reference in New Issue
Block a user