dataset stretch now optional

This commit is contained in:
Si11ium 2020-05-15 11:06:23 +02:00
parent 09f9dd9131
commit 9ae9143544
4 changed files with 21 additions and 12 deletions

View File

@ -22,19 +22,19 @@ main_arg_parser.add_argument("--data_worker", type=int, default=11, help="")
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
main_arg_parser.add_argument("--data_class_name", type=str, default='BinaryMasksDataset', help="")
main_arg_parser.add_argument("--data_normalized", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--data_n_mels", type=int, default=64, help="")
main_arg_parser.add_argument("--data_sr", type=int, default=16000, help="")
main_arg_parser.add_argument("--data_hop_length", type=int, default=256, help="")
main_arg_parser.add_argument("--data_n_fft", type=int, default=512, help="")
main_arg_parser.add_argument("--data_mixup", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--data_stretch", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--data_stretch", type=strtobool, default=True, help="")
# Transformation Parameters
main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, help="") # 0.4
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0, help="") # 0.3
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4
main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="") # 0.2
main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0.4, help="") # 0.4
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.3, help="") # 0.3
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0.4, help="") # 0.4
main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0.2, help="") # 0.2
main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0.3, help="") # 0.3
main_arg_parser.add_argument("--data_speed_factor", type=float, default=0.7, help="") # 0.7

View File

@ -19,7 +19,9 @@ class BinaryMasksDataset(Dataset):
def sample_shape(self):
return self[0][0].shape
def __init__(self, data_root, setting, mel_transforms, transforms=None, mixup=False, stretch_dataset=False):
def __init__(self, data_root, setting, mel_transforms, transforms=None, mixup=False, stretch_dataset=False,
use_preprocessed=True):
self.use_preprocessed = use_preprocessed
self.stretch = stretch_dataset
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
@ -28,12 +30,13 @@ class BinaryMasksDataset(Dataset):
self.data_root = Path(data_root)
self.setting = setting
self.mixup = mixup
self._wav_folder = self.data_root / 'wav'
self._mel_folder = self.data_root / 'mel'
self.container_ext = '.pik'
self._mel_transform = mel_transforms
self._labels = self._build_labels()
self._wav_folder = self.data_root / 'wav'
self._wav_files = list(sorted(self._labels.keys()))
self._mel_folder = self.data_root / 'mel'
self._transforms = transforms or F_x(in_shape=None)
def _build_labels(self):
@ -51,18 +54,25 @@ class BinaryMasksDataset(Dataset):
additional_dict.update({f'X_X_{key}': val for key, val in labeldict.items()})
additional_dict.update({f'X_X_X_{key}': val for key, val in labeldict.items()})
labeldict.update(additional_dict)
# Delete File if one exists.
if not self.use_preprocessed:
for key in labeldict.keys():
(self._mel_folder / (key.replace('.wav', '') + self.container_ext)).unlink(missing_ok=True)
return labeldict
def __len__(self):
return len(self._labels) * 2 if self.mixup else len(self._labels)
def _compute_or_retrieve(self, filename):
if not (self._mel_folder / (filename + self.container_ext)).exists():
raw_sample, sr = librosa.core.load(self._wav_folder / (filename.replace('X_', '') + '.wav'))
mel_sample = self._mel_transform(raw_sample)
self._mel_folder.mkdir(exist_ok=True, parents=True)
with (self._mel_folder / (filename + self.container_ext)).open(mode='wb') as f:
pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
with (self._mel_folder / (filename + self.container_ext)).open(mode='rb') as f:
mel_sample = pickle.load(f, fix_imports=True)
return mel_sample

View File

@ -41,13 +41,11 @@ class ResidualConvClassifier(BinaryMaskDatasetMixin,
last_shape = self.conv_list[-1].shape
for idx in range(len(self.conv_filters)):
conv_module_params.update(conv_filters=self.conv_filters[idx])
self.conv_list.append(ResidualModule(last_shape, ConvModule, 3, **conv_module_params))
self.conv_list.append(ResidualModule(last_shape, ConvModule, 2, **conv_module_params))
last_shape = self.conv_list[-1].shape
try:
self.conv_list.append(ConvModule(last_shape, self.conv_filters[idx+1], (k, k), conv_stride=(2, 2), conv_padding=2,
**self.params.module_kwargs))
for param in self.conv_list[-1].parameters():
param.requires_grad = False
last_shape = self.conv_list[-1].shape
except IndexError:
pass

View File

@ -142,6 +142,7 @@ class BinaryMaskDatasetMixin:
**dict(
# TRAIN DATASET
train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train,
use_preprocessed=self.params.use_preprocessed,
mixup=self.params.mixup, stretch_dataset=self.params.stretch,
mel_transforms=mel_transforms_train, transforms=aug_transforms),
# VALIDATION DATASET