CCS intergration dataloader
This commit is contained in:
@ -263,9 +263,8 @@ class Attention(nn.Module):
|
|||||||
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
|
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
|
||||||
mask = mask[:, None, :] * mask[:, :, None]
|
mask = mask[:, None, :] * mask[:, :, None]
|
||||||
mask = repeat(mask, 'b n d -> b h n d', h=h) # My addition
|
mask = repeat(mask, 'b n d -> b h n d', h=h) # My addition
|
||||||
|
dots.masked_fill_(~mask, mask_value)
|
||||||
#dots.masked_fill_(~mask, mask_value)
|
# dots.masked_fill_(mask, mask_value) # My addition
|
||||||
dots.masked_fill_(mask, mask_value)
|
|
||||||
del mask
|
del mask
|
||||||
|
|
||||||
attn = dots.softmax(dim=-1)
|
attn = dots.softmax(dim=-1)
|
||||||
|
Reference in New Issue
Block a user