bringing brances up to date

This commit is contained in:
Steffen Illium
2021-02-15 11:39:54 +01:00
parent 010176e80b
commit a966321576
11 changed files with 216 additions and 197 deletions

View File

@ -291,19 +291,17 @@ class TransformerModule(ShapeMixin, nn.Module):
for attn, mlp in zip(self.attns, self.mlps):
# Attention
skip_connection = tensor.clone()
tensor = self.norm(tensor)
attn_tensor = self.norm(tensor)
if return_attn_weights:
tensor, attn_weight = attn(tensor, mask=mask, return_attn_weights=return_attn_weights)
attn_tensor, attn_weight = attn(attn_tensor, mask=mask, return_attn_weights=return_attn_weights)
attn_weights.append(attn_weight)
else:
tensor = attn(tensor, mask=mask)
tensor = tensor + skip_connection
attn_tensor = attn(attn_tensor, mask=mask)
tensor = attn_tensor + tensor
# MLP
skip_connection = tensor.clone()
tensor = self.norm(tensor)
tensor = mlp(tensor)
tensor = tensor + skip_connection
mlp_tensor = self.norm(tensor)
mlp_tensor = mlp(mlp_tensor)
tensor = tensor + mlp_tensor
return (tensor, attn_weights) if return_attn_weights else tensor