diff --git a/_templates/new_project/utils/module_mixins.py b/_templates/new_project/utils/module_mixins.py index 609977b..b11cffb 100644 --- a/_templates/new_project/utils/module_mixins.py +++ b/_templates/new_project/utils/module_mixins.py @@ -61,10 +61,11 @@ class BaseTrainMixin: assert isinstance(self, LightningBaseModule) keys = list(outputs[0].keys()) - summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key] - for output in outputs])) - for key in keys if 'loss' in key}) - return summary_dict + summary_dict = {f'mean_{key}': torch.mean(torch.stack([output[key] + for output in outputs])) + for key in keys if 'loss' in key} + for key in summary_dict.keys(): + self.log(key, summary_dict[key]) class BaseValMixin: @@ -83,16 +84,16 @@ class BaseValMixin: def validation_epoch_end(self, outputs, *_, **__): assert isinstance(self, LightningBaseModule) - summary_dict = dict(log=dict()) + summary_dict = dict() # In case of Multiple given dataloader this will outputs will be: list[list[dict[]]] # for output_idx, output in enumerate(outputs): # else:list[dict[]] keys = list(outputs.keys()) # Add Every Value das has a "loss" in it, by calc. mean over all occurences. - summary_dict['log'].update({f'mean_{key}': torch.mean(torch.stack([output[key] - for output in outputs])) - for key in keys if 'loss' in key} - ) + summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key] + for output in outputs])) + for key in keys if 'loss' in key} + ) """ # Additional Score like the unweighted Average Recall: # UnweightedAverageRecall @@ -107,7 +108,8 @@ class BaseValMixin: summary_dict['log'].update({f'uar_score': uar_score}) """ - return summary_dict + for key in summary_dict.keys(): + self.log(key, summary_dict[key]) class BinaryMaskDatasetMixin: diff --git a/_templates/new_project/utils/project_config.py b/_templates/new_project/utils/project_config.py index 8b651b3..1861da3 100644 --- a/_templates/new_project/utils/project_config.py +++ b/_templates/new_project/utils/project_config.py @@ -1,8 +1,5 @@ from argparse import Namespace -from ml_lib.utils.config import Config - - class GlobalVar(Namespace): # Labels for classes LEFT = 1 @@ -21,10 +18,3 @@ class GlobalVar(Namespace): train='train', vali='vali', test='test' - - -class ThisConfig(Config): - - @property - def _model_map(self): - return dict() diff --git a/audio_toolset/audio_augmentation.py b/audio_toolset/audio_augmentation.py index d8a7903..17ea0ef 100644 --- a/audio_toolset/audio_augmentation.py +++ b/audio_toolset/audio_augmentation.py @@ -12,6 +12,7 @@ class Speed(object): def __init__(self, max_amount=0.3, speed_min=1, speed_max=1): self.speed_max = speed_max if speed_max else 1 self.speed_min = speed_min if speed_min else 1 + # noinspection PyTypeChecker self.max_amount = min(max(0, max_amount), 1) def __call__(self, x): diff --git a/modules/blocks.py b/modules/blocks.py index df17d2d..f3ebbe8 100644 --- a/modules/blocks.py +++ b/modules/blocks.py @@ -1,16 +1,18 @@ -import math +import warnings from pathlib import Path from typing import Union import torch -import warnings - from torch import nn from torch.nn import functional as F + +from einops import rearrange + import sys sys.path.append(str(Path(__file__).parent)) -from .util import AutoPad, Interpolate, ShapeMixin, F_x, Flatten + +from .util import AutoPad, Interpolate, ShapeMixin, F_x, Flatten, ResidualBlock, PreNorm DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -212,81 +214,81 @@ class RecurrentModule(ShapeMixin, nn.Module): tensor = self.rnn(x) return tensor - -class AttentionModule(ShapeMixin, nn.Module): - def __init__(self,in_shape, features, dropout=0.1): +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() - self.in_shape = in_shape - self.dropout = dropout - self.features = features - raise NotImplementedError - + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) def forward(self, x): - pass + return self.net(x) - -class MultiHeadAttentionModule(ShapeMixin, nn.Module): - def __init__(self, in_shape, heads, features, dropout=0.1): +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dropout = 0.): super().__init__() - self.in_shape = in_shape - - self.features = features self.heads = heads - self.final_dim = self.features // self.heads + self.scale = dim ** -0.5 - self.linear_q = LinearModule(self.features, self.features) - self.linear_v = LinearModule(self.features, self.features) - self.linear_k = LinearModule(self.features, self.features) - self.dropout = nn.Dropout(dropout) if dropout else F_x(self.features) - self.linear_out = nn.Linear(self.features, self.features) + self.to_qkv = nn.Linear(dim, dim * 3, bias = False) + self.to_out = nn.Sequential( + nn.Linear(dim, dim), + nn.Dropout(dropout) + ) - def forward(self, q, k, v, mask=None): + def forward(self, x, mask = None): + b, n, _, h = *x.shape, self.heads + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = [rearrange(t, 'b n (h d) -> b h n d', h = h) for t in qkv] - batch_size = q.size(0) + dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale + mask_value = -torch.finfo(dots.dtype).max - # perform linear operation and split into h heads - k = self.linear_k(k).view(batch_size, -1, self.heads, self.final_dim) - q = self.linear_q(q).view(batch_size, -1, self.heads, self.final_dim) - v = self.linear_v(v).view(batch_size, -1, self.heads, self.final_dim) - - # transpose to get dimensions bs * h * sl * features - # ToDo: Do we need this? - - k = k.transpose(1, 2) - q = q.transpose(1, 2) - v = v.transpose(1, 2) - - # calculate attention - scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.final_dim) if mask is not None: - mask = mask.unsqueeze(1) - scores = scores.masked_fill(mask == 0, -1e9) - scores = F.softmax(scores, dim=-1) - scores = self.dropout(scores) - scores = torch.matmul(scores, v) + mask = F.pad(mask.flatten(1), [1, 0], value = True) + assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' + mask = mask[:, None, :] * mask[:, :, None] + dots.masked_fill_(~mask, mask_value) + del mask - # concatenate heads and apply final linear transformation - # ToDo: This seems to be old coding style. Do we Need this? - concat = scores.transpose(1, 2).contiguous().view(batch_size, -1, self.features) + attn = dots.softmax(dim=-1) - output = self.out(concat) - return output + out = torch.einsum('bhij,bhjd->bhid', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + out = self.to_out(out) + return out + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, mlp_dim, dropout): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + ResidualBlock(PreNorm(dim, Attention(dim, heads = heads, dropout = dropout))), + ResidualBlock(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) + ])) + + def forward(self, x, mask = None, *_, **__): + for attn, ff in self.layers: + x = attn(x, mask = mask) + x = ff(x) + return x class TransformerModule(ShapeMixin, nn.Module): - def __init__(self, in_shape, hidden_size, n_heads, num_layers=1, dropout=None, use_norm=False, **kwargs): + def __init__(self, in_shape, hidden_size, n_heads, num_layers=1, dropout=None, use_norm=False, activation='gelu'): super(TransformerModule, self).__init__() self.in_shape = in_shape self.flat = Flatten(self.in_shape) if isinstance(self.in_shape, (tuple, list)) else F_x(in_shape) - encoder_layer = nn.TransformerEncoderLayer(self.flat_shape, n_heads, dim_feedforward=hidden_size, - dropout=dropout, activation=kwargs.get('activation') - ) - self.norm = nn.LayerNorm(hidden_size) if use_norm else F_x(hidden_size) - self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers, ) + self.transformer = Transformer(dim=self.flat.flat_shape, depth=num_layers, heads=n_heads, + mlp_dim=hidden_size, dropout=dropout) def forward(self, x, mask=None, key_padding_mask=None): tensor = self.flat(x) diff --git a/modules/model_parts.py b/modules/model_parts.py index a837da4..2150e64 100644 --- a/modules/model_parts.py +++ b/modules/model_parts.py @@ -11,7 +11,7 @@ from operator import mul from torch import nn from torch.utils.data import DataLoader -from .blocks import ConvModule, DeConvModule, LinearModule, MultiHeadAttentionModule +from .blocks import ConvModule, DeConvModule, LinearModule from .util import ShapeMixin, LightningBaseModule, Flatten @@ -112,6 +112,7 @@ class Generator(ShapeMixin, nn.Module): last_shape = re_shape for conv_filter, conv_kernel, interpolation in zip(reversed(filters), kernels, interpolations): + # noinspection PyTypeChecker self.de_conv_list.append(DeConvModule(last_shape, conv_filters=conv_filter, conv_kernel=conv_kernel, conv_padding=conv_kernel-2, @@ -275,16 +276,3 @@ class Encoder(BaseEncoder): tensor = self.l1(tensor) tensor = self.latent_activation(tensor) if self.latent_activation else tensor return tensor - - -class TransformerEncoder(ShapeMixin, nn.Module): - - def __init__(self, in_shape): - super(TransformerEncoder, self).__init__() - # MultiheadSelfAttention - self.msa = MultiHeadAttentionModule() - - - def forward(self, x): - - diff --git a/modules/util.py b/modules/util.py index 020114a..3394aae 100644 --- a/modules/util.py +++ b/modules/util.py @@ -1,3 +1,5 @@ +from typing import List + from functools import reduce from abc import ABC @@ -6,7 +8,7 @@ from pathlib import Path import torch from operator import mul from torch import nn -from torch.nn import functional as F +from torch.nn import functional as F, Unfold # Utility - Modules ################### @@ -38,6 +40,7 @@ try: ################################ self.hparams = hparams self.params = ModelParameters(hparams) + self.lr = self.params.lr or 1e-4 def size(self): return self.shape @@ -76,10 +79,10 @@ try: weight_initializer = WeightInit(in_place_init_function=in_place_init_func_) self.apply(weight_initializer) - modules = [LightningBaseModule, nn.Module] + module_types = (LightningBaseModule, nn.Module,) except ImportError: - modules = [nn.Module, ] + module_types = (nn.Module,) pass # Maybe post a hint to install pytorch-lightning. @@ -88,7 +91,7 @@ class ShapeMixin: @property def shape(self): - assert isinstance(self, modules) + assert isinstance(self, module_types) def get_out_shape(output): return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1] @@ -135,6 +138,41 @@ class F_x(ShapeMixin, nn.Module): return x +class ResidualBlock(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(x, **kwargs) + x + + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + + +class SlidingWindow(nn.Module): + def __init__(self, kernel, stride=1, padding=0, keepdim=False): + super(SlidingWindow, self).__init__() + self.kernel = kernel if not isinstance(kernel, int) else (kernel, kernel) + self.padding = padding + self.stride = stride + self.keepdim = keepdim + self._unfolder = Unfold(self.kernel, dilation=1, padding=self.padding, stride=self.stride) + + def forward(self, x): + tensor = self._unfolder(x) + tensor = tensor.transpose(-1, -2) + if self.keepdim: + shape = *x.shape[:2], -1, *self.kernel + tensor = tensor.reshape(shape) + return tensor + + # Utility - Modules ################### class Flatten(ShapeMixin, nn.Module): @@ -232,14 +270,13 @@ class AutoPadToShape(object): def __call__(self, x): if not torch.is_tensor(x): x = torch.as_tensor(x) - if x.shape[1:] == self.shape or x.shape == self.shape: + if x.shape[-len(self.shape):] == self.shape or x.shape == self.shape: return x - for i in range(-1, -len(self.shape), -1): - idx = [0] * len(x.shape) - idx[i] = self.shape[i] - x.shape[i] - idx = tuple(idx) - x = torch.nn.functional.pad(x, idx) + idx = [0] * (len(self.shape) * 2) + for i, j in zip(range(-1, -(len(self.shape)+1), -1), range(0, len(idx), 2)): + idx[j] = self.shape[i] - x.shape[i] + x = torch.nn.functional.pad(x, idx) return x def __repr__(self): diff --git a/utils/config.py b/utils/config.py index 040c231..700f856 100644 --- a/utils/config.py +++ b/utils/config.py @@ -94,7 +94,7 @@ class Config(ConfigParser, ABC): try: return locate_and_import_class(self.model.type) except AttributeError as e: - raise AttributeError(f'The model alias you provided ("{self.get("model", "type")}")' + + raise AttributeError(f'The model alias you provided ("{self.get("model", "type")}") ' + f'was not found!\n' + f'{e}') diff --git a/utils/model_io.py b/utils/model_io.py index fe9374f..9bcc18d 100644 --- a/utils/model_io.py +++ b/utils/model_io.py @@ -13,6 +13,10 @@ from torch import nn # Hyperparamter Object class ModelParameters(Namespace, Mapping): + @property + def activation_as_string(self): + return self['activation'].lower() + @property def module_kwargs(self): @@ -56,6 +60,7 @@ class ModelParameters(Namespace, Mapping): _activations = dict( leaky_relu=nn.LeakyReLU, + gelu=nn.GELU, elu=nn.ELU, relu=nn.ReLU, sigmoid=nn.Sigmoid,