Debugging

This commit is contained in:
Si11ium
2020-02-28 19:11:53 +01:00
parent 7b3f781d19
commit 44f6589259
18 changed files with 134 additions and 78 deletions

View File

@@ -36,9 +36,9 @@ class ConvModule(nn.Module):
self.stride = conv_stride
# Modules
self.dropout = nn.Dropout2d(dropout) if dropout else False
self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else False
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if use_norm else False
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else lambda x: x
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if use_norm else lambda x: x
self.conv = conv_class(in_channels, conv_filters, conv_kernel, bias=use_bias,
padding=self.padding, stride=self.stride
)
@@ -47,8 +47,8 @@ class ConvModule(nn.Module):
x = self.norm(x) if self.norm else x
tensor = self.conv(x)
tensor = self.dropout(tensor) if self.dropout else tensor
tensor = self.pooling(tensor) if self.pooling else tensor
tensor = self.dropout(tensor)
tensor = self.pooling(tensor)
tensor = self.activation(tensor)
return tensor
@@ -72,23 +72,23 @@ class DeConvModule(nn.Module):
self.in_shape = in_shape
self.conv_filters = conv_filters
self.autopad = AutoPad() if autopad else False
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else False
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if normalize else False
self.dropout = nn.Dropout2d(dropout) if dropout else False
self.autopad = AutoPad() if autopad else lambda x: x
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if normalize else lambda x: x
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, conv_kernel, bias=use_bias,
padding=self.padding, stride=self.stride)
self.activation = activation() if activation else None
self.activation = activation() if activation else lambda x: x
def forward(self, x):
x = self.norm(x) if self.norm else x
x = self.dropout(x) if self.dropout else x
x = self.autopad(x) if self.autopad else x
x = self.interpolation(x) if self.interpolation else x
x = self.norm(x)
x = self.dropout(x)
x = self.autopad(x)
x = self.interpolation(x)
tensor = self.de_conv(x)
tensor = self.activation(tensor) if self.activation else tensor
tensor = self.activation(tensor)
return tensor
@@ -100,12 +100,13 @@ class ResidualModule(nn.Module):
output = self(x)
return output.shape[1:]
def __init__(self, in_shape, module_class, n, **module_paramters):
def __init__(self, in_shape, module_class, n, activation=None, **module_paramters):
assert n >= 1
super(ResidualModule, self).__init__()
self.in_shape = in_shape
module_paramters.update(in_shape=in_shape)
self.residual_block = [module_class(**module_paramters) for x in range(n)]
self.activation = activation() if activation else lambda x: x
self.residual_block = [module_class(**module_paramters) for _ in range(n)]
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
def forward(self, x):
@@ -114,6 +115,7 @@ class ResidualModule(nn.Module):
# noinspection PyUnboundLocalVariable
tensor = tensor + x
tensor = self.activation(tensor)
return tensor

View File

@@ -123,6 +123,10 @@ class LightningBaseModule(pl.LightningModule, ABC):
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
@property
def data_len(self):
return len(self.dataset.train_dataset)
def configure_optimizers(self):
raise NotImplementedError