From a4b6c698c339172725ecc45120e5ee63844592a4 Mon Sep 17 00:00:00 2001 From: Si11ium Date: Thu, 6 Aug 2020 08:12:07 +0200 Subject: [PATCH] InterSpeech Camera Ready Reporting --- modules/blocks.py | 38 +++++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/modules/blocks.py b/modules/blocks.py index 6b50811..791da14 100644 --- a/modules/blocks.py +++ b/modules/blocks.py @@ -5,6 +5,7 @@ import torch import warnings from torch import nn +from torch.nn import functional as F import sys sys.path.append(str(Path(__file__).parent)) from .util import AutoPad, Interpolate, ShapeMixin, F_x, Flatten @@ -42,14 +43,21 @@ class LinearModule(ShapeMixin, nn.Module): class ConvModule(ShapeMixin, nn.Module): def __init__(self, in_shape, conv_filters, conv_kernel, activation: nn.Module = nn.ELU, pooling_size=None, - bias=True, norm=False, dropout: Union[int, float] = 0, + bias=True, norm=False, dropout: Union[int, float] = 0, trainable: bool = True, conv_class=nn.Conv2d, conv_stride=1, conv_padding=0, **kwargs): super(ConvModule, self).__init__() assert isinstance(in_shape, (tuple, list)), f'"in_shape" should be a [list, tuple], but was {type(in_shape)}' assert len(in_shape) == 3, f'Length should be 3, but was {len(in_shape)}' warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}') + if norm and not trainable: + warnings.warn('You set this module to be not trainable but the running norm is active.\n' + + 'We set it to "eval" mode.\n' + + 'Keep this in mind if you do a finetunning or retraining step.' + ) + # Module Parameters self.in_shape = in_shape + self.trainable = trainable in_channels, height, width = in_shape[0], in_shape[1], in_shape[2] # Convolution Parameters @@ -66,6 +74,13 @@ class ConvModule(ShapeMixin, nn.Module): self.conv = conv_class(in_channels, self.conv_filters, self.conv_kernel, bias=bias, padding=self.padding, stride=self.stride ) + if not self.trainable: + for param in self.parameters(): + param.requires_grad = False + self.norm = self.norm.eval() + else: + pass + def forward(self, x): tensor = self.norm(x) @@ -76,6 +91,27 @@ class ConvModule(ShapeMixin, nn.Module): return tensor +# TODO class PreInitializedConvModule(ShapeMixin, nn.Module): + +class SobelFilter(ShapeMixin, nn.Module): + + def __init__(self, in_shape): + super(SobelFilter, self).__init__() + + self.sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).view(1, 1, 3, 3) + self.sobel_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, 2, -1]]).view(1, 1, 3, 3) + + def forward(self, x): + # Apply Filters + g_x = F.conv2d(x, self.sobel_x) + g_y = F.conv2d(x, self.sobel_y) + # Calculate the Edge + g = torch.add(*[torch.pow(tensor, 2) for tensor in [g_x, g_y]]) + # Calculate the Gradient + g_grad = torch.atan2(g_x, g_y) + return g_x, g_y, g, g_grad + + class DeConvModule(ShapeMixin, nn.Module): def __init__(self, in_shape, conv_filters, conv_kernel, conv_stride=1, conv_padding=0,