Small bugfixes
This commit is contained in:
parent
fc4617c9d8
commit
10bf376ac3
7
.gitignore
vendored
7
.gitignore
vendored
@ -1,6 +1 @@
|
|||||||
/.idea/
|
.idea
|
||||||
# my own stuff
|
|
||||||
|
|
||||||
/data
|
|
||||||
/.idea
|
|
||||||
/ml_lib
|
|
||||||
|
@ -4,7 +4,7 @@ from pathlib import Path
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from performer_pytorch import FastAttention
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
@ -26,6 +26,7 @@ def parse_comandline_args_add_defaults(filepath, overrides=None):
|
|||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument('--model_name', type=str)
|
parser.add_argument('--model_name', type=str)
|
||||||
parser.add_argument('--data_name', type=str)
|
parser.add_argument('--data_name', type=str)
|
||||||
|
parser.add_argument('--seed', type=str)
|
||||||
|
|
||||||
# Load Defaults from _parameters.ini file
|
# Load Defaults from _parameters.ini file
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
@ -46,9 +47,11 @@ def parse_comandline_args_add_defaults(filepath, overrides=None):
|
|||||||
overrides = overrides or dict()
|
overrides = overrides or dict()
|
||||||
default_data = overrides.get('data_name', None) or new_defaults['data_name']
|
default_data = overrides.get('data_name', None) or new_defaults['data_name']
|
||||||
default_model = overrides.get('model_name', None) or new_defaults['model_name']
|
default_model = overrides.get('model_name', None) or new_defaults['model_name']
|
||||||
|
default_seed = overrides.get('seed', None) or new_defaults['seed']
|
||||||
|
|
||||||
data_name = args.__dict__.get('data_name', None) or default_data
|
data_name = args.__dict__.get('data_name', None) or default_data
|
||||||
model_name = args.__dict__.get('model_name', None) or default_model
|
model_name = args.__dict__.get('model_name', None) or default_model
|
||||||
|
found_seed = args.__dict__.get('seed', None) or default_seed
|
||||||
|
|
||||||
new_defaults.update({key: auto_cast(val) for key, val in config[model_name].items()})
|
new_defaults.update({key: auto_cast(val) for key, val in config[model_name].items()})
|
||||||
|
|
||||||
@ -72,7 +75,7 @@ def parse_comandline_args_add_defaults(filepath, overrides=None):
|
|||||||
|
|
||||||
if overrides is not None and isinstance(overrides, (Mapping, Dict)):
|
if overrides is not None and isinstance(overrides, (Mapping, Dict)):
|
||||||
args.update(**overrides)
|
args.update(**overrides)
|
||||||
return args, found_data_class, found_model_class
|
return args, found_data_class, found_model_class, found_seed
|
||||||
|
|
||||||
|
|
||||||
def is_jsonable(x):
|
def is_jsonable(x):
|
||||||
|
@ -25,10 +25,11 @@ def to_one_hot(idx_array, max_classes):
|
|||||||
return one_hot
|
return one_hot
|
||||||
|
|
||||||
|
|
||||||
def fix_all_random_seeds(config_obj):
|
def fix_all_random_seeds(seed):
|
||||||
np.random.seed(config_obj.main.seed)
|
np.random.seed(seed)
|
||||||
torch.manual_seed(config_obj.main.seed)
|
torch.manual_seed(seed)
|
||||||
random.seed(config_obj.main.seed)
|
random.seed(seed)
|
||||||
|
print(f'Seed is now fixed: "{seed}".')
|
||||||
|
|
||||||
|
|
||||||
def write_to_shelve(file_path, value):
|
def write_to_shelve(file_path, value):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user