Test Dataset Multiplication by Timeshift

This commit is contained in:
Si11ium 2020-05-12 16:18:02 +02:00
parent 28bfcfdce3
commit dce799a52b
4 changed files with 9 additions and 7 deletions

View File

@ -34,7 +34,7 @@ main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, hel
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0, help="")
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="")
main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="")
main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0.5, help="")
main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0.3, help="")
main_arg_parser.add_argument("--data_speed_factor", type=float, default=0.7, help="")
# Training Parameters
@ -43,8 +43,8 @@ main_arg_parser.add_argument("--train_version", type=strtobool, required=False,
# FIXME: Stochastic weight Avaraging is not good, maybe its my implementation?
main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=0, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=100, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=250, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=30, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=300, help="")
main_arg_parser.add_argument("--train_lr", type=float, default=1e-4, help="")
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
@ -53,7 +53,7 @@ main_arg_parser.add_argument("--model_type", type=str, default="CC", help="")
main_arg_parser.add_argument("--model_secondary_type", type=str, default="CC", help="")
main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="")
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 256, 16]", help="")
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="")
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
main_arg_parser.add_argument("--model_lat_dim", type=int, default=128, help="")
main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="")

View File

@ -46,8 +46,10 @@ class BinaryMasksDataset(Dataset):
continue
filename, label = row.strip().split(',')
labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename
if self.stretch:
if self.stretch and self.setting == V.DATA_OPTIONS.train:
labeldict.update({f'X_{key}': val for key, val in labeldict.items()})
labeldict.update({f'X_X_{key}': val for key, val in labeldict.items()})
labeldict.update({f'X_X_X_{key}': val for key, val in labeldict.items()})
return labeldict
def __len__(self):

View File

@ -25,7 +25,7 @@ class BaseOptimizerMixin:
def configure_optimizers(self):
assert isinstance(self, LightningBaseModule)
opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=0.04)
opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=1e-7)
if self.params.sto_weight_avg:
opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05)
return opt