diff --git a/.gitignore b/.gitignore index 7fb11b7..485dee6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1 @@ -/.idea/ -# my own stuff - -/data -/.idea -/ml_lib \ No newline at end of file +.idea diff --git a/modules/blocks.py b/modules/blocks.py index ce3acf8..4489f5c 100644 --- a/modules/blocks.py +++ b/modules/blocks.py @@ -4,11 +4,11 @@ from pathlib import Path from typing import Union import torch -from performer_pytorch import FastAttention + from torch import nn from torch.nn import functional as F -from einops import rearrange +from einops import rearrange, repeat import sys sys.path.append(str(Path(__file__).parent)) @@ -262,7 +262,10 @@ class Attention(nn.Module): mask = F.pad(mask.flatten(1), (1, 0), value=True) assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' mask = mask[:, None, :] * mask[:, :, None] - dots.masked_fill_(~mask, mask_value) + mask = repeat(mask, 'b n d -> b h n d', h=h) # My addition + + #dots.masked_fill_(~mask, mask_value) + dots.masked_fill_(mask, mask_value) del mask attn = dots.softmax(dim=-1) diff --git a/utils/config.py b/utils/config.py index 492a574..5111dfe 100644 --- a/utils/config.py +++ b/utils/config.py @@ -26,6 +26,7 @@ def parse_comandline_args_add_defaults(filepath, overrides=None): parser = ArgumentParser() parser.add_argument('--model_name', type=str) parser.add_argument('--data_name', type=str) + parser.add_argument('--seed', type=str) # Load Defaults from _parameters.ini file config = configparser.ConfigParser() @@ -46,9 +47,11 @@ def parse_comandline_args_add_defaults(filepath, overrides=None): overrides = overrides or dict() default_data = overrides.get('data_name', None) or new_defaults['data_name'] default_model = overrides.get('model_name', None) or new_defaults['model_name'] + default_seed = overrides.get('seed', None) or new_defaults['seed'] data_name = args.__dict__.get('data_name', None) or default_data model_name = args.__dict__.get('model_name', None) or default_model + found_seed = args.__dict__.get('seed', None) or default_seed new_defaults.update({key: auto_cast(val) for key, val in config[model_name].items()}) @@ -72,7 +75,7 @@ def parse_comandline_args_add_defaults(filepath, overrides=None): if overrides is not None and isinstance(overrides, (Mapping, Dict)): args.update(**overrides) - return args, found_data_class, found_model_class + return args, found_data_class, found_model_class, found_seed def is_jsonable(x): diff --git a/utils/tools.py b/utils/tools.py index 80d0424..08d8f3f 100644 --- a/utils/tools.py +++ b/utils/tools.py @@ -29,6 +29,7 @@ def fix_all_random_seeds(seed): np.random.seed(seed) torch.manual_seed(seed) random.seed(seed) + print(f'Seed is now fixed: "{seed}".') def write_to_shelve(file_path, value):