LinearModule
This commit is contained in:
@ -4,25 +4,49 @@ import torch
|
||||
import warnings
|
||||
|
||||
from torch import nn
|
||||
from ml_lib.modules.utils import AutoPad, Interpolate, ShapeMixin
|
||||
|
||||
DEVICE = torch.cuda.is_available()
|
||||
from ml_lib.modules.utils import AutoPad, Interpolate, ShapeMixin, F_x, Flatten
|
||||
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
|
||||
#
|
||||
# Sub - Modules
|
||||
###################
|
||||
class LinearModule(ShapeMixin, nn.Module):
|
||||
|
||||
def __init__(self, in_shape, out_features, activation=None, bias=True,
|
||||
norm=False, dropout: Union[int, float] = 0, **kwargs):
|
||||
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
|
||||
super(LinearModule, self).__init__()
|
||||
|
||||
self.in_shape = in_shape
|
||||
self.flat = Flatten(self.in_shape) if isinstance(self.in_shape, (tuple, list)) else F_x(in_shape)
|
||||
self.dropout = nn.Dropout(dropout) if dropout else F_x(self.flat.shape)
|
||||
self.norm = nn.BatchNorm1d(self.flat.shape) if norm else F_x(self.flat.shape)
|
||||
self.linear = nn.Linear(self.flat.shape, out_features, bias=bias)
|
||||
self.activation = activation() or F_x(self.linear.out_features)
|
||||
|
||||
def forward(self, x):
|
||||
tensor = self.flat(x)
|
||||
tensor = self.norm(tensor)
|
||||
tensor = self.linear(tensor)
|
||||
tensor = self.activation(tensor)
|
||||
return tensor
|
||||
|
||||
|
||||
class ConvModule(ShapeMixin, nn.Module):
|
||||
|
||||
def __init__(self, in_shape, conv_filters, conv_kernel, activation: nn.Module = nn.ELU, pooling_size=None,
|
||||
bias=True, norm=False, dropout: Union[int, float] = 0,
|
||||
conv_class=nn.Conv2d, conv_stride=1, conv_padding=0, **kwargs):
|
||||
super(ConvModule, self).__init__()
|
||||
assert isinstance(in_shape, (tuple, list)), f'"in_shape" should be a [list, tuple], but was {type(in_shape)}'
|
||||
assert len(in_shape) == 3, f'Length should be 3, but was {len(in_shape)}'
|
||||
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
|
||||
# Module Parameters
|
||||
self.in_shape = in_shape
|
||||
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
||||
self.activation = activation()
|
||||
|
||||
# Convolution Parameters
|
||||
self.padding = conv_padding
|
||||
@ -31,16 +55,17 @@ class ConvModule(ShapeMixin, nn.Module):
|
||||
self.conv_kernel = conv_kernel
|
||||
|
||||
# Modules
|
||||
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
|
||||
self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else lambda x: x
|
||||
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if norm else lambda x: x
|
||||
self.activation = activation() or F_x(None)
|
||||
self.dropout = nn.Dropout2d(dropout) if dropout else F_x(None)
|
||||
self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else F_x(None)
|
||||
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if norm else F_x(None)
|
||||
self.conv = conv_class(in_channels, self.conv_filters, self.conv_kernel, bias=bias,
|
||||
padding=self.padding, stride=self.stride
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
tensor = self.conv(x)
|
||||
tensor = self.norm(x)
|
||||
tensor = self.conv(tensor)
|
||||
tensor = self.dropout(tensor)
|
||||
tensor = self.pooling(tensor)
|
||||
tensor = self.activation(tensor)
|
||||
|
Reference in New Issue
Block a user