diff --git a/_paramters.py b/_paramters.py index 1d41a0f..23de70c 100644 --- a/_paramters.py +++ b/_paramters.py @@ -22,13 +22,13 @@ main_arg_parser.add_argument("--data_worker", type=int, default=11, help="") main_arg_parser.add_argument("--data_root", type=str, default='data', help="") main_arg_parser.add_argument("--data_class_name", type=str, default='BinaryMasksDataset', help="") main_arg_parser.add_argument("--data_normalized", type=strtobool, default=True, help="") -main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="") +main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=False, help="") main_arg_parser.add_argument("--data_n_mels", type=int, default=64, help="") main_arg_parser.add_argument("--data_sr", type=int, default=16000, help="") main_arg_parser.add_argument("--data_hop_length", type=int, default=256, help="") main_arg_parser.add_argument("--data_n_fft", type=int, default=512, help="") main_arg_parser.add_argument("--data_mixup", type=strtobool, default=False, help="") -main_arg_parser.add_argument("--data_stretch", type=strtobool, default=False, help="") +main_arg_parser.add_argument("--data_stretch", type=strtobool, default=True, help="") # Transformation Parameters main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, help="") # 0.4 diff --git a/datasets/binar_masks.py b/datasets/binar_masks.py index c3e10e7..f587ba6 100644 --- a/datasets/binar_masks.py +++ b/datasets/binar_masks.py @@ -19,7 +19,9 @@ class BinaryMasksDataset(Dataset): def sample_shape(self): return self[0][0].shape - def __init__(self, data_root, setting, mel_transforms, transforms=None, mixup=False, stretch_dataset=False): + def __init__(self, data_root, setting, mel_transforms, transforms=None, mixup=False, stretch_dataset=False, + use_preprocessed=True): + self.use_preprocessed = use_preprocessed self.stretch = stretch_dataset assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.' assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.' @@ -28,12 +30,13 @@ class BinaryMasksDataset(Dataset): self.data_root = Path(data_root) self.setting = setting self.mixup = mixup + self._wav_folder = self.data_root / 'wav' + self._mel_folder = self.data_root / 'mel' self.container_ext = '.pik' self._mel_transform = mel_transforms + self._labels = self._build_labels() - self._wav_folder = self.data_root / 'wav' self._wav_files = list(sorted(self._labels.keys())) - self._mel_folder = self.data_root / 'mel' self._transforms = transforms or F_x(in_shape=None) def _build_labels(self): @@ -51,18 +54,25 @@ class BinaryMasksDataset(Dataset): additional_dict.update({f'X_X_{key}': val for key, val in labeldict.items()}) additional_dict.update({f'X_X_X_{key}': val for key, val in labeldict.items()}) labeldict.update(additional_dict) + + # Delete File if one exists. + if not self.use_preprocessed: + for key in labeldict.keys(): + (self._mel_folder / (key.replace('.wav', '') + self.container_ext)).unlink(missing_ok=True) return labeldict def __len__(self): return len(self._labels) * 2 if self.mixup else len(self._labels) def _compute_or_retrieve(self, filename): + if not (self._mel_folder / (filename + self.container_ext)).exists(): raw_sample, sr = librosa.core.load(self._wav_folder / (filename.replace('X_', '') + '.wav')) mel_sample = self._mel_transform(raw_sample) self._mel_folder.mkdir(exist_ok=True, parents=True) with (self._mel_folder / (filename + self.container_ext)).open(mode='wb') as f: pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL) + with (self._mel_folder / (filename + self.container_ext)).open(mode='rb') as f: mel_sample = pickle.load(f, fix_imports=True) return mel_sample diff --git a/models/residual_conv_classifier.py b/models/residual_conv_classifier.py index 5fcf9fd..c1465e8 100644 --- a/models/residual_conv_classifier.py +++ b/models/residual_conv_classifier.py @@ -41,13 +41,11 @@ class ResidualConvClassifier(BinaryMaskDatasetMixin, last_shape = self.conv_list[-1].shape for idx in range(len(self.conv_filters)): conv_module_params.update(conv_filters=self.conv_filters[idx]) - self.conv_list.append(ResidualModule(last_shape, ConvModule, 3, **conv_module_params)) + self.conv_list.append(ResidualModule(last_shape, ConvModule, 2, **conv_module_params)) last_shape = self.conv_list[-1].shape try: self.conv_list.append(ConvModule(last_shape, self.conv_filters[idx+1], (k, k), conv_stride=(2, 2), conv_padding=2, **self.params.module_kwargs)) - for param in self.conv_list[-1].parameters(): - param.requires_grad = False last_shape = self.conv_list[-1].shape except IndexError: pass diff --git a/util/module_mixins.py b/util/module_mixins.py index 66da9de..ee93d81 100644 --- a/util/module_mixins.py +++ b/util/module_mixins.py @@ -142,6 +142,7 @@ class BinaryMaskDatasetMixin: **dict( # TRAIN DATASET train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train, + use_preprocessed=self.params.use_preprocessed, mixup=self.params.mixup, stretch_dataset=self.params.stretch, mel_transforms=mel_transforms_train, transforms=aug_transforms), # VALIDATION DATASET