CNN Classifier

This commit is contained in:
Si11ium
2020-02-21 09:44:09 +01:00
parent 537e5371c9
commit 7b3f781d19
12 changed files with 247 additions and 109 deletions

View File

@ -22,7 +22,7 @@ class ConvModule(nn.Module):
return output.shape[1:]
def __init__(self, in_shape, activation: nn.Module = nn.ELU, pooling_size=None, use_bias=True, use_norm=True,
dropout: Union[int, float] = 0,
dropout: Union[int, float] = 0, conv_class=nn.Conv2d,
conv_filters=64, conv_kernel=5, conv_stride=1, conv_padding=0):
super(ConvModule, self).__init__()
@ -39,9 +39,9 @@ class ConvModule(nn.Module):
self.dropout = nn.Dropout2d(dropout) if dropout else False
self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else False
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if use_norm else False
self.conv = nn.Conv2d(in_channels, conv_filters, conv_kernel, bias=use_bias,
padding=self.padding, stride=self.stride
)
self.conv = conv_class(in_channels, conv_filters, conv_kernel, bias=use_bias,
padding=self.padding, stride=self.stride
)
def forward(self, x):
x = self.norm(x) if self.norm else x
@ -91,8 +91,30 @@ class DeConvModule(nn.Module):
tensor = self.activation(tensor) if self.activation else tensor
return tensor
def size(self):
return self.shape
class ResidualModule(nn.Module):
@property
def shape(self):
x = torch.randn(self.in_shape).unsqueeze(0)
output = self(x)
return output.shape[1:]
def __init__(self, in_shape, module_class, n, **module_paramters):
assert n >= 1
super(ResidualModule, self).__init__()
self.in_shape = in_shape
module_paramters.update(in_shape=in_shape)
self.residual_block = [module_class(**module_paramters) for x in range(n)]
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
def forward(self, x):
for module in self.residual_block:
tensor = module(x)
# noinspection PyUnboundLocalVariable
tensor = tensor + x
return tensor
class RecurrentModule(nn.Module):

View File

@ -6,7 +6,6 @@ from torch import nn
from torch import functional as F
from torch.utils.data import DataLoader
from dataset.dataset import TrajDataset, TrajPairDataset
from lib.objects.map import MapStorage
import pytorch_lightning as pl
@ -17,8 +16,20 @@ import pytorch_lightning as pl
class Flatten(nn.Module):
def __init__(self, to=(-1, )):
@property
def shape(self):
try:
x = torch.randn(self.in_shape).unsqueeze(0)
output = self(x)
return output.shape[1:]
except Exception as e:
print(e)
return -1
def __init__(self, in_shape, to=(-1, )):
super(Flatten, self).__init__()
self.in_shape = in_shape
self.to = to
def forward(self, x):