Hparams passing with user warnings

This commit is contained in:
steffen
2020-04-17 18:07:54 +02:00
parent e53107420d
commit f5c240f038
5 changed files with 27 additions and 16 deletions

View File

@@ -3,12 +3,13 @@ import torch
from scipy.signal import butter, lfilter from scipy.signal import butter, lfilter
from ml_lib.modules.utils import AutoPad from ml_lib.modules.utils import AutoPad
import numpy as np
def butter_lowpass(cutoff, sr, order=5): def butter_lowpass(cutoff, sr, order=5):
nyq = 0.5 * sr nyq = 0.5 * sr
normal_cutoff = cutoff / nyq normal_cutoff = cutoff / nyq
b, a = butter(order, normal_cutoff, btype='low', analog=False) # noinspection PyTupleAssignmentBalance
b, a = butter(order, normal_cutoff, btype='low', analog=False, output='ba')
return b, a return b, a
@@ -57,18 +58,19 @@ class NormalizeMelband(object):
return x return x
class AutoPadTransform(object): class AutoPadToShape(object):
def __init__(self, **kwargs): def __init__(self, shape):
self.__dict__.update(kwargs) self.shape = shape
self.padder = AutoPad()
def __call__(self, y): def __call__(self, x):
if not torch.is_tensor(y): if not torch.is_tensor(x):
y = torch.as_tensor(y) x = torch.as_tensor(x)
return self.padder(y) embedding = torch.zeros(self.shape)
embedding[: x.shape] = x
return embedding
def __repr__(self): def __repr__(self):
return 'AutoPadTransform()' return f'AutoPadTransform({self.shape})'
class Melspectogram(object): class Melspectogram(object):

View File

@@ -1,4 +1,5 @@
from typing import Union from typing import Union
import warnings
import torch import torch
from torch import nn from torch import nn
@@ -19,9 +20,9 @@ class ConvModule(nn.Module):
def __init__(self, in_shape, conv_filters, conv_kernel, activation: nn.Module = nn.ELU, pooling_size=None, def __init__(self, in_shape, conv_filters, conv_kernel, activation: nn.Module = nn.ELU, pooling_size=None,
use_bias=True, use_norm=False, dropout: Union[int, float] = 0, use_bias=True, use_norm=False, dropout: Union[int, float] = 0,
conv_class=nn.Conv2d, conv_stride=1, conv_padding=0): conv_class=nn.Conv2d, conv_stride=1, conv_padding=0, **kwargs):
super(ConvModule, self).__init__() super(ConvModule, self).__init__()
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
# Module Parameters # Module Parameters
self.in_shape = in_shape self.in_shape = in_shape
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2] in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]

View File

@@ -11,8 +11,16 @@ import pytorch_lightning as pl
# Utility - Modules # Utility - Modules
################### ###################
class F_x(object):
def __init__(self):
pass
def __call__(self, x):
return x
# Utility - Modules
###################
class Flatten(nn.Module): class Flatten(nn.Module):
@property @property

View File

@@ -116,7 +116,7 @@ class Config(ConfigParser, ABC):
def build_model(self): def build_model(self):
return self.model_class(self.model_paramters) return self.model_class(self.model_paramters)
def build_and_init_model(self, weight_init_function): def build_and_init_model(self, in_shape, weight_init_function):
model = self.build_model() model = self.build_model()
model.init_weights(weight_init_function) model.init_weights(weight_init_function)
return model return model

View File

@@ -37,8 +37,8 @@ class Logger(LightningLoggerBase):
@property @property
def outpath(self): def outpath(self):
# ToDo: Add further path modification such as dataset config etc. # FIXME: Move this out of here, this is not the right place to do this!!!
return Path(self.config.train.outpath) / self.config.data.mode return Path(self.config.train.outpath) / self.config.model.type
def __init__(self, config: Config): def __init__(self, config: Config):
""" """