From f296ba78b9231945fc2fa8814bd8b23ee56210f5 Mon Sep 17 00:00:00 2001 From: Si11ium Date: Wed, 7 Oct 2020 15:21:45 +0200 Subject: [PATCH] Al Lot --- audio_toolset/audio_io.py | 4 +- audio_toolset/mel_augmentation.py | 4 +- audio_toolset/mel_transforms.py | 21 ++++++++++ modules/model_parts.py | 68 ++++++++++++++++++------------- modules/util.py | 13 ++++-- utils/tools.py | 7 ++-- 6 files changed, 78 insertions(+), 39 deletions(-) create mode 100644 audio_toolset/mel_transforms.py diff --git a/audio_toolset/audio_io.py b/audio_toolset/audio_io.py index bc32679..e0e02a2 100644 --- a/audio_toolset/audio_io.py +++ b/audio_toolset/audio_io.py @@ -76,9 +76,9 @@ class NormalizeMelband(object): class AudioToMel(object): - def __init__(self, amplitude_to_db=False, power_to_db=False, **kwargs): + def __init__(self, amplitude_to_db=False, power_to_db=False, **mel_kwargs): assert not all([amplitude_to_db, power_to_db]), "Choose amplitude_to_db or power_to_db, not both!" - self.mel_kwargs = kwargs + self.mel_kwargs = mel_kwargs self.amplitude_to_db = amplitude_to_db self.power_to_db = power_to_db diff --git a/audio_toolset/mel_augmentation.py b/audio_toolset/mel_augmentation.py index ec1095f..acb356b 100644 --- a/audio_toolset/mel_augmentation.py +++ b/audio_toolset/mel_augmentation.py @@ -59,9 +59,9 @@ class ShiftTime(object): # Set to silence for heading/ tailing shift = int(shift) if shift > 0: - augmented_data[:, :shift] = 0 + augmented_data[:shift, :] = 0 else: - augmented_data[:, shift:] = 0 + augmented_data[shift:, :] = 0 return augmented_data else: return x diff --git a/audio_toolset/mel_transforms.py b/audio_toolset/mel_transforms.py new file mode 100644 index 0000000..9aa470a --- /dev/null +++ b/audio_toolset/mel_transforms.py @@ -0,0 +1,21 @@ +from typing import Union + +import numpy as np + + +class Normalize(object): + + def __init__(self, min_db_level: Union[int, float]): + self.min_db_level = min_db_level + + def __call__(self, s: np.ndarray) -> np.ndarray: + return np.clip((s - self.min_db_level) / -self.min_db_level, 0, 1) + + +class DeNormalize(object): + + def __init__(self, min_db_level: Union[int, float]): + self.min_db_level = min_db_level + + def __call__(self, s: np.ndarray) -> np.ndarray: + return (np.clip(s, 0, 1) * -self.min_db_level) + self.min_db_level diff --git a/modules/model_parts.py b/modules/model_parts.py index 9afd028..a89dbc4 100644 --- a/modules/model_parts.py +++ b/modules/model_parts.py @@ -2,14 +2,18 @@ # Full Model Parts ################### from argparse import Namespace +from functools import reduce from typing import Union, List, Tuple import torch from abc import ABC +from operator import mul from torch import nn from torch.utils.data import DataLoader -from .util import ShapeMixin, LightningBaseModule +from .blocks import ConvModule, DeConvModule, LinearModule + +from .util import ShapeMixin, LightningBaseModule, Flatten class AEBaseModule(LightningBaseModule, ABC): @@ -33,7 +37,7 @@ class AEBaseModule(LightningBaseModule, ABC): def encode(self, x): if len(x.shape) == 3: x = x.unsqueeze(0) - return self.encoder(x).squeeze() + return self.encoder(x) def _find_min_max(self, dataloader): encodings = list() @@ -60,56 +64,61 @@ class AEBaseModule(LightningBaseModule, ABC): return self.decode(random_latent_samples).cpu().detach() def decode(self, z): - if len(z.shape) == 1: - z = z.unsqueeze(0) + try: + if len(z.shape) == 1: + z = z.unsqueeze(0) + except AttributeError: + # Does not seem to be a tensor. + pass return self.decoder(z).squeeze() def encode_and_restore(self, x): - x = x.to(self.device) + x = self.transfer_batch_to_device(x, self.device) if len(x.shape) == 3: x = x.unsqueeze(0) z = self.encode(x) + try: + z = z.squeeze() + except AttributeError: + # Does not seem to be a tensor. + pass x_hat = self.decode(z) return Namespace(main_out=x_hat.squeeze(), latent_out=z) -class Generator(nn.Module): - @property - def shape(self): - x = torch.randn(self.lat_dim).unsqueeze(0) - output = self(x) - return output.shape[1:] +class Generator(ShapeMixin, nn.Module): - # noinspection PyUnresolvedReferences - def __init__(self, out_channels, re_shape, lat_dim, use_norm=False, use_bias=True, dropout: Union[int, float] = 0, - filters: List[int] = None, kernels: List[int] = None, activation=nn.ReLU, **kwargs): + def __init__(self, in_shape, out_channels, re_shape, use_norm=False, use_bias=True, + dropout: Union[int, float] = 0, interpolations: List[int] = None, + filters: List[int] = None, kernels: List[int] = None, activation=nn.ReLU, + **kwargs): super(Generator, self).__init__() assert filters, '"Filters" has to be a list of int.' assert filters, '"Filters" has to be a list of int.' assert len(filters) == len(kernels), '"Filters" and "Kernels" has to be of same length.' - self.filters = filters - self.activation = activation - self.inner_activation = activation() + interpolations = interpolations or [2, 2, 2] + + self.in_shape = in_shape + self.activation = activation() self.out_activation = None - self.lat_dim = lat_dim self.dropout = dropout - self.l1 = nn.Linear(self.lat_dim, reduce(mul, re_shape), bias=use_bias) + self.l1 = LinearModule(in_shape, reduce(mul, re_shape), bias=use_bias, activation=activation) # re_shape = (self.feature_mixed_dim // reduce(mul, re_shape[1:]), ) + tuple(re_shape[1:]) - self.flat = Flatten(to=re_shape) + self.flat = Flatten(self.l1.shape, to=re_shape) self.de_conv_list = nn.ModuleList() last_shape = re_shape - for conv_filter, conv_kernel in zip(filters, kernels): - self.de_conv_list.append(DeConvModule(last_shape, conv_filters=self.filters[0], + for conv_filter, conv_kernel, interpolation in zip(reversed(filters), kernels, interpolations): + self.de_conv_list.append(DeConvModule(last_shape, conv_filters=conv_filter, conv_kernel=conv_kernel, conv_padding=conv_kernel-2, - conv_stride=conv_filter, + conv_stride=1, normalize=use_norm, - activation=self.activation, - interpolation_scale=2, + activation=activation, + interpolation_scale=interpolation, dropout=self.dropout ) ) @@ -121,7 +130,7 @@ class Generator(nn.Module): def forward(self, z): tensor = self.l1(z) - tensor = self.inner_activation(tensor) + tensor = self.activation(tensor) tensor = self.flat(tensor) for de_conv in self.de_conv_list: @@ -204,8 +213,9 @@ class BaseEncoder(ShapeMixin, nn.Module): ) ) last_shape = self.conv_list[-1].shape + self.last_conv_shape = last_shape - self.flat = Flatten() + self.flat = Flatten(self.last_conv_shape) def forward(self, x): tensor = x @@ -254,11 +264,11 @@ class VariationalEncoder(BaseEncoder): class Encoder(BaseEncoder): - # noinspection PyUnresolvedReferences + def __init__(self, *args, **kwargs): super(Encoder, self).__init__(*args, **kwargs) - self.l1 = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias) + self.l1 = nn.Linear(self.flat.shape, self.lat_dim, bias=self.use_bias) def forward(self, x): tensor = super(Encoder, self).forward(x) diff --git a/modules/util.py b/modules/util.py index 87548f7..3eb96ba 100644 --- a/modules/util.py +++ b/modules/util.py @@ -6,7 +6,7 @@ from pathlib import Path import torch from operator import mul from torch import nn -from torch import functional as F +from torch.nn import functional as F import pytorch_lightning as pl @@ -92,7 +92,14 @@ class ShapeMixin: return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1] if self.in_shape is not None: - x = torch.randn(self.in_shape) + try: + device = self.device + except AttributeError: + try: + device = next(self.parameters()).device + except StopIteration: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + x = torch.randn(self.in_shape, device=device) # This is needed for BatchNorm shape checking x = torch.stack((x, x)) @@ -248,7 +255,7 @@ class Splitter(nn.Module): def __init__(self, in_shape, n, dim=-1): super(Splitter, self).__init__() - self.in_shape = in_shape + self.in_shape = (in_shape, ) if isinstance(in_shape, int) else in_shape self.n = n self.dim = dim if dim > 0 else len(self.in_shape) - abs(dim) diff --git a/utils/tools.py b/utils/tools.py index 6d2395f..f93663e 100644 --- a/utils/tools.py +++ b/utils/tools.py @@ -2,7 +2,6 @@ import importlib import pickle import shelve from pathlib import Path, PurePath -from pydoc import safeimport from typing import Union import numpy as np @@ -50,6 +49,8 @@ def locate_and_import_class(class_name, models_location: Union[str, PurePath] = mod = importlib.import_module('.'.join([x.replace('.py', '') for x in module_path.parts])) try: model_class = mod.__getattribute__(class_name) + return model_class except AttributeError: - continue - return model_class + continue + raise AttributeError(f'Check the Model name. Possible model files are:\n{[x.name for x in module_paths]}') +