New Model, Many Changes

This commit is contained in:
Si11ium
2020-11-21 09:28:25 +01:00
parent 13812b83b5
commit 14ed4e0117
8 changed files with 127 additions and 102 deletions

View File

@ -1,16 +1,18 @@
import math
import warnings
from pathlib import Path
from typing import Union
import torch
import warnings
from torch import nn
from torch.nn import functional as F
from einops import rearrange
import sys
sys.path.append(str(Path(__file__).parent))
from .util import AutoPad, Interpolate, ShapeMixin, F_x, Flatten
from .util import AutoPad, Interpolate, ShapeMixin, F_x, Flatten, ResidualBlock, PreNorm
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@ -212,81 +214,81 @@ class RecurrentModule(ShapeMixin, nn.Module):
tensor = self.rnn(x)
return tensor
class AttentionModule(ShapeMixin, nn.Module):
def __init__(self,in_shape, features, dropout=0.1):
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.in_shape = in_shape
self.dropout = dropout
self.features = features
raise NotImplementedError
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
pass
return self.net(x)
class MultiHeadAttentionModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, heads, features, dropout=0.1):
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dropout = 0.):
super().__init__()
self.in_shape = in_shape
self.features = features
self.heads = heads
self.final_dim = self.features // self.heads
self.scale = dim ** -0.5
self.linear_q = LinearModule(self.features, self.features)
self.linear_v = LinearModule(self.features, self.features)
self.linear_k = LinearModule(self.features, self.features)
self.dropout = nn.Dropout(dropout) if dropout else F_x(self.features)
self.linear_out = nn.Linear(self.features, self.features)
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(dim, dim),
nn.Dropout(dropout)
)
def forward(self, q, k, v, mask=None):
def forward(self, x, mask = None):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = [rearrange(t, 'b n (h d) -> b h n d', h = h) for t in qkv]
batch_size = q.size(0)
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
mask_value = -torch.finfo(dots.dtype).max
# perform linear operation and split into h heads
k = self.linear_k(k).view(batch_size, -1, self.heads, self.final_dim)
q = self.linear_q(q).view(batch_size, -1, self.heads, self.final_dim)
v = self.linear_v(v).view(batch_size, -1, self.heads, self.final_dim)
# transpose to get dimensions bs * h * sl * features
# ToDo: Do we need this?
k = k.transpose(1, 2)
q = q.transpose(1, 2)
v = v.transpose(1, 2)
# calculate attention
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.final_dim)
if mask is not None:
mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask == 0, -1e9)
scores = F.softmax(scores, dim=-1)
scores = self.dropout(scores)
scores = torch.matmul(scores, v)
mask = F.pad(mask.flatten(1), [1, 0], value = True)
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
mask = mask[:, None, :] * mask[:, :, None]
dots.masked_fill_(~mask, mask_value)
del mask
# concatenate heads and apply final linear transformation
# ToDo: This seems to be old coding style. Do we Need this?
concat = scores.transpose(1, 2).contiguous().view(batch_size, -1, self.features)
attn = dots.softmax(dim=-1)
output = self.out(concat)
return output
out = torch.einsum('bhij,bhjd->bhid', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, mlp_dim, dropout):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
ResidualBlock(PreNorm(dim, Attention(dim, heads = heads, dropout = dropout))),
ResidualBlock(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
]))
def forward(self, x, mask = None, *_, **__):
for attn, ff in self.layers:
x = attn(x, mask = mask)
x = ff(x)
return x
class TransformerModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, hidden_size, n_heads, num_layers=1, dropout=None, use_norm=False, **kwargs):
def __init__(self, in_shape, hidden_size, n_heads, num_layers=1, dropout=None, use_norm=False, activation='gelu'):
super(TransformerModule, self).__init__()
self.in_shape = in_shape
self.flat = Flatten(self.in_shape) if isinstance(self.in_shape, (tuple, list)) else F_x(in_shape)
encoder_layer = nn.TransformerEncoderLayer(self.flat_shape, n_heads, dim_feedforward=hidden_size,
dropout=dropout, activation=kwargs.get('activation')
)
self.norm = nn.LayerNorm(hidden_size) if use_norm else F_x(hidden_size)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers, )
self.transformer = Transformer(dim=self.flat.flat_shape, depth=num_layers, heads=n_heads,
mlp_dim=hidden_size, dropout=dropout)
def forward(self, x, mask=None, key_padding_mask=None):
tensor = self.flat(x)