bringing brances up to date
This commit is contained in:
@ -291,19 +291,17 @@ class TransformerModule(ShapeMixin, nn.Module):
|
||||
|
||||
for attn, mlp in zip(self.attns, self.mlps):
|
||||
# Attention
|
||||
skip_connection = tensor.clone()
|
||||
tensor = self.norm(tensor)
|
||||
attn_tensor = self.norm(tensor)
|
||||
if return_attn_weights:
|
||||
tensor, attn_weight = attn(tensor, mask=mask, return_attn_weights=return_attn_weights)
|
||||
attn_tensor, attn_weight = attn(attn_tensor, mask=mask, return_attn_weights=return_attn_weights)
|
||||
attn_weights.append(attn_weight)
|
||||
else:
|
||||
tensor = attn(tensor, mask=mask)
|
||||
tensor = tensor + skip_connection
|
||||
attn_tensor = attn(attn_tensor, mask=mask)
|
||||
tensor = attn_tensor + tensor
|
||||
|
||||
# MLP
|
||||
skip_connection = tensor.clone()
|
||||
tensor = self.norm(tensor)
|
||||
tensor = mlp(tensor)
|
||||
tensor = tensor + skip_connection
|
||||
mlp_tensor = self.norm(tensor)
|
||||
mlp_tensor = mlp(mlp_tensor)
|
||||
tensor = tensor + mlp_tensor
|
||||
|
||||
return (tensor, attn_weights) if return_attn_weights else tensor
|
||||
|
Reference in New Issue
Block a user