diff --git a/_paramters.py b/_paramters.py index eb0c231..1d41a0f 100644 --- a/_paramters.py +++ b/_paramters.py @@ -28,6 +28,7 @@ main_arg_parser.add_argument("--data_sr", type=int, default=16000, help="") main_arg_parser.add_argument("--data_hop_length", type=int, default=256, help="") main_arg_parser.add_argument("--data_n_fft", type=int, default=512, help="") main_arg_parser.add_argument("--data_mixup", type=strtobool, default=False, help="") +main_arg_parser.add_argument("--data_stretch", type=strtobool, default=False, help="") # Transformation Parameters main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, help="") # 0.4 diff --git a/datasets/binar_masks.py b/datasets/binar_masks.py index 50fdb6c..c3e10e7 100644 --- a/datasets/binar_masks.py +++ b/datasets/binar_masks.py @@ -19,7 +19,7 @@ class BinaryMasksDataset(Dataset): def sample_shape(self): return self[0][0].shape - def __init__(self, data_root, setting, mel_transforms, transforms=None, mixup=False, stretch_dataset=True): + def __init__(self, data_root, setting, mel_transforms, transforms=None, mixup=False, stretch_dataset=False): self.stretch = stretch_dataset assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.' assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.' diff --git a/models/residual_conv_classifier.py b/models/residual_conv_classifier.py index b377074..5fcf9fd 100644 --- a/models/residual_conv_classifier.py +++ b/models/residual_conv_classifier.py @@ -39,15 +39,18 @@ class ResidualConvClassifier(BinaryMaskDatasetMixin, self.conv_list.append(ConvModule(last_shape, self.conv_filters[0], (k, k), conv_stride=(2, 2), conv_padding=1, **self.params.module_kwargs)) last_shape = self.conv_list[-1].shape - for filters in self.conv_filters: - conv_module_params.update(conv_filters=filters) + for idx in range(len(self.conv_filters)): + conv_module_params.update(conv_filters=self.conv_filters[idx]) self.conv_list.append(ResidualModule(last_shape, ConvModule, 3, **conv_module_params)) last_shape = self.conv_list[-1].shape - self.conv_list.append(ConvModule(last_shape, filters, (k, k), conv_stride=(2, 2), conv_padding=2, - **self.params.module_kwargs)) - for param in self.conv_list[-1].parameters(): - param.requires_grad = False - last_shape = self.conv_list[-1].shape + try: + self.conv_list.append(ConvModule(last_shape, self.conv_filters[idx+1], (k, k), conv_stride=(2, 2), conv_padding=2, + **self.params.module_kwargs)) + for param in self.conv_list[-1].parameters(): + param.requires_grad = False + last_shape = self.conv_list[-1].shape + except IndexError: + pass self.full_1 = LinearModule(self.conv_list[-1].shape, self.params.lat_dim, **self.params.module_kwargs) self.full_2 = LinearModule(self.full_1.shape, self.full_1.shape * 2, **self.params.module_kwargs) diff --git a/util/module_mixins.py b/util/module_mixins.py index ea4e2dc..66da9de 100644 --- a/util/module_mixins.py +++ b/util/module_mixins.py @@ -142,7 +142,7 @@ class BinaryMaskDatasetMixin: **dict( # TRAIN DATASET train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train, - mixup=self.params.mixup, + mixup=self.params.mixup, stretch_dataset=self.params.stretch, mel_transforms=mel_transforms_train, transforms=aug_transforms), # VALIDATION DATASET val_train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train,