From 43cf0ad00d3be2e52eeb9e0e94280c4dd530dcaf Mon Sep 17 00:00:00 2001 From: Steffen Date: Fri, 19 Mar 2021 17:17:16 +0100 Subject: [PATCH 1/2] CCS intergration dataloader --- modules/blocks.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/modules/blocks.py b/modules/blocks.py index 4489f5c..602bdb9 100644 --- a/modules/blocks.py +++ b/modules/blocks.py @@ -262,10 +262,9 @@ class Attention(nn.Module): mask = F.pad(mask.flatten(1), (1, 0), value=True) assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' mask = mask[:, None, :] * mask[:, :, None] - mask = repeat(mask, 'b n d -> b h n d', h=h) # My addition - - #dots.masked_fill_(~mask, mask_value) - dots.masked_fill_(mask, mask_value) + mask = repeat(mask, 'b n d -> b h n d', h=h) # My addition + dots.masked_fill_(~mask, mask_value) + # dots.masked_fill_(mask, mask_value) # My addition del mask attn = dots.softmax(dim=-1) From 675312537ff627b2e1b0ef005800fdfdb7fc89b9 Mon Sep 17 00:00:00 2001 From: Steffen Date: Fri, 19 Mar 2021 18:05:17 +0100 Subject: [PATCH 2/2] CCS intergration dataloader --- utils/config.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/utils/config.py b/utils/config.py index 5111dfe..87b3b17 100644 --- a/utils/config.py +++ b/utils/config.py @@ -70,7 +70,11 @@ def parse_comandline_args_add_defaults(filepath, overrides=None): log_save_interval=10000, # TODO: Better Value / Setting auto_lr_find=not args['debug'], weights_summary='top', - check_val_every_n_epoch=1 if args['debug'] else args.get('check_val_every_n_epoch', 1) + check_val_every_n_epoch=1 if args['debug'] else args.get('check_val_every_n_epoch', 1), + limit_train_batches = 2.0, + limit_val_batches = 2.0, + limit_test_batches = 2.0, + limit_predict_batches = 2.0, ) if overrides is not None and isinstance(overrides, (Mapping, Dict)):