CNN Classifier
This commit is contained in:
@ -22,7 +22,7 @@ class ConvModule(nn.Module):
|
||||
return output.shape[1:]
|
||||
|
||||
def __init__(self, in_shape, activation: nn.Module = nn.ELU, pooling_size=None, use_bias=True, use_norm=True,
|
||||
dropout: Union[int, float] = 0,
|
||||
dropout: Union[int, float] = 0, conv_class=nn.Conv2d,
|
||||
conv_filters=64, conv_kernel=5, conv_stride=1, conv_padding=0):
|
||||
super(ConvModule, self).__init__()
|
||||
|
||||
@ -39,9 +39,9 @@ class ConvModule(nn.Module):
|
||||
self.dropout = nn.Dropout2d(dropout) if dropout else False
|
||||
self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else False
|
||||
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if use_norm else False
|
||||
self.conv = nn.Conv2d(in_channels, conv_filters, conv_kernel, bias=use_bias,
|
||||
padding=self.padding, stride=self.stride
|
||||
)
|
||||
self.conv = conv_class(in_channels, conv_filters, conv_kernel, bias=use_bias,
|
||||
padding=self.padding, stride=self.stride
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x) if self.norm else x
|
||||
@ -91,8 +91,30 @@ class DeConvModule(nn.Module):
|
||||
tensor = self.activation(tensor) if self.activation else tensor
|
||||
return tensor
|
||||
|
||||
def size(self):
|
||||
return self.shape
|
||||
|
||||
class ResidualModule(nn.Module):
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||
output = self(x)
|
||||
return output.shape[1:]
|
||||
|
||||
def __init__(self, in_shape, module_class, n, **module_paramters):
|
||||
assert n >= 1
|
||||
super(ResidualModule, self).__init__()
|
||||
self.in_shape = in_shape
|
||||
module_paramters.update(in_shape=in_shape)
|
||||
self.residual_block = [module_class(**module_paramters) for x in range(n)]
|
||||
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
|
||||
|
||||
def forward(self, x):
|
||||
for module in self.residual_block:
|
||||
tensor = module(x)
|
||||
|
||||
# noinspection PyUnboundLocalVariable
|
||||
tensor = tensor + x
|
||||
return tensor
|
||||
|
||||
|
||||
class RecurrentModule(nn.Module):
|
||||
|
@ -6,7 +6,6 @@ from torch import nn
|
||||
from torch import functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from dataset.dataset import TrajDataset, TrajPairDataset
|
||||
from lib.objects.map import MapStorage
|
||||
|
||||
import pytorch_lightning as pl
|
||||
@ -17,8 +16,20 @@ import pytorch_lightning as pl
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def __init__(self, to=(-1, )):
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
try:
|
||||
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||
output = self(x)
|
||||
return output.shape[1:]
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return -1
|
||||
|
||||
def __init__(self, in_shape, to=(-1, )):
|
||||
super(Flatten, self).__init__()
|
||||
self.in_shape = in_shape
|
||||
self.to = to
|
||||
|
||||
def forward(self, x):
|
||||
|
Reference in New Issue
Block a user