Merge remote-tracking branch 'origin/master'
# Conflicts: # _paramters.py # main_inference.py
This commit is contained in:
@ -35,8 +35,20 @@ main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, hel
|
||||
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0, help="") # 0.3
|
||||
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4
|
||||
main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="") # 0.2
|
||||
main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0, help="") # 0.3
|
||||
main_arg_parser.add_argument("--data_speed_factor", type=float, default=0, help="") # 0.7
|
||||
main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0.3, help="") # 0.3
|
||||
main_arg_parser.add_argument("--data_speed_factor", type=float, default=0.7, help="") # 0.7
|
||||
|
||||
# Model Parameters
|
||||
main_arg_parser.add_argument("--model_type", type=str, default="RCC", help="")
|
||||
main_arg_parser.add_argument("--model_secondary_type", type=str, default="RCC", help="")
|
||||
main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="")
|
||||
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
|
||||
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="")
|
||||
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
|
||||
main_arg_parser.add_argument("--model_lat_dim", type=int, default=128, help="")
|
||||
main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="")
|
||||
main_arg_parser.add_argument("--model_norm", type=strtobool, default=True, help="")
|
||||
main_arg_parser.add_argument("--model_dropout", type=float, default=0.2, help="")
|
||||
|
||||
# Training Parameters
|
||||
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
|
||||
@ -49,18 +61,6 @@ main_arg_parser.add_argument("--train_batch_size", type=int, default=300, help="
|
||||
main_arg_parser.add_argument("--train_lr", type=float, default=1e-4, help="")
|
||||
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
|
||||
|
||||
# Model Parameters
|
||||
main_arg_parser.add_argument("--model_type", type=str, default="CC", help="")
|
||||
main_arg_parser.add_argument("--model_secondary_type", type=str, default="CC", help="")
|
||||
main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="")
|
||||
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
|
||||
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="")
|
||||
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
|
||||
main_arg_parser.add_argument("--model_lat_dim", type=int, default=128, help="")
|
||||
main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="")
|
||||
main_arg_parser.add_argument("--model_norm", type=strtobool, default=True, help="")
|
||||
main_arg_parser.add_argument("--model_dropout", type=float, default=0.2, help="")
|
||||
|
||||
# Project Parameters
|
||||
main_arg_parser.add_argument("--project_name", type=str, default=_ROOT.name, help="")
|
||||
main_arg_parser.add_argument("--project_owner", type=str, default='si11ium', help="")
|
||||
|
@ -34,10 +34,9 @@ def prepare_dataloader(config_obj):
|
||||
transforms = Compose([NormalizeLocal(), ToTensor()])
|
||||
aug_transforms = Compose([
|
||||
NoiseInjection(0.4),
|
||||
LoudnessManipulator(0),
|
||||
ShiftTime(0),
|
||||
MaskAug(0),
|
||||
# Utility
|
||||
LoudnessManipulator(0.4),
|
||||
ShiftTime(0.3),
|
||||
MaskAug(0.2),
|
||||
NormalizeLocal(), ToTensor()
|
||||
])
|
||||
|
||||
@ -73,8 +72,9 @@ if __name__ == '__main__':
|
||||
p = Plotter(outpath)
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
d = test_dataloader.dataset[0][0].squeeze()
|
||||
d = test_dataloader.dataset[100][0].squeeze()
|
||||
plt.imshow(d)
|
||||
p.save_current_figure('100')
|
||||
|
||||
loaded_model = restore_logger_and_model(config)
|
||||
loaded_model.eval()
|
||||
|
Reference in New Issue
Block a user