New Model, Many Changes

This commit is contained in:
Si11ium 2020-11-21 09:28:25 +01:00
parent 13812b83b5
commit 14ed4e0117
8 changed files with 127 additions and 102 deletions

View File

@ -61,10 +61,11 @@ class BaseTrainMixin:
assert isinstance(self, LightningBaseModule)
keys = list(outputs[0].keys())
summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key]
summary_dict = {f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key})
return summary_dict
for key in keys if 'loss' in key}
for key in summary_dict.keys():
self.log(key, summary_dict[key])
class BaseValMixin:
@ -83,13 +84,13 @@ class BaseValMixin:
def validation_epoch_end(self, outputs, *_, **__):
assert isinstance(self, LightningBaseModule)
summary_dict = dict(log=dict())
summary_dict = dict()
# In case of Multiple given dataloader this will outputs will be: list[list[dict[]]]
# for output_idx, output in enumerate(outputs):
# else:list[dict[]]
keys = list(outputs.keys())
# Add Every Value das has a "loss" in it, by calc. mean over all occurences.
summary_dict['log'].update({f'mean_{key}': torch.mean(torch.stack([output[key]
summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key}
)
@ -107,7 +108,8 @@ class BaseValMixin:
summary_dict['log'].update({f'uar_score': uar_score})
"""
return summary_dict
for key in summary_dict.keys():
self.log(key, summary_dict[key])
class BinaryMaskDatasetMixin:

View File

@ -1,8 +1,5 @@
from argparse import Namespace
from ml_lib.utils.config import Config
class GlobalVar(Namespace):
# Labels for classes
LEFT = 1
@ -21,10 +18,3 @@ class GlobalVar(Namespace):
train='train',
vali='vali',
test='test'
class ThisConfig(Config):
@property
def _model_map(self):
return dict()

View File

@ -12,6 +12,7 @@ class Speed(object):
def __init__(self, max_amount=0.3, speed_min=1, speed_max=1):
self.speed_max = speed_max if speed_max else 1
self.speed_min = speed_min if speed_min else 1
# noinspection PyTypeChecker
self.max_amount = min(max(0, max_amount), 1)
def __call__(self, x):

View File

@ -1,16 +1,18 @@
import math
import warnings
from pathlib import Path
from typing import Union
import torch
import warnings
from torch import nn
from torch.nn import functional as F
from einops import rearrange
import sys
sys.path.append(str(Path(__file__).parent))
from .util import AutoPad, Interpolate, ShapeMixin, F_x, Flatten
from .util import AutoPad, Interpolate, ShapeMixin, F_x, Flatten, ResidualBlock, PreNorm
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@ -212,81 +214,81 @@ class RecurrentModule(ShapeMixin, nn.Module):
tensor = self.rnn(x)
return tensor
class AttentionModule(ShapeMixin, nn.Module):
def __init__(self,in_shape, features, dropout=0.1):
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.in_shape = in_shape
self.dropout = dropout
self.features = features
raise NotImplementedError
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
pass
return self.net(x)
class MultiHeadAttentionModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, heads, features, dropout=0.1):
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dropout = 0.):
super().__init__()
self.in_shape = in_shape
self.features = features
self.heads = heads
self.final_dim = self.features // self.heads
self.scale = dim ** -0.5
self.linear_q = LinearModule(self.features, self.features)
self.linear_v = LinearModule(self.features, self.features)
self.linear_k = LinearModule(self.features, self.features)
self.dropout = nn.Dropout(dropout) if dropout else F_x(self.features)
self.linear_out = nn.Linear(self.features, self.features)
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(dim, dim),
nn.Dropout(dropout)
)
def forward(self, q, k, v, mask=None):
def forward(self, x, mask = None):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = [rearrange(t, 'b n (h d) -> b h n d', h = h) for t in qkv]
batch_size = q.size(0)
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
mask_value = -torch.finfo(dots.dtype).max
# perform linear operation and split into h heads
k = self.linear_k(k).view(batch_size, -1, self.heads, self.final_dim)
q = self.linear_q(q).view(batch_size, -1, self.heads, self.final_dim)
v = self.linear_v(v).view(batch_size, -1, self.heads, self.final_dim)
# transpose to get dimensions bs * h * sl * features
# ToDo: Do we need this?
k = k.transpose(1, 2)
q = q.transpose(1, 2)
v = v.transpose(1, 2)
# calculate attention
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.final_dim)
if mask is not None:
mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask == 0, -1e9)
scores = F.softmax(scores, dim=-1)
scores = self.dropout(scores)
scores = torch.matmul(scores, v)
mask = F.pad(mask.flatten(1), [1, 0], value = True)
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
mask = mask[:, None, :] * mask[:, :, None]
dots.masked_fill_(~mask, mask_value)
del mask
# concatenate heads and apply final linear transformation
# ToDo: This seems to be old coding style. Do we Need this?
concat = scores.transpose(1, 2).contiguous().view(batch_size, -1, self.features)
attn = dots.softmax(dim=-1)
output = self.out(concat)
return output
out = torch.einsum('bhij,bhjd->bhid', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, mlp_dim, dropout):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
ResidualBlock(PreNorm(dim, Attention(dim, heads = heads, dropout = dropout))),
ResidualBlock(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
]))
def forward(self, x, mask = None, *_, **__):
for attn, ff in self.layers:
x = attn(x, mask = mask)
x = ff(x)
return x
class TransformerModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, hidden_size, n_heads, num_layers=1, dropout=None, use_norm=False, **kwargs):
def __init__(self, in_shape, hidden_size, n_heads, num_layers=1, dropout=None, use_norm=False, activation='gelu'):
super(TransformerModule, self).__init__()
self.in_shape = in_shape
self.flat = Flatten(self.in_shape) if isinstance(self.in_shape, (tuple, list)) else F_x(in_shape)
encoder_layer = nn.TransformerEncoderLayer(self.flat_shape, n_heads, dim_feedforward=hidden_size,
dropout=dropout, activation=kwargs.get('activation')
)
self.norm = nn.LayerNorm(hidden_size) if use_norm else F_x(hidden_size)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers, )
self.transformer = Transformer(dim=self.flat.flat_shape, depth=num_layers, heads=n_heads,
mlp_dim=hidden_size, dropout=dropout)
def forward(self, x, mask=None, key_padding_mask=None):
tensor = self.flat(x)

View File

@ -11,7 +11,7 @@ from operator import mul
from torch import nn
from torch.utils.data import DataLoader
from .blocks import ConvModule, DeConvModule, LinearModule, MultiHeadAttentionModule
from .blocks import ConvModule, DeConvModule, LinearModule
from .util import ShapeMixin, LightningBaseModule, Flatten
@ -112,6 +112,7 @@ class Generator(ShapeMixin, nn.Module):
last_shape = re_shape
for conv_filter, conv_kernel, interpolation in zip(reversed(filters), kernels, interpolations):
# noinspection PyTypeChecker
self.de_conv_list.append(DeConvModule(last_shape, conv_filters=conv_filter,
conv_kernel=conv_kernel,
conv_padding=conv_kernel-2,
@ -275,16 +276,3 @@ class Encoder(BaseEncoder):
tensor = self.l1(tensor)
tensor = self.latent_activation(tensor) if self.latent_activation else tensor
return tensor
class TransformerEncoder(ShapeMixin, nn.Module):
def __init__(self, in_shape):
super(TransformerEncoder, self).__init__()
# MultiheadSelfAttention
self.msa = MultiHeadAttentionModule()
def forward(self, x):

View File

@ -1,3 +1,5 @@
from typing import List
from functools import reduce
from abc import ABC
@ -6,7 +8,7 @@ from pathlib import Path
import torch
from operator import mul
from torch import nn
from torch.nn import functional as F
from torch.nn import functional as F, Unfold
# Utility - Modules
###################
@ -38,6 +40,7 @@ try:
################################
self.hparams = hparams
self.params = ModelParameters(hparams)
self.lr = self.params.lr or 1e-4
def size(self):
return self.shape
@ -76,10 +79,10 @@ try:
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
self.apply(weight_initializer)
modules = [LightningBaseModule, nn.Module]
module_types = (LightningBaseModule, nn.Module,)
except ImportError:
modules = [nn.Module, ]
module_types = (nn.Module,)
pass # Maybe post a hint to install pytorch-lightning.
@ -88,7 +91,7 @@ class ShapeMixin:
@property
def shape(self):
assert isinstance(self, modules)
assert isinstance(self, module_types)
def get_out_shape(output):
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
@ -135,6 +138,41 @@ class F_x(ShapeMixin, nn.Module):
return x
class ResidualBlock(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class SlidingWindow(nn.Module):
def __init__(self, kernel, stride=1, padding=0, keepdim=False):
super(SlidingWindow, self).__init__()
self.kernel = kernel if not isinstance(kernel, int) else (kernel, kernel)
self.padding = padding
self.stride = stride
self.keepdim = keepdim
self._unfolder = Unfold(self.kernel, dilation=1, padding=self.padding, stride=self.stride)
def forward(self, x):
tensor = self._unfolder(x)
tensor = tensor.transpose(-1, -2)
if self.keepdim:
shape = *x.shape[:2], -1, *self.kernel
tensor = tensor.reshape(shape)
return tensor
# Utility - Modules
###################
class Flatten(ShapeMixin, nn.Module):
@ -232,13 +270,12 @@ class AutoPadToShape(object):
def __call__(self, x):
if not torch.is_tensor(x):
x = torch.as_tensor(x)
if x.shape[1:] == self.shape or x.shape == self.shape:
if x.shape[-len(self.shape):] == self.shape or x.shape == self.shape:
return x
for i in range(-1, -len(self.shape), -1):
idx = [0] * len(x.shape)
idx[i] = self.shape[i] - x.shape[i]
idx = tuple(idx)
idx = [0] * (len(self.shape) * 2)
for i, j in zip(range(-1, -(len(self.shape)+1), -1), range(0, len(idx), 2)):
idx[j] = self.shape[i] - x.shape[i]
x = torch.nn.functional.pad(x, idx)
return x

View File

@ -13,6 +13,10 @@ from torch import nn
# Hyperparamter Object
class ModelParameters(Namespace, Mapping):
@property
def activation_as_string(self):
return self['activation'].lower()
@property
def module_kwargs(self):
@ -56,6 +60,7 @@ class ModelParameters(Namespace, Mapping):
_activations = dict(
leaky_relu=nn.LeakyReLU,
gelu=nn.GELU,
elu=nn.ELU,
relu=nn.ReLU,
sigmoid=nn.Sigmoid,