From 43cf0ad00d3be2e52eeb9e0e94280c4dd530dcaf Mon Sep 17 00:00:00 2001 From: Steffen Date: Fri, 19 Mar 2021 17:17:16 +0100 Subject: [PATCH] CCS intergration dataloader --- modules/blocks.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/modules/blocks.py b/modules/blocks.py index 4489f5c..602bdb9 100644 --- a/modules/blocks.py +++ b/modules/blocks.py @@ -262,10 +262,9 @@ class Attention(nn.Module): mask = F.pad(mask.flatten(1), (1, 0), value=True) assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' mask = mask[:, None, :] * mask[:, :, None] - mask = repeat(mask, 'b n d -> b h n d', h=h) # My addition - - #dots.masked_fill_(~mask, mask_value) - dots.masked_fill_(mask, mask_value) + mask = repeat(mask, 'b n d -> b h n d', h=h) # My addition + dots.masked_fill_(~mask, mask_value) + # dots.masked_fill_(mask, mask_value) # My addition del mask attn = dots.softmax(dim=-1)