Test Dataset Multiplication by Timeshift

This commit is contained in:
Si11ium 2020-05-12 16:18:02 +02:00
parent 28bfcfdce3
commit dce799a52b
4 changed files with 9 additions and 7 deletions

View File

@ -34,7 +34,7 @@ main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, hel
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0, help="") main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0, help="")
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="")
main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="") main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="")
main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0.5, help="") main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0.3, help="")
main_arg_parser.add_argument("--data_speed_factor", type=float, default=0.7, help="") main_arg_parser.add_argument("--data_speed_factor", type=float, default=0.7, help="")
# Training Parameters # Training Parameters
@ -43,8 +43,8 @@ main_arg_parser.add_argument("--train_version", type=strtobool, required=False,
# FIXME: Stochastic weight Avaraging is not good, maybe its my implementation? # FIXME: Stochastic weight Avaraging is not good, maybe its my implementation?
main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=True, help="") main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=0, help="") main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=0, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=100, help="") main_arg_parser.add_argument("--train_epochs", type=int, default=30, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=250, help="") main_arg_parser.add_argument("--train_batch_size", type=int, default=300, help="")
main_arg_parser.add_argument("--train_lr", type=float, default=1e-4, help="") main_arg_parser.add_argument("--train_lr", type=float, default=1e-4, help="")
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="") main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
@ -53,7 +53,7 @@ main_arg_parser.add_argument("--model_type", type=str, default="CC", help="")
main_arg_parser.add_argument("--model_secondary_type", type=str, default="CC", help="") main_arg_parser.add_argument("--model_secondary_type", type=str, default="CC", help="")
main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="") main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="")
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="") main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 256, 16]", help="") main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="")
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="") main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
main_arg_parser.add_argument("--model_lat_dim", type=int, default=128, help="") main_arg_parser.add_argument("--model_lat_dim", type=int, default=128, help="")
main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="") main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="")

View File

@ -46,8 +46,10 @@ class BinaryMasksDataset(Dataset):
continue continue
filename, label = row.strip().split(',') filename, label = row.strip().split(',')
labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename
if self.stretch: if self.stretch and self.setting == V.DATA_OPTIONS.train:
labeldict.update({f'X_{key}': val for key, val in labeldict.items()}) labeldict.update({f'X_{key}': val for key, val in labeldict.items()})
labeldict.update({f'X_X_{key}': val for key, val in labeldict.items()})
labeldict.update({f'X_X_X_{key}': val for key, val in labeldict.items()})
return labeldict return labeldict
def __len__(self): def __len__(self):

View File

@ -35,7 +35,7 @@ class ConvClassifier(BinaryMaskDatasetFunction,
last_shape = self.in_shape last_shape = self.in_shape
k = 3 # Base Kernel Value k = 3 # Base Kernel Value
for filters in self.conv_filters: for filters in self.conv_filters:
self.conv_list.append(ConvModule(last_shape, filters, (k,k), conv_stride=(2, 2), conv_padding=2, self.conv_list.append(ConvModule(last_shape, filters, (k, k), conv_stride=(2, 2), conv_padding=2,
**self.params.module_kwargs)) **self.params.module_kwargs))
last_shape = self.conv_list[-1].shape last_shape = self.conv_list[-1].shape
# self.conv_list.append(ConvModule(last_shape, 1, 1, conv_stride=1, **self.params.module_kwargs)) # self.conv_list.append(ConvModule(last_shape, 1, 1, conv_stride=1, **self.params.module_kwargs))

View File

@ -25,7 +25,7 @@ class BaseOptimizerMixin:
def configure_optimizers(self): def configure_optimizers(self):
assert isinstance(self, LightningBaseModule) assert isinstance(self, LightningBaseModule)
opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=0.04) opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=1e-7)
if self.params.sto_weight_avg: if self.params.sto_weight_avg:
opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05) opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05)
return opt return opt