seeding
This commit is contained in:
parent
e7d1a4895a
commit
f458660613
@ -33,7 +33,7 @@ main_arg_parser.add_argument("--data_mixup", type=strtobool, default=False, help
|
|||||||
main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, help="") # 0.4
|
main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, help="") # 0.4
|
||||||
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0, help="") # 0.3
|
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0, help="") # 0.3
|
||||||
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4
|
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4
|
||||||
main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0.2, help="") # 0.2
|
main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="") # 0.2
|
||||||
main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0.3, help="") # 0.3
|
main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0.3, help="") # 0.3
|
||||||
main_arg_parser.add_argument("--data_speed_factor", type=float, default=0.7, help="") # 0.7
|
main_arg_parser.add_argument("--data_speed_factor", type=float, default=0.7, help="") # 0.7
|
||||||
|
|
||||||
|
10
main.py
10
main.py
@ -20,6 +20,15 @@ warnings.filterwarnings('ignore', category=FutureWarning)
|
|||||||
warnings.filterwarnings('ignore', category=UserWarning)
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
|
|
||||||
|
|
||||||
|
def fix_all_random_seeds(config_obj):
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import random
|
||||||
|
np.random.seed(config.main.seed)
|
||||||
|
torch.manual_seed(config.main.seed)
|
||||||
|
random.seed(config.main.seed)
|
||||||
|
|
||||||
|
|
||||||
def run_lightning_loop(config_obj):
|
def run_lightning_loop(config_obj):
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
@ -124,4 +133,5 @@ if __name__ == "__main__":
|
|||||||
from _paramters import main_arg_parser
|
from _paramters import main_arg_parser
|
||||||
|
|
||||||
config = MConfig.read_argparser(main_arg_parser)
|
config = MConfig.read_argparser(main_arg_parser)
|
||||||
|
fix_all_random_seeds(config)
|
||||||
trained_model = run_lightning_loop(config)
|
trained_model = run_lightning_loop(config)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user