InterSpeech Camera Ready Reporting
This commit is contained in:
		| @@ -5,6 +5,7 @@ import torch | ||||
| import warnings | ||||
|  | ||||
| from torch import nn | ||||
| from torch.nn import functional as F | ||||
| import sys | ||||
| sys.path.append(str(Path(__file__).parent)) | ||||
| from .util import AutoPad, Interpolate, ShapeMixin, F_x, Flatten | ||||
| @@ -42,14 +43,21 @@ class LinearModule(ShapeMixin, nn.Module): | ||||
| class ConvModule(ShapeMixin, nn.Module): | ||||
|  | ||||
|     def __init__(self, in_shape, conv_filters, conv_kernel, activation: nn.Module = nn.ELU, pooling_size=None, | ||||
|                  bias=True, norm=False, dropout: Union[int, float] = 0, | ||||
|                  bias=True, norm=False, dropout: Union[int, float] = 0, trainable: bool = True, | ||||
|                  conv_class=nn.Conv2d, conv_stride=1, conv_padding=0, **kwargs): | ||||
|         super(ConvModule, self).__init__() | ||||
|         assert isinstance(in_shape, (tuple, list)), f'"in_shape" should be a [list, tuple], but was {type(in_shape)}' | ||||
|         assert len(in_shape) == 3, f'Length should be 3, but was {len(in_shape)}' | ||||
|         warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}') | ||||
|         if norm and not trainable: | ||||
|             warnings.warn('You set this module to be not trainable but the running norm is active.\n' + | ||||
|                           'We set it to "eval" mode.\n' + | ||||
|                           'Keep this in mind if you do a finetunning or retraining step.' | ||||
|                           ) | ||||
|  | ||||
|         # Module Parameters | ||||
|         self.in_shape = in_shape | ||||
|         self.trainable = trainable | ||||
|         in_channels, height, width = in_shape[0], in_shape[1], in_shape[2] | ||||
|  | ||||
|         # Convolution Parameters | ||||
| @@ -66,6 +74,13 @@ class ConvModule(ShapeMixin, nn.Module): | ||||
|         self.conv = conv_class(in_channels, self.conv_filters, self.conv_kernel, bias=bias, | ||||
|                                padding=self.padding, stride=self.stride | ||||
|                                ) | ||||
|         if not self.trainable: | ||||
|             for param in self.parameters(): | ||||
|                 param.requires_grad = False | ||||
|             self.norm = self.norm.eval() | ||||
|         else: | ||||
|             pass | ||||
|  | ||||
|  | ||||
|     def forward(self, x): | ||||
|         tensor = self.norm(x) | ||||
| @@ -76,6 +91,27 @@ class ConvModule(ShapeMixin, nn.Module): | ||||
|         return tensor | ||||
|  | ||||
|  | ||||
| # TODO class PreInitializedConvModule(ShapeMixin, nn.Module): | ||||
|  | ||||
| class SobelFilter(ShapeMixin, nn.Module): | ||||
|  | ||||
|     def __init__(self, in_shape): | ||||
|         super(SobelFilter, self).__init__() | ||||
|  | ||||
|         self.sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).view(1, 1, 3, 3) | ||||
|         self.sobel_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, 2, -1]]).view(1, 1, 3, 3) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         # Apply Filters | ||||
|         g_x = F.conv2d(x, self.sobel_x) | ||||
|         g_y = F.conv2d(x, self.sobel_y) | ||||
|         # Calculate the Edge | ||||
|         g = torch.add(*[torch.pow(tensor, 2) for tensor in [g_x, g_y]]) | ||||
|         # Calculate the Gradient | ||||
|         g_grad = torch.atan2(g_x, g_y) | ||||
|         return g_x, g_y, g, g_grad | ||||
|  | ||||
|  | ||||
| class DeConvModule(ShapeMixin, nn.Module): | ||||
|  | ||||
|     def __init__(self, in_shape, conv_filters, conv_kernel, conv_stride=1, conv_padding=0, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Si11ium
					Si11ium