variable mask size, beter image shapes
This commit is contained in:
		| @@ -8,7 +8,7 @@ import torch | ||||
| from torch import nn | ||||
| from torch.nn import functional as F | ||||
|  | ||||
| from einops import rearrange | ||||
| from einops import rearrange, repeat | ||||
|  | ||||
| import sys | ||||
| sys.path.append(str(Path(__file__).parent)) | ||||
| @@ -262,7 +262,10 @@ class Attention(nn.Module): | ||||
|             mask = F.pad(mask.flatten(1), (1, 0), value=True) | ||||
|             assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' | ||||
|             mask = mask[:, None, :] * mask[:, :, None] | ||||
|             dots.masked_fill_(~mask, mask_value) | ||||
|             mask = repeat(mask, 'b n d -> b h n d', h=h)             # My addition | ||||
|  | ||||
|             #dots.masked_fill_(~mask, mask_value) | ||||
|             dots.masked_fill_(mask, mask_value) | ||||
|             del mask | ||||
|  | ||||
|         attn = dots.softmax(dim=-1) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Steffen
					Steffen