ResidualModule and New Parameters, Speed Manipulation

This commit is contained in:
Si11ium
2020-05-12 12:37:25 +02:00
parent f6c6726509
commit dfe2db342f
4 changed files with 132 additions and 63 deletions

View File

@@ -15,7 +15,7 @@ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
###################
class LinearModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, out_features, activation=None, bias=True,
def __init__(self, in_shape, out_features, bias=True, activation=None,
norm=False, dropout: Union[int, float] = 0, **kwargs):
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
super(LinearModule, self).__init__()
@@ -25,10 +25,11 @@ class LinearModule(ShapeMixin, nn.Module):
self.dropout = nn.Dropout(dropout) if dropout else F_x(self.flat.shape)
self.norm = nn.BatchNorm1d(self.flat.shape) if norm else F_x(self.flat.shape)
self.linear = nn.Linear(self.flat.shape, out_features, bias=bias)
self.activation = activation() or F_x(self.linear.out_features)
self.activation = activation() if activation else F_x(self.linear.out_features)
def forward(self, x):
tensor = self.flat(x)
tensor = self.dropout(tensor)
tensor = self.norm(tensor)
tensor = self.linear(tensor)
tensor = self.activation(tensor)
@@ -108,12 +109,16 @@ class DeConvModule(ShapeMixin, nn.Module):
class ResidualModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, module_class, n, activation=None, **module_parameters):
def __init__(self, in_shape, module_class, n, **module_parameters):
assert n >= 1
super(ResidualModule, self).__init__()
self.in_shape = in_shape
module_parameters.update(in_shape=in_shape)
self.activation = activation() if activation else lambda x: x
self.activation = module_parameters.get('activation', None)
if self.activation is not None:
self.activation = self.activation()
else:
self.activation = F_x(self.in_shape)
self.residual_block = nn.ModuleList([module_class(**module_parameters) for _ in range(n)])
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'

View File

@@ -18,12 +18,14 @@ class ShapeMixin:
@property
def shape(self):
assert isinstance(self, (LightningBaseModule, nn.Module))
x = torch.randn(self.in_shape)
# This is needed for BatchNorm shape checking
x = torch.stack((x, x))
output = self(x)
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
if self.in_shape is not None:
x = torch.randn(self.in_shape)
# This is needed for BatchNorm shape checking
x = torch.stack((x, x))
output = self(x)
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
else:
return -1
class F_x(ShapeMixin, nn.Module):