InterSpeech Camera Ready Reporting
This commit is contained in:
parent
4b089729b2
commit
a4b6c698c3
@ -5,6 +5,7 @@ import torch
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
import sys
|
import sys
|
||||||
sys.path.append(str(Path(__file__).parent))
|
sys.path.append(str(Path(__file__).parent))
|
||||||
from .util import AutoPad, Interpolate, ShapeMixin, F_x, Flatten
|
from .util import AutoPad, Interpolate, ShapeMixin, F_x, Flatten
|
||||||
@ -42,14 +43,21 @@ class LinearModule(ShapeMixin, nn.Module):
|
|||||||
class ConvModule(ShapeMixin, nn.Module):
|
class ConvModule(ShapeMixin, nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_shape, conv_filters, conv_kernel, activation: nn.Module = nn.ELU, pooling_size=None,
|
def __init__(self, in_shape, conv_filters, conv_kernel, activation: nn.Module = nn.ELU, pooling_size=None,
|
||||||
bias=True, norm=False, dropout: Union[int, float] = 0,
|
bias=True, norm=False, dropout: Union[int, float] = 0, trainable: bool = True,
|
||||||
conv_class=nn.Conv2d, conv_stride=1, conv_padding=0, **kwargs):
|
conv_class=nn.Conv2d, conv_stride=1, conv_padding=0, **kwargs):
|
||||||
super(ConvModule, self).__init__()
|
super(ConvModule, self).__init__()
|
||||||
assert isinstance(in_shape, (tuple, list)), f'"in_shape" should be a [list, tuple], but was {type(in_shape)}'
|
assert isinstance(in_shape, (tuple, list)), f'"in_shape" should be a [list, tuple], but was {type(in_shape)}'
|
||||||
assert len(in_shape) == 3, f'Length should be 3, but was {len(in_shape)}'
|
assert len(in_shape) == 3, f'Length should be 3, but was {len(in_shape)}'
|
||||||
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
|
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
|
||||||
|
if norm and not trainable:
|
||||||
|
warnings.warn('You set this module to be not trainable but the running norm is active.\n' +
|
||||||
|
'We set it to "eval" mode.\n' +
|
||||||
|
'Keep this in mind if you do a finetunning or retraining step.'
|
||||||
|
)
|
||||||
|
|
||||||
# Module Parameters
|
# Module Parameters
|
||||||
self.in_shape = in_shape
|
self.in_shape = in_shape
|
||||||
|
self.trainable = trainable
|
||||||
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
||||||
|
|
||||||
# Convolution Parameters
|
# Convolution Parameters
|
||||||
@ -66,6 +74,13 @@ class ConvModule(ShapeMixin, nn.Module):
|
|||||||
self.conv = conv_class(in_channels, self.conv_filters, self.conv_kernel, bias=bias,
|
self.conv = conv_class(in_channels, self.conv_filters, self.conv_kernel, bias=bias,
|
||||||
padding=self.padding, stride=self.stride
|
padding=self.padding, stride=self.stride
|
||||||
)
|
)
|
||||||
|
if not self.trainable:
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
self.norm = self.norm.eval()
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
tensor = self.norm(x)
|
tensor = self.norm(x)
|
||||||
@ -76,6 +91,27 @@ class ConvModule(ShapeMixin, nn.Module):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
# TODO class PreInitializedConvModule(ShapeMixin, nn.Module):
|
||||||
|
|
||||||
|
class SobelFilter(ShapeMixin, nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_shape):
|
||||||
|
super(SobelFilter, self).__init__()
|
||||||
|
|
||||||
|
self.sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).view(1, 1, 3, 3)
|
||||||
|
self.sobel_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, 2, -1]]).view(1, 1, 3, 3)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# Apply Filters
|
||||||
|
g_x = F.conv2d(x, self.sobel_x)
|
||||||
|
g_y = F.conv2d(x, self.sobel_y)
|
||||||
|
# Calculate the Edge
|
||||||
|
g = torch.add(*[torch.pow(tensor, 2) for tensor in [g_x, g_y]])
|
||||||
|
# Calculate the Gradient
|
||||||
|
g_grad = torch.atan2(g_x, g_y)
|
||||||
|
return g_x, g_y, g, g_grad
|
||||||
|
|
||||||
|
|
||||||
class DeConvModule(ShapeMixin, nn.Module):
|
class DeConvModule(ShapeMixin, nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_shape, conv_filters, conv_kernel, conv_stride=1, conv_padding=0,
|
def __init__(self, in_shape, conv_filters, conv_kernel, conv_stride=1, conv_padding=0,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user