diff --git a/modules/blocks.py b/modules/blocks.py index f76742d..4489f5c 100644 --- a/modules/blocks.py +++ b/modules/blocks.py @@ -8,7 +8,7 @@ import torch from torch import nn from torch.nn import functional as F -from einops import rearrange +from einops import rearrange, repeat import sys sys.path.append(str(Path(__file__).parent)) @@ -262,7 +262,10 @@ class Attention(nn.Module): mask = F.pad(mask.flatten(1), (1, 0), value=True) assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' mask = mask[:, None, :] * mask[:, :, None] - dots.masked_fill_(~mask, mask_value) + mask = repeat(mask, 'b n d -> b h n d', h=h) # My addition + + #dots.masked_fill_(~mask, mask_value) + dots.masked_fill_(mask, mask_value) del mask attn = dots.softmax(dim=-1)