BandwiseBinaryClassifier is work in progress; TODO: Shape Piping.
This commit is contained in:
@ -1,22 +1,15 @@
|
||||
from typing import Union
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from ml_lib.modules.utils import AutoPad, Interpolate
|
||||
from ml_lib.modules.utils import AutoPad, Interpolate, ShapeMixin
|
||||
|
||||
|
||||
#
|
||||
# Sub - Modules
|
||||
###################
|
||||
|
||||
class ConvModule(nn.Module):
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||
output = self(x)
|
||||
return output.shape[1:]
|
||||
class ConvModule(ShapeMixin, nn.Module):
|
||||
|
||||
def __init__(self, in_shape, conv_filters, conv_kernel, activation: nn.Module = nn.ELU, pooling_size=None,
|
||||
bias=True, norm=False, dropout: Union[int, float] = 0,
|
||||
@ -51,13 +44,7 @@ class ConvModule(nn.Module):
|
||||
return tensor
|
||||
|
||||
|
||||
class DeConvModule(nn.Module):
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||
output = self(x)
|
||||
return output.shape[1:]
|
||||
class DeConvModule(ShapeMixin, nn.Module):
|
||||
|
||||
def __init__(self, in_shape, conv_filters, conv_kernel, conv_stride=1, conv_padding=0,
|
||||
dropout: Union[int, float] = 0, autopad=0,
|
||||
@ -91,13 +78,7 @@ class DeConvModule(nn.Module):
|
||||
return tensor
|
||||
|
||||
|
||||
class ResidualModule(nn.Module):
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||
output = self(x)
|
||||
return output.shape[1:]
|
||||
class ResidualModule(ShapeMixin, nn.Module):
|
||||
|
||||
def __init__(self, in_shape, module_class, n, activation=None, **module_parameters):
|
||||
assert n >= 1
|
||||
@ -118,13 +99,7 @@ class ResidualModule(nn.Module):
|
||||
return tensor
|
||||
|
||||
|
||||
class RecurrentModule(nn.Module):
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||
output = self(x)
|
||||
return output.shape[1:]
|
||||
class RecurrentModule(ShapeMixin, nn.Module):
|
||||
|
||||
def __init__(self, in_shape, hidden_size, num_layers=1, cell_type=nn.GRU, bias=True, dropout=0):
|
||||
super(RecurrentModule, self).__init__()
|
||||
|
Reference in New Issue
Block a user