diff --git a/main.py b/main.py index d429e69..3d17f8f 100644 --- a/main.py +++ b/main.py @@ -20,6 +20,15 @@ warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) +def fix_all_random_seeds(config_obj): + import numpy as np + import torch + import random + np.random.seed(config.main.seed) + torch.manual_seed(config.main.seed) + random.seed(config.main.seed) + + def run_lightning_loop(config_obj): # Logging @@ -124,4 +133,5 @@ if __name__ == "__main__": from _paramters import main_arg_parser config = MConfig.read_argparser(main_arg_parser) + fix_all_random_seeds(config) trained_model = run_lightning_loop(config)