This commit is contained in:
Si11ium 2020-10-07 15:21:45 +02:00
parent 5848b528f0
commit f296ba78b9
6 changed files with 78 additions and 39 deletions

View File

@ -76,9 +76,9 @@ class NormalizeMelband(object):
class AudioToMel(object): class AudioToMel(object):
def __init__(self, amplitude_to_db=False, power_to_db=False, **kwargs): def __init__(self, amplitude_to_db=False, power_to_db=False, **mel_kwargs):
assert not all([amplitude_to_db, power_to_db]), "Choose amplitude_to_db or power_to_db, not both!" assert not all([amplitude_to_db, power_to_db]), "Choose amplitude_to_db or power_to_db, not both!"
self.mel_kwargs = kwargs self.mel_kwargs = mel_kwargs
self.amplitude_to_db = amplitude_to_db self.amplitude_to_db = amplitude_to_db
self.power_to_db = power_to_db self.power_to_db = power_to_db

View File

@ -59,9 +59,9 @@ class ShiftTime(object):
# Set to silence for heading/ tailing # Set to silence for heading/ tailing
shift = int(shift) shift = int(shift)
if shift > 0: if shift > 0:
augmented_data[:, :shift] = 0 augmented_data[:shift, :] = 0
else: else:
augmented_data[:, shift:] = 0 augmented_data[shift:, :] = 0
return augmented_data return augmented_data
else: else:
return x return x

View File

@ -0,0 +1,21 @@
from typing import Union
import numpy as np
class Normalize(object):
def __init__(self, min_db_level: Union[int, float]):
self.min_db_level = min_db_level
def __call__(self, s: np.ndarray) -> np.ndarray:
return np.clip((s - self.min_db_level) / -self.min_db_level, 0, 1)
class DeNormalize(object):
def __init__(self, min_db_level: Union[int, float]):
self.min_db_level = min_db_level
def __call__(self, s: np.ndarray) -> np.ndarray:
return (np.clip(s, 0, 1) * -self.min_db_level) + self.min_db_level

View File

@ -2,14 +2,18 @@
# Full Model Parts # Full Model Parts
################### ###################
from argparse import Namespace from argparse import Namespace
from functools import reduce
from typing import Union, List, Tuple from typing import Union, List, Tuple
import torch import torch
from abc import ABC from abc import ABC
from operator import mul
from torch import nn from torch import nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from .util import ShapeMixin, LightningBaseModule from .blocks import ConvModule, DeConvModule, LinearModule
from .util import ShapeMixin, LightningBaseModule, Flatten
class AEBaseModule(LightningBaseModule, ABC): class AEBaseModule(LightningBaseModule, ABC):
@ -33,7 +37,7 @@ class AEBaseModule(LightningBaseModule, ABC):
def encode(self, x): def encode(self, x):
if len(x.shape) == 3: if len(x.shape) == 3:
x = x.unsqueeze(0) x = x.unsqueeze(0)
return self.encoder(x).squeeze() return self.encoder(x)
def _find_min_max(self, dataloader): def _find_min_max(self, dataloader):
encodings = list() encodings = list()
@ -60,56 +64,61 @@ class AEBaseModule(LightningBaseModule, ABC):
return self.decode(random_latent_samples).cpu().detach() return self.decode(random_latent_samples).cpu().detach()
def decode(self, z): def decode(self, z):
if len(z.shape) == 1: try:
z = z.unsqueeze(0) if len(z.shape) == 1:
z = z.unsqueeze(0)
except AttributeError:
# Does not seem to be a tensor.
pass
return self.decoder(z).squeeze() return self.decoder(z).squeeze()
def encode_and_restore(self, x): def encode_and_restore(self, x):
x = x.to(self.device) x = self.transfer_batch_to_device(x, self.device)
if len(x.shape) == 3: if len(x.shape) == 3:
x = x.unsqueeze(0) x = x.unsqueeze(0)
z = self.encode(x) z = self.encode(x)
try:
z = z.squeeze()
except AttributeError:
# Does not seem to be a tensor.
pass
x_hat = self.decode(z) x_hat = self.decode(z)
return Namespace(main_out=x_hat.squeeze(), latent_out=z) return Namespace(main_out=x_hat.squeeze(), latent_out=z)
class Generator(nn.Module): class Generator(ShapeMixin, nn.Module):
@property
def shape(self):
x = torch.randn(self.lat_dim).unsqueeze(0)
output = self(x)
return output.shape[1:]
# noinspection PyUnresolvedReferences def __init__(self, in_shape, out_channels, re_shape, use_norm=False, use_bias=True,
def __init__(self, out_channels, re_shape, lat_dim, use_norm=False, use_bias=True, dropout: Union[int, float] = 0, dropout: Union[int, float] = 0, interpolations: List[int] = None,
filters: List[int] = None, kernels: List[int] = None, activation=nn.ReLU, **kwargs): filters: List[int] = None, kernels: List[int] = None, activation=nn.ReLU,
**kwargs):
super(Generator, self).__init__() super(Generator, self).__init__()
assert filters, '"Filters" has to be a list of int.' assert filters, '"Filters" has to be a list of int.'
assert filters, '"Filters" has to be a list of int.' assert filters, '"Filters" has to be a list of int.'
assert len(filters) == len(kernels), '"Filters" and "Kernels" has to be of same length.' assert len(filters) == len(kernels), '"Filters" and "Kernels" has to be of same length.'
self.filters = filters interpolations = interpolations or [2, 2, 2]
self.activation = activation
self.inner_activation = activation() self.in_shape = in_shape
self.activation = activation()
self.out_activation = None self.out_activation = None
self.lat_dim = lat_dim
self.dropout = dropout self.dropout = dropout
self.l1 = nn.Linear(self.lat_dim, reduce(mul, re_shape), bias=use_bias) self.l1 = LinearModule(in_shape, reduce(mul, re_shape), bias=use_bias, activation=activation)
# re_shape = (self.feature_mixed_dim // reduce(mul, re_shape[1:]), ) + tuple(re_shape[1:]) # re_shape = (self.feature_mixed_dim // reduce(mul, re_shape[1:]), ) + tuple(re_shape[1:])
self.flat = Flatten(to=re_shape) self.flat = Flatten(self.l1.shape, to=re_shape)
self.de_conv_list = nn.ModuleList() self.de_conv_list = nn.ModuleList()
last_shape = re_shape last_shape = re_shape
for conv_filter, conv_kernel in zip(filters, kernels): for conv_filter, conv_kernel, interpolation in zip(reversed(filters), kernels, interpolations):
self.de_conv_list.append(DeConvModule(last_shape, conv_filters=self.filters[0], self.de_conv_list.append(DeConvModule(last_shape, conv_filters=conv_filter,
conv_kernel=conv_kernel, conv_kernel=conv_kernel,
conv_padding=conv_kernel-2, conv_padding=conv_kernel-2,
conv_stride=conv_filter, conv_stride=1,
normalize=use_norm, normalize=use_norm,
activation=self.activation, activation=activation,
interpolation_scale=2, interpolation_scale=interpolation,
dropout=self.dropout dropout=self.dropout
) )
) )
@ -121,7 +130,7 @@ class Generator(nn.Module):
def forward(self, z): def forward(self, z):
tensor = self.l1(z) tensor = self.l1(z)
tensor = self.inner_activation(tensor) tensor = self.activation(tensor)
tensor = self.flat(tensor) tensor = self.flat(tensor)
for de_conv in self.de_conv_list: for de_conv in self.de_conv_list:
@ -204,8 +213,9 @@ class BaseEncoder(ShapeMixin, nn.Module):
) )
) )
last_shape = self.conv_list[-1].shape last_shape = self.conv_list[-1].shape
self.last_conv_shape = last_shape
self.flat = Flatten() self.flat = Flatten(self.last_conv_shape)
def forward(self, x): def forward(self, x):
tensor = x tensor = x
@ -254,11 +264,11 @@ class VariationalEncoder(BaseEncoder):
class Encoder(BaseEncoder): class Encoder(BaseEncoder):
# noinspection PyUnresolvedReferences
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(Encoder, self).__init__(*args, **kwargs) super(Encoder, self).__init__(*args, **kwargs)
self.l1 = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias) self.l1 = nn.Linear(self.flat.shape, self.lat_dim, bias=self.use_bias)
def forward(self, x): def forward(self, x):
tensor = super(Encoder, self).forward(x) tensor = super(Encoder, self).forward(x)

View File

@ -6,7 +6,7 @@ from pathlib import Path
import torch import torch
from operator import mul from operator import mul
from torch import nn from torch import nn
from torch import functional as F from torch.nn import functional as F
import pytorch_lightning as pl import pytorch_lightning as pl
@ -92,7 +92,14 @@ class ShapeMixin:
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1] return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
if self.in_shape is not None: if self.in_shape is not None:
x = torch.randn(self.in_shape) try:
device = self.device
except AttributeError:
try:
device = next(self.parameters()).device
except StopIteration:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
x = torch.randn(self.in_shape, device=device)
# This is needed for BatchNorm shape checking # This is needed for BatchNorm shape checking
x = torch.stack((x, x)) x = torch.stack((x, x))
@ -248,7 +255,7 @@ class Splitter(nn.Module):
def __init__(self, in_shape, n, dim=-1): def __init__(self, in_shape, n, dim=-1):
super(Splitter, self).__init__() super(Splitter, self).__init__()
self.in_shape = in_shape self.in_shape = (in_shape, ) if isinstance(in_shape, int) else in_shape
self.n = n self.n = n
self.dim = dim if dim > 0 else len(self.in_shape) - abs(dim) self.dim = dim if dim > 0 else len(self.in_shape) - abs(dim)

View File

@ -2,7 +2,6 @@ import importlib
import pickle import pickle
import shelve import shelve
from pathlib import Path, PurePath from pathlib import Path, PurePath
from pydoc import safeimport
from typing import Union from typing import Union
import numpy as np import numpy as np
@ -50,6 +49,8 @@ def locate_and_import_class(class_name, models_location: Union[str, PurePath] =
mod = importlib.import_module('.'.join([x.replace('.py', '') for x in module_path.parts])) mod = importlib.import_module('.'.join([x.replace('.py', '') for x in module_path.parts]))
try: try:
model_class = mod.__getattribute__(class_name) model_class = mod.__getattribute__(class_name)
return model_class
except AttributeError: except AttributeError:
continue continue
return model_class raise AttributeError(f'Check the Model name. Possible model files are:\n{[x.name for x in module_paths]}')