CCS intergration dataloader

This commit is contained in:
Steffen 2021-03-19 17:17:16 +01:00
parent 8e719af554
commit 43cf0ad00d

View File

@ -262,10 +262,9 @@ class Attention(nn.Module):
mask = F.pad(mask.flatten(1), (1, 0), value=True)
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
mask = mask[:, None, :] * mask[:, :, None]
mask = repeat(mask, 'b n d -> b h n d', h=h) # My addition
#dots.masked_fill_(~mask, mask_value)
dots.masked_fill_(mask, mask_value)
mask = repeat(mask, 'b n d -> b h n d', h=h) # My addition
dots.masked_fill_(~mask, mask_value)
# dots.masked_fill_(mask, mask_value) # My addition
del mask
attn = dots.softmax(dim=-1)