From f5c240f03887d2354da7e503dd47aa7ef81412de Mon Sep 17 00:00:00 2001 From: steffen Date: Fri, 17 Apr 2020 18:07:54 +0200 Subject: [PATCH] Hparams passing with user warnings --- audio_toolset/audio_io.py | 24 +++++++++++++----------- modules/blocks.py | 5 +++-- modules/utils.py | 8 ++++++++ utils/config.py | 2 +- utils/logging.py | 4 ++-- 5 files changed, 27 insertions(+), 16 deletions(-) diff --git a/audio_toolset/audio_io.py b/audio_toolset/audio_io.py index 839fafd..440194c 100644 --- a/audio_toolset/audio_io.py +++ b/audio_toolset/audio_io.py @@ -3,12 +3,13 @@ import torch from scipy.signal import butter, lfilter from ml_lib.modules.utils import AutoPad - +import numpy as np def butter_lowpass(cutoff, sr, order=5): nyq = 0.5 * sr normal_cutoff = cutoff / nyq - b, a = butter(order, normal_cutoff, btype='low', analog=False) + # noinspection PyTupleAssignmentBalance + b, a = butter(order, normal_cutoff, btype='low', analog=False, output='ba') return b, a @@ -57,18 +58,19 @@ class NormalizeMelband(object): return x -class AutoPadTransform(object): - def __init__(self, **kwargs): - self.__dict__.update(kwargs) - self.padder = AutoPad() +class AutoPadToShape(object): + def __init__(self, shape): + self.shape = shape - def __call__(self, y): - if not torch.is_tensor(y): - y = torch.as_tensor(y) - return self.padder(y) + def __call__(self, x): + if not torch.is_tensor(x): + x = torch.as_tensor(x) + embedding = torch.zeros(self.shape) + embedding[: x.shape] = x + return embedding def __repr__(self): - return 'AutoPadTransform()' + return f'AutoPadTransform({self.shape})' class Melspectogram(object): diff --git a/modules/blocks.py b/modules/blocks.py index 316f528..ac103ac 100644 --- a/modules/blocks.py +++ b/modules/blocks.py @@ -1,4 +1,5 @@ from typing import Union +import warnings import torch from torch import nn @@ -19,9 +20,9 @@ class ConvModule(nn.Module): def __init__(self, in_shape, conv_filters, conv_kernel, activation: nn.Module = nn.ELU, pooling_size=None, use_bias=True, use_norm=False, dropout: Union[int, float] = 0, - conv_class=nn.Conv2d, conv_stride=1, conv_padding=0): + conv_class=nn.Conv2d, conv_stride=1, conv_padding=0, **kwargs): super(ConvModule, self).__init__() - + warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}') # Module Parameters self.in_shape = in_shape in_channels, height, width = in_shape[0], in_shape[1], in_shape[2] diff --git a/modules/utils.py b/modules/utils.py index 6fe0338..c307b9d 100644 --- a/modules/utils.py +++ b/modules/utils.py @@ -11,8 +11,16 @@ import pytorch_lightning as pl # Utility - Modules ################### +class F_x(object): + def __init__(self): + pass + + def __call__(self, x): + return x +# Utility - Modules +################### class Flatten(nn.Module): @property diff --git a/utils/config.py b/utils/config.py index d67bccd..a6989bb 100644 --- a/utils/config.py +++ b/utils/config.py @@ -116,7 +116,7 @@ class Config(ConfigParser, ABC): def build_model(self): return self.model_class(self.model_paramters) - def build_and_init_model(self, weight_init_function): + def build_and_init_model(self, in_shape, weight_init_function): model = self.build_model() model.init_weights(weight_init_function) return model diff --git a/utils/logging.py b/utils/logging.py index 74f233d..9597596 100644 --- a/utils/logging.py +++ b/utils/logging.py @@ -37,8 +37,8 @@ class Logger(LightningLoggerBase): @property def outpath(self): - # ToDo: Add further path modification such as dataset config etc. - return Path(self.config.train.outpath) / self.config.data.mode + # FIXME: Move this out of here, this is not the right place to do this!!! + return Path(self.config.train.outpath) / self.config.model.type def __init__(self, config: Config): """