From 10bf376ac32948dd1055d7747c4c957efda476bc Mon Sep 17 00:00:00 2001 From: Steffen Date: Thu, 18 Mar 2021 12:12:43 +0100 Subject: [PATCH] Small bugfixes --- .gitignore | 7 +------ modules/blocks.py | 2 +- utils/config.py | 5 ++++- utils/tools.py | 9 +++++---- 4 files changed, 11 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index 7fb11b7..485dee6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1 @@ -/.idea/ -# my own stuff - -/data -/.idea -/ml_lib \ No newline at end of file +.idea diff --git a/modules/blocks.py b/modules/blocks.py index ce3acf8..f76742d 100644 --- a/modules/blocks.py +++ b/modules/blocks.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Union import torch -from performer_pytorch import FastAttention + from torch import nn from torch.nn import functional as F diff --git a/utils/config.py b/utils/config.py index 492a574..5111dfe 100644 --- a/utils/config.py +++ b/utils/config.py @@ -26,6 +26,7 @@ def parse_comandline_args_add_defaults(filepath, overrides=None): parser = ArgumentParser() parser.add_argument('--model_name', type=str) parser.add_argument('--data_name', type=str) + parser.add_argument('--seed', type=str) # Load Defaults from _parameters.ini file config = configparser.ConfigParser() @@ -46,9 +47,11 @@ def parse_comandline_args_add_defaults(filepath, overrides=None): overrides = overrides or dict() default_data = overrides.get('data_name', None) or new_defaults['data_name'] default_model = overrides.get('model_name', None) or new_defaults['model_name'] + default_seed = overrides.get('seed', None) or new_defaults['seed'] data_name = args.__dict__.get('data_name', None) or default_data model_name = args.__dict__.get('model_name', None) or default_model + found_seed = args.__dict__.get('seed', None) or default_seed new_defaults.update({key: auto_cast(val) for key, val in config[model_name].items()}) @@ -72,7 +75,7 @@ def parse_comandline_args_add_defaults(filepath, overrides=None): if overrides is not None and isinstance(overrides, (Mapping, Dict)): args.update(**overrides) - return args, found_data_class, found_model_class + return args, found_data_class, found_model_class, found_seed def is_jsonable(x): diff --git a/utils/tools.py b/utils/tools.py index 3119f0e..08d8f3f 100644 --- a/utils/tools.py +++ b/utils/tools.py @@ -25,10 +25,11 @@ def to_one_hot(idx_array, max_classes): return one_hot -def fix_all_random_seeds(config_obj): - np.random.seed(config_obj.main.seed) - torch.manual_seed(config_obj.main.seed) - random.seed(config_obj.main.seed) +def fix_all_random_seeds(seed): + np.random.seed(seed) + torch.manual_seed(seed) + random.seed(seed) + print(f'Seed is now fixed: "{seed}".') def write_to_shelve(file_path, value):