LinearModule
This commit is contained in:
parent
d2e74ff33a
commit
f6c6726509
@ -4,25 +4,49 @@ import torch
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from ml_lib.modules.utils import AutoPad, Interpolate, ShapeMixin
|
|
||||||
|
|
||||||
DEVICE = torch.cuda.is_available()
|
from ml_lib.modules.utils import AutoPad, Interpolate, ShapeMixin, F_x, Flatten
|
||||||
|
|
||||||
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Sub - Modules
|
# Sub - Modules
|
||||||
###################
|
###################
|
||||||
|
class LinearModule(ShapeMixin, nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_shape, out_features, activation=None, bias=True,
|
||||||
|
norm=False, dropout: Union[int, float] = 0, **kwargs):
|
||||||
|
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
|
||||||
|
super(LinearModule, self).__init__()
|
||||||
|
|
||||||
|
self.in_shape = in_shape
|
||||||
|
self.flat = Flatten(self.in_shape) if isinstance(self.in_shape, (tuple, list)) else F_x(in_shape)
|
||||||
|
self.dropout = nn.Dropout(dropout) if dropout else F_x(self.flat.shape)
|
||||||
|
self.norm = nn.BatchNorm1d(self.flat.shape) if norm else F_x(self.flat.shape)
|
||||||
|
self.linear = nn.Linear(self.flat.shape, out_features, bias=bias)
|
||||||
|
self.activation = activation() or F_x(self.linear.out_features)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
tensor = self.flat(x)
|
||||||
|
tensor = self.norm(tensor)
|
||||||
|
tensor = self.linear(tensor)
|
||||||
|
tensor = self.activation(tensor)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
class ConvModule(ShapeMixin, nn.Module):
|
class ConvModule(ShapeMixin, nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_shape, conv_filters, conv_kernel, activation: nn.Module = nn.ELU, pooling_size=None,
|
def __init__(self, in_shape, conv_filters, conv_kernel, activation: nn.Module = nn.ELU, pooling_size=None,
|
||||||
bias=True, norm=False, dropout: Union[int, float] = 0,
|
bias=True, norm=False, dropout: Union[int, float] = 0,
|
||||||
conv_class=nn.Conv2d, conv_stride=1, conv_padding=0, **kwargs):
|
conv_class=nn.Conv2d, conv_stride=1, conv_padding=0, **kwargs):
|
||||||
super(ConvModule, self).__init__()
|
super(ConvModule, self).__init__()
|
||||||
|
assert isinstance(in_shape, (tuple, list)), f'"in_shape" should be a [list, tuple], but was {type(in_shape)}'
|
||||||
|
assert len(in_shape) == 3, f'Length should be 3, but was {len(in_shape)}'
|
||||||
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
|
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
|
||||||
# Module Parameters
|
# Module Parameters
|
||||||
self.in_shape = in_shape
|
self.in_shape = in_shape
|
||||||
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
||||||
self.activation = activation()
|
|
||||||
|
|
||||||
# Convolution Parameters
|
# Convolution Parameters
|
||||||
self.padding = conv_padding
|
self.padding = conv_padding
|
||||||
@ -31,16 +55,17 @@ class ConvModule(ShapeMixin, nn.Module):
|
|||||||
self.conv_kernel = conv_kernel
|
self.conv_kernel = conv_kernel
|
||||||
|
|
||||||
# Modules
|
# Modules
|
||||||
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
|
self.activation = activation() or F_x(None)
|
||||||
self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else lambda x: x
|
self.dropout = nn.Dropout2d(dropout) if dropout else F_x(None)
|
||||||
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if norm else lambda x: x
|
self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else F_x(None)
|
||||||
|
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if norm else F_x(None)
|
||||||
self.conv = conv_class(in_channels, self.conv_filters, self.conv_kernel, bias=bias,
|
self.conv = conv_class(in_channels, self.conv_filters, self.conv_kernel, bias=bias,
|
||||||
padding=self.padding, stride=self.stride
|
padding=self.padding, stride=self.stride
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.norm(x)
|
tensor = self.norm(x)
|
||||||
tensor = self.conv(x)
|
tensor = self.conv(tensor)
|
||||||
tensor = self.dropout(tensor)
|
tensor = self.dropout(tensor)
|
||||||
tensor = self.pooling(tensor)
|
tensor = self.pooling(tensor)
|
||||||
tensor = self.activation(tensor)
|
tensor = self.activation(tensor)
|
||||||
|
@ -4,7 +4,6 @@ from pathlib import Path
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch import functional as F
|
from torch import functional as F
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
|
|
||||||
@ -14,25 +13,26 @@ import pytorch_lightning as pl
|
|||||||
from ml_lib.utils.model_io import ModelParameters
|
from ml_lib.utils.model_io import ModelParameters
|
||||||
|
|
||||||
|
|
||||||
class F_x(object):
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __call__(self, x):
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class ShapeMixin:
|
class ShapeMixin:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self):
|
def shape(self):
|
||||||
try:
|
assert isinstance(self, (LightningBaseModule, nn.Module))
|
||||||
x = torch.randn(self.in_shape).unsqueeze(0)
|
|
||||||
output = self(x)
|
x = torch.randn(self.in_shape)
|
||||||
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
|
# This is needed for BatchNorm shape checking
|
||||||
except Exception as e:
|
x = torch.stack((x, x))
|
||||||
print(e)
|
output = self(x)
|
||||||
return -1
|
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
|
||||||
|
|
||||||
|
|
||||||
|
class F_x(ShapeMixin, nn.Module):
|
||||||
|
def __init__(self, in_shape):
|
||||||
|
super(F_x, self).__init__()
|
||||||
|
self.in_shape = in_shape
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
# Utility - Modules
|
# Utility - Modules
|
||||||
@ -128,9 +128,6 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
|||||||
def size(self):
|
def size(self):
|
||||||
return self.shape
|
return self.shape
|
||||||
|
|
||||||
def _move_to_model_device(self, x):
|
|
||||||
return x.cuda() if next(self.parameters()).is_cuda else x.cpu()
|
|
||||||
|
|
||||||
def save_to_disk(self, model_path):
|
def save_to_disk(self, model_path):
|
||||||
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
|
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
|
||||||
if not (model_path / 'model_class.obj').exists():
|
if not (model_path / 'model_class.obj').exists():
|
||||||
@ -207,7 +204,7 @@ class AutoPadToShape(object):
|
|||||||
x = torch.as_tensor(x)
|
x = torch.as_tensor(x)
|
||||||
if x.shape[1:] == self.shape:
|
if x.shape[1:] == self.shape:
|
||||||
return x
|
return x
|
||||||
embedding = torch.zeros((x.shape[0], *self.shape), device='cuda' if x.is_cuda else'cpu')
|
embedding = torch.zeros((x.shape[0], *self.shape))
|
||||||
embedding[:, :x.shape[1], :x.shape[2], :x.shape[3]] = x
|
embedding[:, :x.shape[1], :x.shape[2], :x.shape[3]] = x
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
@ -18,19 +18,12 @@ class ModelParameters(Namespace, Mapping):
|
|||||||
|
|
||||||
paramter_mapping.update(
|
paramter_mapping.update(
|
||||||
dict(
|
dict(
|
||||||
activation=self._activations[paramter_mapping['activation']]
|
activation=self._activations[self['activation']]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return paramter_mapping
|
return paramter_mapping
|
||||||
|
|
||||||
@property
|
|
||||||
def test_activation(self):
|
|
||||||
try:
|
|
||||||
return self._activations[self.model.activation]
|
|
||||||
except KeyError:
|
|
||||||
return nn.ReLU
|
|
||||||
|
|
||||||
def __getitem__(self, k):
|
def __getitem__(self, k):
|
||||||
# k: _KT -> _VT_co
|
# k: _KT -> _VT_co
|
||||||
return self.__dict__[k]
|
return self.__dict__[k]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user