Merge remote-tracking branch 'origin/master'
This commit is contained in:
commit
ed260f1c2a
@ -262,10 +262,9 @@ class Attention(nn.Module):
|
|||||||
mask = F.pad(mask.flatten(1), (1, 0), value=True)
|
mask = F.pad(mask.flatten(1), (1, 0), value=True)
|
||||||
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
|
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
|
||||||
mask = mask[:, None, :] * mask[:, :, None]
|
mask = mask[:, None, :] * mask[:, :, None]
|
||||||
mask = repeat(mask, 'b n d -> b h n d', h=h) # My addition
|
mask = repeat(mask, 'b n d -> b h n d', h=h) # My addition
|
||||||
|
dots.masked_fill_(~mask, mask_value)
|
||||||
#dots.masked_fill_(~mask, mask_value)
|
# dots.masked_fill_(mask, mask_value) # My addition
|
||||||
dots.masked_fill_(mask, mask_value)
|
|
||||||
del mask
|
del mask
|
||||||
|
|
||||||
attn = dots.softmax(dim=-1)
|
attn = dots.softmax(dim=-1)
|
||||||
|
@ -70,7 +70,11 @@ def parse_comandline_args_add_defaults(filepath, overrides=None):
|
|||||||
log_save_interval=10000, # TODO: Better Value / Setting
|
log_save_interval=10000, # TODO: Better Value / Setting
|
||||||
auto_lr_find=not args['debug'],
|
auto_lr_find=not args['debug'],
|
||||||
weights_summary='top',
|
weights_summary='top',
|
||||||
check_val_every_n_epoch=1 if args['debug'] else args.get('check_val_every_n_epoch', 1)
|
check_val_every_n_epoch=1 if args['debug'] else args.get('check_val_every_n_epoch', 1),
|
||||||
|
limit_train_batches = 2.0,
|
||||||
|
limit_val_batches = 2.0,
|
||||||
|
limit_test_batches = 2.0,
|
||||||
|
limit_predict_batches = 2.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if overrides is not None and isinstance(overrides, (Mapping, Dict)):
|
if overrides is not None and isinstance(overrides, (Mapping, Dict)):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user