From dce799a52b81dff918313dd9ba1d89890cf19f8a Mon Sep 17 00:00:00 2001 From: Si11ium Date: Tue, 12 May 2020 16:18:02 +0200 Subject: [PATCH] Test Dataset Multiplication by Timeshift --- _paramters.py | 8 ++++---- datasets/binar_masks.py | 4 +++- models/conv_classifier.py | 2 +- util/module_mixins.py | 2 +- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/_paramters.py b/_paramters.py index 7a79b94..97bc52d 100644 --- a/_paramters.py +++ b/_paramters.py @@ -34,7 +34,7 @@ main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, hel main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0, help="") main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="") -main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0.5, help="") +main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0.3, help="") main_arg_parser.add_argument("--data_speed_factor", type=float, default=0.7, help="") # Training Parameters @@ -43,8 +43,8 @@ main_arg_parser.add_argument("--train_version", type=strtobool, required=False, # FIXME: Stochastic weight Avaraging is not good, maybe its my implementation? main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=True, help="") main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=0, help="") -main_arg_parser.add_argument("--train_epochs", type=int, default=100, help="") -main_arg_parser.add_argument("--train_batch_size", type=int, default=250, help="") +main_arg_parser.add_argument("--train_epochs", type=int, default=30, help="") +main_arg_parser.add_argument("--train_batch_size", type=int, default=300, help="") main_arg_parser.add_argument("--train_lr", type=float, default=1e-4, help="") main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="") @@ -53,7 +53,7 @@ main_arg_parser.add_argument("--model_type", type=str, default="CC", help="") main_arg_parser.add_argument("--model_secondary_type", type=str, default="CC", help="") main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="") main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="") -main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 256, 16]", help="") +main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="") main_arg_parser.add_argument("--model_classes", type=int, default=2, help="") main_arg_parser.add_argument("--model_lat_dim", type=int, default=128, help="") main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="") diff --git a/datasets/binar_masks.py b/datasets/binar_masks.py index 20a2072..5c03a7c 100644 --- a/datasets/binar_masks.py +++ b/datasets/binar_masks.py @@ -46,8 +46,10 @@ class BinaryMasksDataset(Dataset): continue filename, label = row.strip().split(',') labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename - if self.stretch: + if self.stretch and self.setting == V.DATA_OPTIONS.train: labeldict.update({f'X_{key}': val for key, val in labeldict.items()}) + labeldict.update({f'X_X_{key}': val for key, val in labeldict.items()}) + labeldict.update({f'X_X_X_{key}': val for key, val in labeldict.items()}) return labeldict def __len__(self): diff --git a/models/conv_classifier.py b/models/conv_classifier.py index ca4b92f..f9a3210 100644 --- a/models/conv_classifier.py +++ b/models/conv_classifier.py @@ -35,7 +35,7 @@ class ConvClassifier(BinaryMaskDatasetFunction, last_shape = self.in_shape k = 3 # Base Kernel Value for filters in self.conv_filters: - self.conv_list.append(ConvModule(last_shape, filters, (k,k), conv_stride=(2, 2), conv_padding=2, + self.conv_list.append(ConvModule(last_shape, filters, (k, k), conv_stride=(2, 2), conv_padding=2, **self.params.module_kwargs)) last_shape = self.conv_list[-1].shape # self.conv_list.append(ConvModule(last_shape, 1, 1, conv_stride=1, **self.params.module_kwargs)) diff --git a/util/module_mixins.py b/util/module_mixins.py index 5a0c6bc..aa233ba 100644 --- a/util/module_mixins.py +++ b/util/module_mixins.py @@ -25,7 +25,7 @@ class BaseOptimizerMixin: def configure_optimizers(self): assert isinstance(self, LightningBaseModule) - opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=0.04) + opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=1e-7) if self.params.sto_weight_avg: opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05) return opt