Merge remote-tracking branch 'origin/master'

# Conflicts:
#	_paramters.py
#	main_inference.py
This commit is contained in:
steffen
2020-05-15 19:48:27 +02:00
2 changed files with 19 additions and 19 deletions
+14 -14
View File
@@ -35,8 +35,20 @@ main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, hel
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0, help="") # 0.3 main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0, help="") # 0.3
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4 main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4
main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="") # 0.2 main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="") # 0.2
main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0, help="") # 0.3 main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0.3, help="") # 0.3
main_arg_parser.add_argument("--data_speed_factor", type=float, default=0, help="") # 0.7 main_arg_parser.add_argument("--data_speed_factor", type=float, default=0.7, help="") # 0.7
# Model Parameters
main_arg_parser.add_argument("--model_type", type=str, default="RCC", help="")
main_arg_parser.add_argument("--model_secondary_type", type=str, default="RCC", help="")
main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="")
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="")
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
main_arg_parser.add_argument("--model_lat_dim", type=int, default=128, help="")
main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_norm", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_dropout", type=float, default=0.2, help="")
# Training Parameters # Training Parameters
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="") main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
@@ -49,18 +61,6 @@ main_arg_parser.add_argument("--train_batch_size", type=int, default=300, help="
main_arg_parser.add_argument("--train_lr", type=float, default=1e-4, help="") main_arg_parser.add_argument("--train_lr", type=float, default=1e-4, help="")
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="") main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
# Model Parameters
main_arg_parser.add_argument("--model_type", type=str, default="CC", help="")
main_arg_parser.add_argument("--model_secondary_type", type=str, default="CC", help="")
main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="")
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="")
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
main_arg_parser.add_argument("--model_lat_dim", type=int, default=128, help="")
main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_norm", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_dropout", type=float, default=0.2, help="")
# Project Parameters # Project Parameters
main_arg_parser.add_argument("--project_name", type=str, default=_ROOT.name, help="") main_arg_parser.add_argument("--project_name", type=str, default=_ROOT.name, help="")
main_arg_parser.add_argument("--project_owner", type=str, default='si11ium', help="") main_arg_parser.add_argument("--project_owner", type=str, default='si11ium', help="")
+5 -5
View File
@@ -34,10 +34,9 @@ def prepare_dataloader(config_obj):
transforms = Compose([NormalizeLocal(), ToTensor()]) transforms = Compose([NormalizeLocal(), ToTensor()])
aug_transforms = Compose([ aug_transforms = Compose([
NoiseInjection(0.4), NoiseInjection(0.4),
LoudnessManipulator(0), LoudnessManipulator(0.4),
ShiftTime(0), ShiftTime(0.3),
MaskAug(0), MaskAug(0.2),
# Utility
NormalizeLocal(), ToTensor() NormalizeLocal(), ToTensor()
]) ])
@@ -73,8 +72,9 @@ if __name__ == '__main__':
p = Plotter(outpath) p = Plotter(outpath)
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
d = test_dataloader.dataset[0][0].squeeze() d = test_dataloader.dataset[100][0].squeeze()
plt.imshow(d) plt.imshow(d)
p.save_current_figure('100')
loaded_model = restore_logger_and_model(config) loaded_model = restore_logger_and_model(config)
loaded_model.eval() loaded_model.eval()