This commit is contained in:
Si11ium 2020-10-07 15:21:45 +02:00
parent 5848b528f0
commit f296ba78b9
6 changed files with 78 additions and 39 deletions

View File

@ -76,9 +76,9 @@ class NormalizeMelband(object):
class AudioToMel(object):
def __init__(self, amplitude_to_db=False, power_to_db=False, **kwargs):
def __init__(self, amplitude_to_db=False, power_to_db=False, **mel_kwargs):
assert not all([amplitude_to_db, power_to_db]), "Choose amplitude_to_db or power_to_db, not both!"
self.mel_kwargs = kwargs
self.mel_kwargs = mel_kwargs
self.amplitude_to_db = amplitude_to_db
self.power_to_db = power_to_db

View File

@ -59,9 +59,9 @@ class ShiftTime(object):
# Set to silence for heading/ tailing
shift = int(shift)
if shift > 0:
augmented_data[:, :shift] = 0
augmented_data[:shift, :] = 0
else:
augmented_data[:, shift:] = 0
augmented_data[shift:, :] = 0
return augmented_data
else:
return x

View File

@ -0,0 +1,21 @@
from typing import Union
import numpy as np
class Normalize(object):
def __init__(self, min_db_level: Union[int, float]):
self.min_db_level = min_db_level
def __call__(self, s: np.ndarray) -> np.ndarray:
return np.clip((s - self.min_db_level) / -self.min_db_level, 0, 1)
class DeNormalize(object):
def __init__(self, min_db_level: Union[int, float]):
self.min_db_level = min_db_level
def __call__(self, s: np.ndarray) -> np.ndarray:
return (np.clip(s, 0, 1) * -self.min_db_level) + self.min_db_level

View File

@ -2,14 +2,18 @@
# Full Model Parts
###################
from argparse import Namespace
from functools import reduce
from typing import Union, List, Tuple
import torch
from abc import ABC
from operator import mul
from torch import nn
from torch.utils.data import DataLoader
from .util import ShapeMixin, LightningBaseModule
from .blocks import ConvModule, DeConvModule, LinearModule
from .util import ShapeMixin, LightningBaseModule, Flatten
class AEBaseModule(LightningBaseModule, ABC):
@ -33,7 +37,7 @@ class AEBaseModule(LightningBaseModule, ABC):
def encode(self, x):
if len(x.shape) == 3:
x = x.unsqueeze(0)
return self.encoder(x).squeeze()
return self.encoder(x)
def _find_min_max(self, dataloader):
encodings = list()
@ -60,56 +64,61 @@ class AEBaseModule(LightningBaseModule, ABC):
return self.decode(random_latent_samples).cpu().detach()
def decode(self, z):
try:
if len(z.shape) == 1:
z = z.unsqueeze(0)
except AttributeError:
# Does not seem to be a tensor.
pass
return self.decoder(z).squeeze()
def encode_and_restore(self, x):
x = x.to(self.device)
x = self.transfer_batch_to_device(x, self.device)
if len(x.shape) == 3:
x = x.unsqueeze(0)
z = self.encode(x)
try:
z = z.squeeze()
except AttributeError:
# Does not seem to be a tensor.
pass
x_hat = self.decode(z)
return Namespace(main_out=x_hat.squeeze(), latent_out=z)
class Generator(nn.Module):
@property
def shape(self):
x = torch.randn(self.lat_dim).unsqueeze(0)
output = self(x)
return output.shape[1:]
class Generator(ShapeMixin, nn.Module):
# noinspection PyUnresolvedReferences
def __init__(self, out_channels, re_shape, lat_dim, use_norm=False, use_bias=True, dropout: Union[int, float] = 0,
filters: List[int] = None, kernels: List[int] = None, activation=nn.ReLU, **kwargs):
def __init__(self, in_shape, out_channels, re_shape, use_norm=False, use_bias=True,
dropout: Union[int, float] = 0, interpolations: List[int] = None,
filters: List[int] = None, kernels: List[int] = None, activation=nn.ReLU,
**kwargs):
super(Generator, self).__init__()
assert filters, '"Filters" has to be a list of int.'
assert filters, '"Filters" has to be a list of int.'
assert len(filters) == len(kernels), '"Filters" and "Kernels" has to be of same length.'
self.filters = filters
self.activation = activation
self.inner_activation = activation()
interpolations = interpolations or [2, 2, 2]
self.in_shape = in_shape
self.activation = activation()
self.out_activation = None
self.lat_dim = lat_dim
self.dropout = dropout
self.l1 = nn.Linear(self.lat_dim, reduce(mul, re_shape), bias=use_bias)
self.l1 = LinearModule(in_shape, reduce(mul, re_shape), bias=use_bias, activation=activation)
# re_shape = (self.feature_mixed_dim // reduce(mul, re_shape[1:]), ) + tuple(re_shape[1:])
self.flat = Flatten(to=re_shape)
self.flat = Flatten(self.l1.shape, to=re_shape)
self.de_conv_list = nn.ModuleList()
last_shape = re_shape
for conv_filter, conv_kernel in zip(filters, kernels):
self.de_conv_list.append(DeConvModule(last_shape, conv_filters=self.filters[0],
for conv_filter, conv_kernel, interpolation in zip(reversed(filters), kernels, interpolations):
self.de_conv_list.append(DeConvModule(last_shape, conv_filters=conv_filter,
conv_kernel=conv_kernel,
conv_padding=conv_kernel-2,
conv_stride=conv_filter,
conv_stride=1,
normalize=use_norm,
activation=self.activation,
interpolation_scale=2,
activation=activation,
interpolation_scale=interpolation,
dropout=self.dropout
)
)
@ -121,7 +130,7 @@ class Generator(nn.Module):
def forward(self, z):
tensor = self.l1(z)
tensor = self.inner_activation(tensor)
tensor = self.activation(tensor)
tensor = self.flat(tensor)
for de_conv in self.de_conv_list:
@ -204,8 +213,9 @@ class BaseEncoder(ShapeMixin, nn.Module):
)
)
last_shape = self.conv_list[-1].shape
self.last_conv_shape = last_shape
self.flat = Flatten()
self.flat = Flatten(self.last_conv_shape)
def forward(self, x):
tensor = x
@ -254,11 +264,11 @@ class VariationalEncoder(BaseEncoder):
class Encoder(BaseEncoder):
# noinspection PyUnresolvedReferences
def __init__(self, *args, **kwargs):
super(Encoder, self).__init__(*args, **kwargs)
self.l1 = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
self.l1 = nn.Linear(self.flat.shape, self.lat_dim, bias=self.use_bias)
def forward(self, x):
tensor = super(Encoder, self).forward(x)

View File

@ -6,7 +6,7 @@ from pathlib import Path
import torch
from operator import mul
from torch import nn
from torch import functional as F
from torch.nn import functional as F
import pytorch_lightning as pl
@ -92,7 +92,14 @@ class ShapeMixin:
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
if self.in_shape is not None:
x = torch.randn(self.in_shape)
try:
device = self.device
except AttributeError:
try:
device = next(self.parameters()).device
except StopIteration:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
x = torch.randn(self.in_shape, device=device)
# This is needed for BatchNorm shape checking
x = torch.stack((x, x))
@ -248,7 +255,7 @@ class Splitter(nn.Module):
def __init__(self, in_shape, n, dim=-1):
super(Splitter, self).__init__()
self.in_shape = in_shape
self.in_shape = (in_shape, ) if isinstance(in_shape, int) else in_shape
self.n = n
self.dim = dim if dim > 0 else len(self.in_shape) - abs(dim)

View File

@ -2,7 +2,6 @@ import importlib
import pickle
import shelve
from pathlib import Path, PurePath
from pydoc import safeimport
from typing import Union
import numpy as np
@ -50,6 +49,8 @@ def locate_and_import_class(class_name, models_location: Union[str, PurePath] =
mod = importlib.import_module('.'.join([x.replace('.py', '') for x in module_path.parts]))
try:
model_class = mod.__getattribute__(class_name)
return model_class
except AttributeError:
continue
return model_class
raise AttributeError(f'Check the Model name. Possible model files are:\n{[x.name for x in module_paths]}')