New Model, Many Changes

This commit is contained in:
Si11ium 2020-11-21 09:28:25 +01:00
parent 13812b83b5
commit 14ed4e0117
8 changed files with 127 additions and 102 deletions

View File

@ -61,10 +61,11 @@ class BaseTrainMixin:
assert isinstance(self, LightningBaseModule) assert isinstance(self, LightningBaseModule)
keys = list(outputs[0].keys()) keys = list(outputs[0].keys())
summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key] summary_dict = {f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs])) for output in outputs]))
for key in keys if 'loss' in key}) for key in keys if 'loss' in key}
return summary_dict for key in summary_dict.keys():
self.log(key, summary_dict[key])
class BaseValMixin: class BaseValMixin:
@ -83,16 +84,16 @@ class BaseValMixin:
def validation_epoch_end(self, outputs, *_, **__): def validation_epoch_end(self, outputs, *_, **__):
assert isinstance(self, LightningBaseModule) assert isinstance(self, LightningBaseModule)
summary_dict = dict(log=dict()) summary_dict = dict()
# In case of Multiple given dataloader this will outputs will be: list[list[dict[]]] # In case of Multiple given dataloader this will outputs will be: list[list[dict[]]]
# for output_idx, output in enumerate(outputs): # for output_idx, output in enumerate(outputs):
# else:list[dict[]] # else:list[dict[]]
keys = list(outputs.keys()) keys = list(outputs.keys())
# Add Every Value das has a "loss" in it, by calc. mean over all occurences. # Add Every Value das has a "loss" in it, by calc. mean over all occurences.
summary_dict['log'].update({f'mean_{key}': torch.mean(torch.stack([output[key] summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs])) for output in outputs]))
for key in keys if 'loss' in key} for key in keys if 'loss' in key}
) )
""" """
# Additional Score like the unweighted Average Recall: # Additional Score like the unweighted Average Recall:
# UnweightedAverageRecall # UnweightedAverageRecall
@ -107,7 +108,8 @@ class BaseValMixin:
summary_dict['log'].update({f'uar_score': uar_score}) summary_dict['log'].update({f'uar_score': uar_score})
""" """
return summary_dict for key in summary_dict.keys():
self.log(key, summary_dict[key])
class BinaryMaskDatasetMixin: class BinaryMaskDatasetMixin:

View File

@ -1,8 +1,5 @@
from argparse import Namespace from argparse import Namespace
from ml_lib.utils.config import Config
class GlobalVar(Namespace): class GlobalVar(Namespace):
# Labels for classes # Labels for classes
LEFT = 1 LEFT = 1
@ -21,10 +18,3 @@ class GlobalVar(Namespace):
train='train', train='train',
vali='vali', vali='vali',
test='test' test='test'
class ThisConfig(Config):
@property
def _model_map(self):
return dict()

View File

@ -12,6 +12,7 @@ class Speed(object):
def __init__(self, max_amount=0.3, speed_min=1, speed_max=1): def __init__(self, max_amount=0.3, speed_min=1, speed_max=1):
self.speed_max = speed_max if speed_max else 1 self.speed_max = speed_max if speed_max else 1
self.speed_min = speed_min if speed_min else 1 self.speed_min = speed_min if speed_min else 1
# noinspection PyTypeChecker
self.max_amount = min(max(0, max_amount), 1) self.max_amount = min(max(0, max_amount), 1)
def __call__(self, x): def __call__(self, x):

View File

@ -1,16 +1,18 @@
import math import warnings
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
import torch import torch
import warnings
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from einops import rearrange
import sys import sys
sys.path.append(str(Path(__file__).parent)) sys.path.append(str(Path(__file__).parent))
from .util import AutoPad, Interpolate, ShapeMixin, F_x, Flatten
from .util import AutoPad, Interpolate, ShapeMixin, F_x, Flatten, ResidualBlock, PreNorm
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@ -212,81 +214,81 @@ class RecurrentModule(ShapeMixin, nn.Module):
tensor = self.rnn(x) tensor = self.rnn(x)
return tensor return tensor
class FeedForward(nn.Module):
class AttentionModule(ShapeMixin, nn.Module): def __init__(self, dim, hidden_dim, dropout = 0.):
def __init__(self,in_shape, features, dropout=0.1):
super().__init__() super().__init__()
self.in_shape = in_shape self.net = nn.Sequential(
self.dropout = dropout nn.Linear(dim, hidden_dim),
self.features = features nn.GELU(),
raise NotImplementedError nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x): def forward(self, x):
pass return self.net(x)
class Attention(nn.Module):
class MultiHeadAttentionModule(ShapeMixin, nn.Module): def __init__(self, dim, heads = 8, dropout = 0.):
def __init__(self, in_shape, heads, features, dropout=0.1):
super().__init__() super().__init__()
self.in_shape = in_shape
self.features = features
self.heads = heads self.heads = heads
self.final_dim = self.features // self.heads self.scale = dim ** -0.5
self.linear_q = LinearModule(self.features, self.features) self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
self.linear_v = LinearModule(self.features, self.features) self.to_out = nn.Sequential(
self.linear_k = LinearModule(self.features, self.features) nn.Linear(dim, dim),
self.dropout = nn.Dropout(dropout) if dropout else F_x(self.features) nn.Dropout(dropout)
self.linear_out = nn.Linear(self.features, self.features) )
def forward(self, q, k, v, mask=None): def forward(self, x, mask = None):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = [rearrange(t, 'b n (h d) -> b h n d', h = h) for t in qkv]
batch_size = q.size(0) dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
mask_value = -torch.finfo(dots.dtype).max
# perform linear operation and split into h heads
k = self.linear_k(k).view(batch_size, -1, self.heads, self.final_dim)
q = self.linear_q(q).view(batch_size, -1, self.heads, self.final_dim)
v = self.linear_v(v).view(batch_size, -1, self.heads, self.final_dim)
# transpose to get dimensions bs * h * sl * features
# ToDo: Do we need this?
k = k.transpose(1, 2)
q = q.transpose(1, 2)
v = v.transpose(1, 2)
# calculate attention
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.final_dim)
if mask is not None: if mask is not None:
mask = mask.unsqueeze(1) mask = F.pad(mask.flatten(1), [1, 0], value = True)
scores = scores.masked_fill(mask == 0, -1e9) assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
scores = F.softmax(scores, dim=-1) mask = mask[:, None, :] * mask[:, :, None]
scores = self.dropout(scores) dots.masked_fill_(~mask, mask_value)
scores = torch.matmul(scores, v) del mask
# concatenate heads and apply final linear transformation attn = dots.softmax(dim=-1)
# ToDo: This seems to be old coding style. Do we Need this?
concat = scores.transpose(1, 2).contiguous().view(batch_size, -1, self.features)
output = self.out(concat) out = torch.einsum('bhij,bhjd->bhid', attn, v)
return output out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, mlp_dim, dropout):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
ResidualBlock(PreNorm(dim, Attention(dim, heads = heads, dropout = dropout))),
ResidualBlock(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
]))
def forward(self, x, mask = None, *_, **__):
for attn, ff in self.layers:
x = attn(x, mask = mask)
x = ff(x)
return x
class TransformerModule(ShapeMixin, nn.Module): class TransformerModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, hidden_size, n_heads, num_layers=1, dropout=None, use_norm=False, **kwargs): def __init__(self, in_shape, hidden_size, n_heads, num_layers=1, dropout=None, use_norm=False, activation='gelu'):
super(TransformerModule, self).__init__() super(TransformerModule, self).__init__()
self.in_shape = in_shape self.in_shape = in_shape
self.flat = Flatten(self.in_shape) if isinstance(self.in_shape, (tuple, list)) else F_x(in_shape) self.flat = Flatten(self.in_shape) if isinstance(self.in_shape, (tuple, list)) else F_x(in_shape)
encoder_layer = nn.TransformerEncoderLayer(self.flat_shape, n_heads, dim_feedforward=hidden_size, self.transformer = Transformer(dim=self.flat.flat_shape, depth=num_layers, heads=n_heads,
dropout=dropout, activation=kwargs.get('activation') mlp_dim=hidden_size, dropout=dropout)
)
self.norm = nn.LayerNorm(hidden_size) if use_norm else F_x(hidden_size)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers, )
def forward(self, x, mask=None, key_padding_mask=None): def forward(self, x, mask=None, key_padding_mask=None):
tensor = self.flat(x) tensor = self.flat(x)

View File

@ -11,7 +11,7 @@ from operator import mul
from torch import nn from torch import nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from .blocks import ConvModule, DeConvModule, LinearModule, MultiHeadAttentionModule from .blocks import ConvModule, DeConvModule, LinearModule
from .util import ShapeMixin, LightningBaseModule, Flatten from .util import ShapeMixin, LightningBaseModule, Flatten
@ -112,6 +112,7 @@ class Generator(ShapeMixin, nn.Module):
last_shape = re_shape last_shape = re_shape
for conv_filter, conv_kernel, interpolation in zip(reversed(filters), kernels, interpolations): for conv_filter, conv_kernel, interpolation in zip(reversed(filters), kernels, interpolations):
# noinspection PyTypeChecker
self.de_conv_list.append(DeConvModule(last_shape, conv_filters=conv_filter, self.de_conv_list.append(DeConvModule(last_shape, conv_filters=conv_filter,
conv_kernel=conv_kernel, conv_kernel=conv_kernel,
conv_padding=conv_kernel-2, conv_padding=conv_kernel-2,
@ -275,16 +276,3 @@ class Encoder(BaseEncoder):
tensor = self.l1(tensor) tensor = self.l1(tensor)
tensor = self.latent_activation(tensor) if self.latent_activation else tensor tensor = self.latent_activation(tensor) if self.latent_activation else tensor
return tensor return tensor
class TransformerEncoder(ShapeMixin, nn.Module):
def __init__(self, in_shape):
super(TransformerEncoder, self).__init__()
# MultiheadSelfAttention
self.msa = MultiHeadAttentionModule()
def forward(self, x):

View File

@ -1,3 +1,5 @@
from typing import List
from functools import reduce from functools import reduce
from abc import ABC from abc import ABC
@ -6,7 +8,7 @@ from pathlib import Path
import torch import torch
from operator import mul from operator import mul
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F, Unfold
# Utility - Modules # Utility - Modules
################### ###################
@ -38,6 +40,7 @@ try:
################################ ################################
self.hparams = hparams self.hparams = hparams
self.params = ModelParameters(hparams) self.params = ModelParameters(hparams)
self.lr = self.params.lr or 1e-4
def size(self): def size(self):
return self.shape return self.shape
@ -76,10 +79,10 @@ try:
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_) weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
self.apply(weight_initializer) self.apply(weight_initializer)
modules = [LightningBaseModule, nn.Module] module_types = (LightningBaseModule, nn.Module,)
except ImportError: except ImportError:
modules = [nn.Module, ] module_types = (nn.Module,)
pass # Maybe post a hint to install pytorch-lightning. pass # Maybe post a hint to install pytorch-lightning.
@ -88,7 +91,7 @@ class ShapeMixin:
@property @property
def shape(self): def shape(self):
assert isinstance(self, modules) assert isinstance(self, module_types)
def get_out_shape(output): def get_out_shape(output):
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1] return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
@ -135,6 +138,41 @@ class F_x(ShapeMixin, nn.Module):
return x return x
class ResidualBlock(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class SlidingWindow(nn.Module):
def __init__(self, kernel, stride=1, padding=0, keepdim=False):
super(SlidingWindow, self).__init__()
self.kernel = kernel if not isinstance(kernel, int) else (kernel, kernel)
self.padding = padding
self.stride = stride
self.keepdim = keepdim
self._unfolder = Unfold(self.kernel, dilation=1, padding=self.padding, stride=self.stride)
def forward(self, x):
tensor = self._unfolder(x)
tensor = tensor.transpose(-1, -2)
if self.keepdim:
shape = *x.shape[:2], -1, *self.kernel
tensor = tensor.reshape(shape)
return tensor
# Utility - Modules # Utility - Modules
################### ###################
class Flatten(ShapeMixin, nn.Module): class Flatten(ShapeMixin, nn.Module):
@ -232,14 +270,13 @@ class AutoPadToShape(object):
def __call__(self, x): def __call__(self, x):
if not torch.is_tensor(x): if not torch.is_tensor(x):
x = torch.as_tensor(x) x = torch.as_tensor(x)
if x.shape[1:] == self.shape or x.shape == self.shape: if x.shape[-len(self.shape):] == self.shape or x.shape == self.shape:
return x return x
for i in range(-1, -len(self.shape), -1): idx = [0] * (len(self.shape) * 2)
idx = [0] * len(x.shape) for i, j in zip(range(-1, -(len(self.shape)+1), -1), range(0, len(idx), 2)):
idx[i] = self.shape[i] - x.shape[i] idx[j] = self.shape[i] - x.shape[i]
idx = tuple(idx) x = torch.nn.functional.pad(x, idx)
x = torch.nn.functional.pad(x, idx)
return x return x
def __repr__(self): def __repr__(self):

View File

@ -94,7 +94,7 @@ class Config(ConfigParser, ABC):
try: try:
return locate_and_import_class(self.model.type) return locate_and_import_class(self.model.type)
except AttributeError as e: except AttributeError as e:
raise AttributeError(f'The model alias you provided ("{self.get("model", "type")}")' + raise AttributeError(f'The model alias you provided ("{self.get("model", "type")}") ' +
f'was not found!\n' + f'was not found!\n' +
f'{e}') f'{e}')

View File

@ -13,6 +13,10 @@ from torch import nn
# Hyperparamter Object # Hyperparamter Object
class ModelParameters(Namespace, Mapping): class ModelParameters(Namespace, Mapping):
@property
def activation_as_string(self):
return self['activation'].lower()
@property @property
def module_kwargs(self): def module_kwargs(self):
@ -56,6 +60,7 @@ class ModelParameters(Namespace, Mapping):
_activations = dict( _activations = dict(
leaky_relu=nn.LeakyReLU, leaky_relu=nn.LeakyReLU,
gelu=nn.GELU,
elu=nn.ELU, elu=nn.ELU,
relu=nn.ReLU, relu=nn.ReLU,
sigmoid=nn.Sigmoid, sigmoid=nn.Sigmoid,