BandwiseBinaryClassifier is work in progress; TODO: Shape Piping.

This commit is contained in:
Si11ium
2020-05-04 18:45:12 +02:00
parent 6d8fbd7184
commit f285200917
6 changed files with 123 additions and 94 deletions

View File

@ -1,22 +1,15 @@
from typing import Union
import warnings
import torch
from torch import nn
from ml_lib.modules.utils import AutoPad, Interpolate
from ml_lib.modules.utils import AutoPad, Interpolate, ShapeMixin
#
# Sub - Modules
###################
class ConvModule(nn.Module):
@property
def shape(self):
x = torch.randn(self.in_shape).unsqueeze(0)
output = self(x)
return output.shape[1:]
class ConvModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, conv_filters, conv_kernel, activation: nn.Module = nn.ELU, pooling_size=None,
bias=True, norm=False, dropout: Union[int, float] = 0,
@ -51,13 +44,7 @@ class ConvModule(nn.Module):
return tensor
class DeConvModule(nn.Module):
@property
def shape(self):
x = torch.randn(self.in_shape).unsqueeze(0)
output = self(x)
return output.shape[1:]
class DeConvModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, conv_filters, conv_kernel, conv_stride=1, conv_padding=0,
dropout: Union[int, float] = 0, autopad=0,
@ -91,13 +78,7 @@ class DeConvModule(nn.Module):
return tensor
class ResidualModule(nn.Module):
@property
def shape(self):
x = torch.randn(self.in_shape).unsqueeze(0)
output = self(x)
return output.shape[1:]
class ResidualModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, module_class, n, activation=None, **module_parameters):
assert n >= 1
@ -118,13 +99,7 @@ class ResidualModule(nn.Module):
return tensor
class RecurrentModule(nn.Module):
@property
def shape(self):
x = torch.randn(self.in_shape).unsqueeze(0)
output = self(x)
return output.shape[1:]
class RecurrentModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, hidden_size, num_layers=1, cell_type=nn.GRU, bias=True, dropout=0):
super(RecurrentModule, self).__init__()