InterSpeech Camera Ready Reporting

This commit is contained in:
Si11ium 2020-08-06 08:12:07 +02:00
parent 4b089729b2
commit a4b6c698c3

View File

@ -5,6 +5,7 @@ import torch
import warnings import warnings
from torch import nn from torch import nn
from torch.nn import functional as F
import sys import sys
sys.path.append(str(Path(__file__).parent)) sys.path.append(str(Path(__file__).parent))
from .util import AutoPad, Interpolate, ShapeMixin, F_x, Flatten from .util import AutoPad, Interpolate, ShapeMixin, F_x, Flatten
@ -42,14 +43,21 @@ class LinearModule(ShapeMixin, nn.Module):
class ConvModule(ShapeMixin, nn.Module): class ConvModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, conv_filters, conv_kernel, activation: nn.Module = nn.ELU, pooling_size=None, def __init__(self, in_shape, conv_filters, conv_kernel, activation: nn.Module = nn.ELU, pooling_size=None,
bias=True, norm=False, dropout: Union[int, float] = 0, bias=True, norm=False, dropout: Union[int, float] = 0, trainable: bool = True,
conv_class=nn.Conv2d, conv_stride=1, conv_padding=0, **kwargs): conv_class=nn.Conv2d, conv_stride=1, conv_padding=0, **kwargs):
super(ConvModule, self).__init__() super(ConvModule, self).__init__()
assert isinstance(in_shape, (tuple, list)), f'"in_shape" should be a [list, tuple], but was {type(in_shape)}' assert isinstance(in_shape, (tuple, list)), f'"in_shape" should be a [list, tuple], but was {type(in_shape)}'
assert len(in_shape) == 3, f'Length should be 3, but was {len(in_shape)}' assert len(in_shape) == 3, f'Length should be 3, but was {len(in_shape)}'
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}') warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
if norm and not trainable:
warnings.warn('You set this module to be not trainable but the running norm is active.\n' +
'We set it to "eval" mode.\n' +
'Keep this in mind if you do a finetunning or retraining step.'
)
# Module Parameters # Module Parameters
self.in_shape = in_shape self.in_shape = in_shape
self.trainable = trainable
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2] in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
# Convolution Parameters # Convolution Parameters
@ -66,6 +74,13 @@ class ConvModule(ShapeMixin, nn.Module):
self.conv = conv_class(in_channels, self.conv_filters, self.conv_kernel, bias=bias, self.conv = conv_class(in_channels, self.conv_filters, self.conv_kernel, bias=bias,
padding=self.padding, stride=self.stride padding=self.padding, stride=self.stride
) )
if not self.trainable:
for param in self.parameters():
param.requires_grad = False
self.norm = self.norm.eval()
else:
pass
def forward(self, x): def forward(self, x):
tensor = self.norm(x) tensor = self.norm(x)
@ -76,6 +91,27 @@ class ConvModule(ShapeMixin, nn.Module):
return tensor return tensor
# TODO class PreInitializedConvModule(ShapeMixin, nn.Module):
class SobelFilter(ShapeMixin, nn.Module):
def __init__(self, in_shape):
super(SobelFilter, self).__init__()
self.sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).view(1, 1, 3, 3)
self.sobel_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, 2, -1]]).view(1, 1, 3, 3)
def forward(self, x):
# Apply Filters
g_x = F.conv2d(x, self.sobel_x)
g_y = F.conv2d(x, self.sobel_y)
# Calculate the Edge
g = torch.add(*[torch.pow(tensor, 2) for tensor in [g_x, g_y]])
# Calculate the Gradient
g_grad = torch.atan2(g_x, g_y)
return g_x, g_y, g, g_grad
class DeConvModule(ShapeMixin, nn.Module): class DeConvModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, conv_filters, conv_kernel, conv_stride=1, conv_padding=0, def __init__(self, in_shape, conv_filters, conv_kernel, conv_stride=1, conv_padding=0,