diff --git a/_paramters.py b/_paramters.py index 9d54918..dbf0144 100644 --- a/_paramters.py +++ b/_paramters.py @@ -33,7 +33,7 @@ main_arg_parser.add_argument("--data_mixup", type=strtobool, default=False, help main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, help="") # 0.4 main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0, help="") # 0.3 main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4 -main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0.2, help="") # 0.2 +main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="") # 0.2 main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0.3, help="") # 0.3 main_arg_parser.add_argument("--data_speed_factor", type=float, default=0.7, help="") # 0.7 diff --git a/main.py b/main.py index d429e69..3d17f8f 100644 --- a/main.py +++ b/main.py @@ -20,6 +20,15 @@ warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) +def fix_all_random_seeds(config_obj): + import numpy as np + import torch + import random + np.random.seed(config.main.seed) + torch.manual_seed(config.main.seed) + random.seed(config.main.seed) + + def run_lightning_loop(config_obj): # Logging @@ -124,4 +133,5 @@ if __name__ == "__main__": from _paramters import main_arg_parser config = MConfig.read_argparser(main_arg_parser) + fix_all_random_seeds(config) trained_model = run_lightning_loop(config)