variable mask size, beter image shapes
This commit is contained in:
parent
10bf376ac3
commit
8e719af554
@ -8,7 +8,7 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from einops import rearrange
|
||||
from einops import rearrange, repeat
|
||||
|
||||
import sys
|
||||
sys.path.append(str(Path(__file__).parent))
|
||||
@ -262,7 +262,10 @@ class Attention(nn.Module):
|
||||
mask = F.pad(mask.flatten(1), (1, 0), value=True)
|
||||
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
|
||||
mask = mask[:, None, :] * mask[:, :, None]
|
||||
dots.masked_fill_(~mask, mask_value)
|
||||
mask = repeat(mask, 'b n d -> b h n d', h=h) # My addition
|
||||
|
||||
#dots.masked_fill_(~mask, mask_value)
|
||||
dots.masked_fill_(mask, mask_value)
|
||||
del mask
|
||||
|
||||
attn = dots.softmax(dim=-1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user