Merge remote-tracking branch 'origin/master'

# Conflicts:
#	utils/tools.py
This commit is contained in:
Steffen Illium 2021-03-18 21:44:18 +01:00
commit 479514c9e7
4 changed files with 12 additions and 10 deletions

7
.gitignore vendored
View File

@ -1,6 +1 @@
/.idea/ .idea
# my own stuff
/data
/.idea
/ml_lib

View File

@ -4,11 +4,11 @@ from pathlib import Path
from typing import Union from typing import Union
import torch import torch
from performer_pytorch import FastAttention
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from einops import rearrange from einops import rearrange, repeat
import sys import sys
sys.path.append(str(Path(__file__).parent)) sys.path.append(str(Path(__file__).parent))
@ -262,7 +262,10 @@ class Attention(nn.Module):
mask = F.pad(mask.flatten(1), (1, 0), value=True) mask = F.pad(mask.flatten(1), (1, 0), value=True)
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
mask = mask[:, None, :] * mask[:, :, None] mask = mask[:, None, :] * mask[:, :, None]
dots.masked_fill_(~mask, mask_value) mask = repeat(mask, 'b n d -> b h n d', h=h) # My addition
#dots.masked_fill_(~mask, mask_value)
dots.masked_fill_(mask, mask_value)
del mask del mask
attn = dots.softmax(dim=-1) attn = dots.softmax(dim=-1)

View File

@ -26,6 +26,7 @@ def parse_comandline_args_add_defaults(filepath, overrides=None):
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument('--model_name', type=str) parser.add_argument('--model_name', type=str)
parser.add_argument('--data_name', type=str) parser.add_argument('--data_name', type=str)
parser.add_argument('--seed', type=str)
# Load Defaults from _parameters.ini file # Load Defaults from _parameters.ini file
config = configparser.ConfigParser() config = configparser.ConfigParser()
@ -46,9 +47,11 @@ def parse_comandline_args_add_defaults(filepath, overrides=None):
overrides = overrides or dict() overrides = overrides or dict()
default_data = overrides.get('data_name', None) or new_defaults['data_name'] default_data = overrides.get('data_name', None) or new_defaults['data_name']
default_model = overrides.get('model_name', None) or new_defaults['model_name'] default_model = overrides.get('model_name', None) or new_defaults['model_name']
default_seed = overrides.get('seed', None) or new_defaults['seed']
data_name = args.__dict__.get('data_name', None) or default_data data_name = args.__dict__.get('data_name', None) or default_data
model_name = args.__dict__.get('model_name', None) or default_model model_name = args.__dict__.get('model_name', None) or default_model
found_seed = args.__dict__.get('seed', None) or default_seed
new_defaults.update({key: auto_cast(val) for key, val in config[model_name].items()}) new_defaults.update({key: auto_cast(val) for key, val in config[model_name].items()})
@ -72,7 +75,7 @@ def parse_comandline_args_add_defaults(filepath, overrides=None):
if overrides is not None and isinstance(overrides, (Mapping, Dict)): if overrides is not None and isinstance(overrides, (Mapping, Dict)):
args.update(**overrides) args.update(**overrides)
return args, found_data_class, found_model_class return args, found_data_class, found_model_class, found_seed
def is_jsonable(x): def is_jsonable(x):

View File

@ -29,6 +29,7 @@ def fix_all_random_seeds(seed):
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
random.seed(seed) random.seed(seed)
print(f'Seed is now fixed: "{seed}".')
def write_to_shelve(file_path, value): def write_to_shelve(file_path, value):