Merge remote-tracking branch 'origin/master'

# Conflicts:
#	_paramters.py
#	main_inference.py
This commit is contained in:
steffen
2020-05-15 19:48:27 +02:00
2 changed files with 19 additions and 19 deletions

@ -35,8 +35,20 @@ main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, hel
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0, help="") # 0.3
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4
main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="") # 0.2
main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0, help="") # 0.3
main_arg_parser.add_argument("--data_speed_factor", type=float, default=0, help="") # 0.7
main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0.3, help="") # 0.3
main_arg_parser.add_argument("--data_speed_factor", type=float, default=0.7, help="") # 0.7
# Model Parameters
main_arg_parser.add_argument("--model_type", type=str, default="RCC", help="")
main_arg_parser.add_argument("--model_secondary_type", type=str, default="RCC", help="")
main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="")
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="")
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
main_arg_parser.add_argument("--model_lat_dim", type=int, default=128, help="")
main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_norm", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_dropout", type=float, default=0.2, help="")
# Training Parameters
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
@ -49,18 +61,6 @@ main_arg_parser.add_argument("--train_batch_size", type=int, default=300, help="
main_arg_parser.add_argument("--train_lr", type=float, default=1e-4, help="")
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
# Model Parameters
main_arg_parser.add_argument("--model_type", type=str, default="CC", help="")
main_arg_parser.add_argument("--model_secondary_type", type=str, default="CC", help="")
main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="")
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="")
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
main_arg_parser.add_argument("--model_lat_dim", type=int, default=128, help="")
main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_norm", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_dropout", type=float, default=0.2, help="")
# Project Parameters
main_arg_parser.add_argument("--project_name", type=str, default=_ROOT.name, help="")
main_arg_parser.add_argument("--project_owner", type=str, default='si11ium', help="")

@ -34,10 +34,9 @@ def prepare_dataloader(config_obj):
transforms = Compose([NormalizeLocal(), ToTensor()])
aug_transforms = Compose([
NoiseInjection(0.4),
LoudnessManipulator(0),
ShiftTime(0),
MaskAug(0),
# Utility
LoudnessManipulator(0.4),
ShiftTime(0.3),
MaskAug(0.2),
NormalizeLocal(), ToTensor()
])
@ -73,8 +72,9 @@ if __name__ == '__main__':
p = Plotter(outpath)
from matplotlib import pyplot as plt
d = test_dataloader.dataset[0][0].squeeze()
d = test_dataloader.dataset[100][0].squeeze()
plt.imshow(d)
p.save_current_figure('100')
loaded_model = restore_logger_and_model(config)
loaded_model.eval()