Al Lot
This commit is contained in:
parent
5848b528f0
commit
f296ba78b9
@ -76,9 +76,9 @@ class NormalizeMelband(object):
|
|||||||
|
|
||||||
|
|
||||||
class AudioToMel(object):
|
class AudioToMel(object):
|
||||||
def __init__(self, amplitude_to_db=False, power_to_db=False, **kwargs):
|
def __init__(self, amplitude_to_db=False, power_to_db=False, **mel_kwargs):
|
||||||
assert not all([amplitude_to_db, power_to_db]), "Choose amplitude_to_db or power_to_db, not both!"
|
assert not all([amplitude_to_db, power_to_db]), "Choose amplitude_to_db or power_to_db, not both!"
|
||||||
self.mel_kwargs = kwargs
|
self.mel_kwargs = mel_kwargs
|
||||||
self.amplitude_to_db = amplitude_to_db
|
self.amplitude_to_db = amplitude_to_db
|
||||||
self.power_to_db = power_to_db
|
self.power_to_db = power_to_db
|
||||||
|
|
||||||
|
@ -59,9 +59,9 @@ class ShiftTime(object):
|
|||||||
# Set to silence for heading/ tailing
|
# Set to silence for heading/ tailing
|
||||||
shift = int(shift)
|
shift = int(shift)
|
||||||
if shift > 0:
|
if shift > 0:
|
||||||
augmented_data[:, :shift] = 0
|
augmented_data[:shift, :] = 0
|
||||||
else:
|
else:
|
||||||
augmented_data[:, shift:] = 0
|
augmented_data[shift:, :] = 0
|
||||||
return augmented_data
|
return augmented_data
|
||||||
else:
|
else:
|
||||||
return x
|
return x
|
||||||
|
21
audio_toolset/mel_transforms.py
Normal file
21
audio_toolset/mel_transforms.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class Normalize(object):
|
||||||
|
|
||||||
|
def __init__(self, min_db_level: Union[int, float]):
|
||||||
|
self.min_db_level = min_db_level
|
||||||
|
|
||||||
|
def __call__(self, s: np.ndarray) -> np.ndarray:
|
||||||
|
return np.clip((s - self.min_db_level) / -self.min_db_level, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class DeNormalize(object):
|
||||||
|
|
||||||
|
def __init__(self, min_db_level: Union[int, float]):
|
||||||
|
self.min_db_level = min_db_level
|
||||||
|
|
||||||
|
def __call__(self, s: np.ndarray) -> np.ndarray:
|
||||||
|
return (np.clip(s, 0, 1) * -self.min_db_level) + self.min_db_level
|
@ -2,14 +2,18 @@
|
|||||||
# Full Model Parts
|
# Full Model Parts
|
||||||
###################
|
###################
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
from functools import reduce
|
||||||
from typing import Union, List, Tuple
|
from typing import Union, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
from operator import mul
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from .util import ShapeMixin, LightningBaseModule
|
from .blocks import ConvModule, DeConvModule, LinearModule
|
||||||
|
|
||||||
|
from .util import ShapeMixin, LightningBaseModule, Flatten
|
||||||
|
|
||||||
|
|
||||||
class AEBaseModule(LightningBaseModule, ABC):
|
class AEBaseModule(LightningBaseModule, ABC):
|
||||||
@ -33,7 +37,7 @@ class AEBaseModule(LightningBaseModule, ABC):
|
|||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
if len(x.shape) == 3:
|
if len(x.shape) == 3:
|
||||||
x = x.unsqueeze(0)
|
x = x.unsqueeze(0)
|
||||||
return self.encoder(x).squeeze()
|
return self.encoder(x)
|
||||||
|
|
||||||
def _find_min_max(self, dataloader):
|
def _find_min_max(self, dataloader):
|
||||||
encodings = list()
|
encodings = list()
|
||||||
@ -60,56 +64,61 @@ class AEBaseModule(LightningBaseModule, ABC):
|
|||||||
return self.decode(random_latent_samples).cpu().detach()
|
return self.decode(random_latent_samples).cpu().detach()
|
||||||
|
|
||||||
def decode(self, z):
|
def decode(self, z):
|
||||||
if len(z.shape) == 1:
|
try:
|
||||||
z = z.unsqueeze(0)
|
if len(z.shape) == 1:
|
||||||
|
z = z.unsqueeze(0)
|
||||||
|
except AttributeError:
|
||||||
|
# Does not seem to be a tensor.
|
||||||
|
pass
|
||||||
return self.decoder(z).squeeze()
|
return self.decoder(z).squeeze()
|
||||||
|
|
||||||
def encode_and_restore(self, x):
|
def encode_and_restore(self, x):
|
||||||
x = x.to(self.device)
|
x = self.transfer_batch_to_device(x, self.device)
|
||||||
if len(x.shape) == 3:
|
if len(x.shape) == 3:
|
||||||
x = x.unsqueeze(0)
|
x = x.unsqueeze(0)
|
||||||
z = self.encode(x)
|
z = self.encode(x)
|
||||||
|
try:
|
||||||
|
z = z.squeeze()
|
||||||
|
except AttributeError:
|
||||||
|
# Does not seem to be a tensor.
|
||||||
|
pass
|
||||||
x_hat = self.decode(z)
|
x_hat = self.decode(z)
|
||||||
|
|
||||||
return Namespace(main_out=x_hat.squeeze(), latent_out=z)
|
return Namespace(main_out=x_hat.squeeze(), latent_out=z)
|
||||||
|
|
||||||
|
|
||||||
class Generator(nn.Module):
|
class Generator(ShapeMixin, nn.Module):
|
||||||
@property
|
|
||||||
def shape(self):
|
|
||||||
x = torch.randn(self.lat_dim).unsqueeze(0)
|
|
||||||
output = self(x)
|
|
||||||
return output.shape[1:]
|
|
||||||
|
|
||||||
# noinspection PyUnresolvedReferences
|
def __init__(self, in_shape, out_channels, re_shape, use_norm=False, use_bias=True,
|
||||||
def __init__(self, out_channels, re_shape, lat_dim, use_norm=False, use_bias=True, dropout: Union[int, float] = 0,
|
dropout: Union[int, float] = 0, interpolations: List[int] = None,
|
||||||
filters: List[int] = None, kernels: List[int] = None, activation=nn.ReLU, **kwargs):
|
filters: List[int] = None, kernels: List[int] = None, activation=nn.ReLU,
|
||||||
|
**kwargs):
|
||||||
super(Generator, self).__init__()
|
super(Generator, self).__init__()
|
||||||
assert filters, '"Filters" has to be a list of int.'
|
assert filters, '"Filters" has to be a list of int.'
|
||||||
assert filters, '"Filters" has to be a list of int.'
|
assert filters, '"Filters" has to be a list of int.'
|
||||||
assert len(filters) == len(kernels), '"Filters" and "Kernels" has to be of same length.'
|
assert len(filters) == len(kernels), '"Filters" and "Kernels" has to be of same length.'
|
||||||
|
|
||||||
self.filters = filters
|
interpolations = interpolations or [2, 2, 2]
|
||||||
self.activation = activation
|
|
||||||
self.inner_activation = activation()
|
self.in_shape = in_shape
|
||||||
|
self.activation = activation()
|
||||||
self.out_activation = None
|
self.out_activation = None
|
||||||
self.lat_dim = lat_dim
|
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.l1 = nn.Linear(self.lat_dim, reduce(mul, re_shape), bias=use_bias)
|
self.l1 = LinearModule(in_shape, reduce(mul, re_shape), bias=use_bias, activation=activation)
|
||||||
# re_shape = (self.feature_mixed_dim // reduce(mul, re_shape[1:]), ) + tuple(re_shape[1:])
|
# re_shape = (self.feature_mixed_dim // reduce(mul, re_shape[1:]), ) + tuple(re_shape[1:])
|
||||||
|
|
||||||
self.flat = Flatten(to=re_shape)
|
self.flat = Flatten(self.l1.shape, to=re_shape)
|
||||||
self.de_conv_list = nn.ModuleList()
|
self.de_conv_list = nn.ModuleList()
|
||||||
|
|
||||||
last_shape = re_shape
|
last_shape = re_shape
|
||||||
for conv_filter, conv_kernel in zip(filters, kernels):
|
for conv_filter, conv_kernel, interpolation in zip(reversed(filters), kernels, interpolations):
|
||||||
self.de_conv_list.append(DeConvModule(last_shape, conv_filters=self.filters[0],
|
self.de_conv_list.append(DeConvModule(last_shape, conv_filters=conv_filter,
|
||||||
conv_kernel=conv_kernel,
|
conv_kernel=conv_kernel,
|
||||||
conv_padding=conv_kernel-2,
|
conv_padding=conv_kernel-2,
|
||||||
conv_stride=conv_filter,
|
conv_stride=1,
|
||||||
normalize=use_norm,
|
normalize=use_norm,
|
||||||
activation=self.activation,
|
activation=activation,
|
||||||
interpolation_scale=2,
|
interpolation_scale=interpolation,
|
||||||
dropout=self.dropout
|
dropout=self.dropout
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -121,7 +130,7 @@ class Generator(nn.Module):
|
|||||||
|
|
||||||
def forward(self, z):
|
def forward(self, z):
|
||||||
tensor = self.l1(z)
|
tensor = self.l1(z)
|
||||||
tensor = self.inner_activation(tensor)
|
tensor = self.activation(tensor)
|
||||||
tensor = self.flat(tensor)
|
tensor = self.flat(tensor)
|
||||||
|
|
||||||
for de_conv in self.de_conv_list:
|
for de_conv in self.de_conv_list:
|
||||||
@ -204,8 +213,9 @@ class BaseEncoder(ShapeMixin, nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
last_shape = self.conv_list[-1].shape
|
last_shape = self.conv_list[-1].shape
|
||||||
|
self.last_conv_shape = last_shape
|
||||||
|
|
||||||
self.flat = Flatten()
|
self.flat = Flatten(self.last_conv_shape)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
tensor = x
|
tensor = x
|
||||||
@ -254,11 +264,11 @@ class VariationalEncoder(BaseEncoder):
|
|||||||
|
|
||||||
|
|
||||||
class Encoder(BaseEncoder):
|
class Encoder(BaseEncoder):
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(Encoder, self).__init__(*args, **kwargs)
|
super(Encoder, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
self.l1 = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
self.l1 = nn.Linear(self.flat.shape, self.lat_dim, bias=self.use_bias)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
tensor = super(Encoder, self).forward(x)
|
tensor = super(Encoder, self).forward(x)
|
||||||
|
@ -6,7 +6,7 @@ from pathlib import Path
|
|||||||
import torch
|
import torch
|
||||||
from operator import mul
|
from operator import mul
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
|
|
||||||
@ -92,7 +92,14 @@ class ShapeMixin:
|
|||||||
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
|
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
|
||||||
|
|
||||||
if self.in_shape is not None:
|
if self.in_shape is not None:
|
||||||
x = torch.randn(self.in_shape)
|
try:
|
||||||
|
device = self.device
|
||||||
|
except AttributeError:
|
||||||
|
try:
|
||||||
|
device = next(self.parameters()).device
|
||||||
|
except StopIteration:
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
x = torch.randn(self.in_shape, device=device)
|
||||||
# This is needed for BatchNorm shape checking
|
# This is needed for BatchNorm shape checking
|
||||||
x = torch.stack((x, x))
|
x = torch.stack((x, x))
|
||||||
|
|
||||||
@ -248,7 +255,7 @@ class Splitter(nn.Module):
|
|||||||
def __init__(self, in_shape, n, dim=-1):
|
def __init__(self, in_shape, n, dim=-1):
|
||||||
super(Splitter, self).__init__()
|
super(Splitter, self).__init__()
|
||||||
|
|
||||||
self.in_shape = in_shape
|
self.in_shape = (in_shape, ) if isinstance(in_shape, int) else in_shape
|
||||||
self.n = n
|
self.n = n
|
||||||
self.dim = dim if dim > 0 else len(self.in_shape) - abs(dim)
|
self.dim = dim if dim > 0 else len(self.in_shape) - abs(dim)
|
||||||
|
|
||||||
|
@ -2,7 +2,6 @@ import importlib
|
|||||||
import pickle
|
import pickle
|
||||||
import shelve
|
import shelve
|
||||||
from pathlib import Path, PurePath
|
from pathlib import Path, PurePath
|
||||||
from pydoc import safeimport
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -50,6 +49,8 @@ def locate_and_import_class(class_name, models_location: Union[str, PurePath] =
|
|||||||
mod = importlib.import_module('.'.join([x.replace('.py', '') for x in module_path.parts]))
|
mod = importlib.import_module('.'.join([x.replace('.py', '') for x in module_path.parts]))
|
||||||
try:
|
try:
|
||||||
model_class = mod.__getattribute__(class_name)
|
model_class = mod.__getattribute__(class_name)
|
||||||
|
return model_class
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
continue
|
continue
|
||||||
return model_class
|
raise AttributeError(f'Check the Model name. Possible model files are:\n{[x.name for x in module_paths]}')
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user