From 10bf376ac32948dd1055d7747c4c957efda476bc Mon Sep 17 00:00:00 2001 From: Steffen Date: Thu, 18 Mar 2021 12:12:43 +0100 Subject: [PATCH 1/2] Small bugfixes --- .gitignore | 7 +------ modules/blocks.py | 2 +- utils/config.py | 5 ++++- utils/tools.py | 9 +++++---- 4 files changed, 11 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index 7fb11b7..485dee6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1 @@ -/.idea/ -# my own stuff - -/data -/.idea -/ml_lib \ No newline at end of file +.idea diff --git a/modules/blocks.py b/modules/blocks.py index ce3acf8..f76742d 100644 --- a/modules/blocks.py +++ b/modules/blocks.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Union import torch -from performer_pytorch import FastAttention + from torch import nn from torch.nn import functional as F diff --git a/utils/config.py b/utils/config.py index 492a574..5111dfe 100644 --- a/utils/config.py +++ b/utils/config.py @@ -26,6 +26,7 @@ def parse_comandline_args_add_defaults(filepath, overrides=None): parser = ArgumentParser() parser.add_argument('--model_name', type=str) parser.add_argument('--data_name', type=str) + parser.add_argument('--seed', type=str) # Load Defaults from _parameters.ini file config = configparser.ConfigParser() @@ -46,9 +47,11 @@ def parse_comandline_args_add_defaults(filepath, overrides=None): overrides = overrides or dict() default_data = overrides.get('data_name', None) or new_defaults['data_name'] default_model = overrides.get('model_name', None) or new_defaults['model_name'] + default_seed = overrides.get('seed', None) or new_defaults['seed'] data_name = args.__dict__.get('data_name', None) or default_data model_name = args.__dict__.get('model_name', None) or default_model + found_seed = args.__dict__.get('seed', None) or default_seed new_defaults.update({key: auto_cast(val) for key, val in config[model_name].items()}) @@ -72,7 +75,7 @@ def parse_comandline_args_add_defaults(filepath, overrides=None): if overrides is not None and isinstance(overrides, (Mapping, Dict)): args.update(**overrides) - return args, found_data_class, found_model_class + return args, found_data_class, found_model_class, found_seed def is_jsonable(x): diff --git a/utils/tools.py b/utils/tools.py index 3119f0e..08d8f3f 100644 --- a/utils/tools.py +++ b/utils/tools.py @@ -25,10 +25,11 @@ def to_one_hot(idx_array, max_classes): return one_hot -def fix_all_random_seeds(config_obj): - np.random.seed(config_obj.main.seed) - torch.manual_seed(config_obj.main.seed) - random.seed(config_obj.main.seed) +def fix_all_random_seeds(seed): + np.random.seed(seed) + torch.manual_seed(seed) + random.seed(seed) + print(f'Seed is now fixed: "{seed}".') def write_to_shelve(file_path, value): From 8e719af5543778c0b17e955e035ef229370ad51a Mon Sep 17 00:00:00 2001 From: Steffen Date: Thu, 18 Mar 2021 21:34:51 +0100 Subject: [PATCH 2/2] variable mask size, beter image shapes --- modules/blocks.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/modules/blocks.py b/modules/blocks.py index f76742d..4489f5c 100644 --- a/modules/blocks.py +++ b/modules/blocks.py @@ -8,7 +8,7 @@ import torch from torch import nn from torch.nn import functional as F -from einops import rearrange +from einops import rearrange, repeat import sys sys.path.append(str(Path(__file__).parent)) @@ -262,7 +262,10 @@ class Attention(nn.Module): mask = F.pad(mask.flatten(1), (1, 0), value=True) assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' mask = mask[:, None, :] * mask[:, :, None] - dots.masked_fill_(~mask, mask_value) + mask = repeat(mask, 'b n d -> b h n d', h=h) # My addition + + #dots.masked_fill_(~mask, mask_value) + dots.masked_fill_(mask, mask_value) del mask attn = dots.softmax(dim=-1)