Hparams passing with user warnings
This commit is contained in:
@@ -3,12 +3,13 @@ import torch
|
|||||||
from scipy.signal import butter, lfilter
|
from scipy.signal import butter, lfilter
|
||||||
|
|
||||||
from ml_lib.modules.utils import AutoPad
|
from ml_lib.modules.utils import AutoPad
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
def butter_lowpass(cutoff, sr, order=5):
|
def butter_lowpass(cutoff, sr, order=5):
|
||||||
nyq = 0.5 * sr
|
nyq = 0.5 * sr
|
||||||
normal_cutoff = cutoff / nyq
|
normal_cutoff = cutoff / nyq
|
||||||
b, a = butter(order, normal_cutoff, btype='low', analog=False)
|
# noinspection PyTupleAssignmentBalance
|
||||||
|
b, a = butter(order, normal_cutoff, btype='low', analog=False, output='ba')
|
||||||
return b, a
|
return b, a
|
||||||
|
|
||||||
|
|
||||||
@@ -57,18 +58,19 @@ class NormalizeMelband(object):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class AutoPadTransform(object):
|
class AutoPadToShape(object):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, shape):
|
||||||
self.__dict__.update(kwargs)
|
self.shape = shape
|
||||||
self.padder = AutoPad()
|
|
||||||
|
|
||||||
def __call__(self, y):
|
def __call__(self, x):
|
||||||
if not torch.is_tensor(y):
|
if not torch.is_tensor(x):
|
||||||
y = torch.as_tensor(y)
|
x = torch.as_tensor(x)
|
||||||
return self.padder(y)
|
embedding = torch.zeros(self.shape)
|
||||||
|
embedding[: x.shape] = x
|
||||||
|
return embedding
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return 'AutoPadTransform()'
|
return f'AutoPadTransform({self.shape})'
|
||||||
|
|
||||||
|
|
||||||
class Melspectogram(object):
|
class Melspectogram(object):
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -19,9 +20,9 @@ class ConvModule(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, in_shape, conv_filters, conv_kernel, activation: nn.Module = nn.ELU, pooling_size=None,
|
def __init__(self, in_shape, conv_filters, conv_kernel, activation: nn.Module = nn.ELU, pooling_size=None,
|
||||||
use_bias=True, use_norm=False, dropout: Union[int, float] = 0,
|
use_bias=True, use_norm=False, dropout: Union[int, float] = 0,
|
||||||
conv_class=nn.Conv2d, conv_stride=1, conv_padding=0):
|
conv_class=nn.Conv2d, conv_stride=1, conv_padding=0, **kwargs):
|
||||||
super(ConvModule, self).__init__()
|
super(ConvModule, self).__init__()
|
||||||
|
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
|
||||||
# Module Parameters
|
# Module Parameters
|
||||||
self.in_shape = in_shape
|
self.in_shape = in_shape
|
||||||
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
||||||
|
|||||||
@@ -11,8 +11,16 @@ import pytorch_lightning as pl
|
|||||||
|
|
||||||
# Utility - Modules
|
# Utility - Modules
|
||||||
###################
|
###################
|
||||||
|
class F_x(object):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# Utility - Modules
|
||||||
|
###################
|
||||||
class Flatten(nn.Module):
|
class Flatten(nn.Module):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -116,7 +116,7 @@ class Config(ConfigParser, ABC):
|
|||||||
def build_model(self):
|
def build_model(self):
|
||||||
return self.model_class(self.model_paramters)
|
return self.model_class(self.model_paramters)
|
||||||
|
|
||||||
def build_and_init_model(self, weight_init_function):
|
def build_and_init_model(self, in_shape, weight_init_function):
|
||||||
model = self.build_model()
|
model = self.build_model()
|
||||||
model.init_weights(weight_init_function)
|
model.init_weights(weight_init_function)
|
||||||
return model
|
return model
|
||||||
|
|||||||
@@ -37,8 +37,8 @@ class Logger(LightningLoggerBase):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def outpath(self):
|
def outpath(self):
|
||||||
# ToDo: Add further path modification such as dataset config etc.
|
# FIXME: Move this out of here, this is not the right place to do this!!!
|
||||||
return Path(self.config.train.outpath) / self.config.data.mode
|
return Path(self.config.train.outpath) / self.config.model.type
|
||||||
|
|
||||||
def __init__(self, config: Config):
|
def __init__(self, config: Config):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user