ResidualModule and New Parameters, Speed Manipulation
This commit is contained in:
@ -15,7 +15,7 @@ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
###################
|
||||
class LinearModule(ShapeMixin, nn.Module):
|
||||
|
||||
def __init__(self, in_shape, out_features, activation=None, bias=True,
|
||||
def __init__(self, in_shape, out_features, bias=True, activation=None,
|
||||
norm=False, dropout: Union[int, float] = 0, **kwargs):
|
||||
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
|
||||
super(LinearModule, self).__init__()
|
||||
@ -25,10 +25,11 @@ class LinearModule(ShapeMixin, nn.Module):
|
||||
self.dropout = nn.Dropout(dropout) if dropout else F_x(self.flat.shape)
|
||||
self.norm = nn.BatchNorm1d(self.flat.shape) if norm else F_x(self.flat.shape)
|
||||
self.linear = nn.Linear(self.flat.shape, out_features, bias=bias)
|
||||
self.activation = activation() or F_x(self.linear.out_features)
|
||||
self.activation = activation() if activation else F_x(self.linear.out_features)
|
||||
|
||||
def forward(self, x):
|
||||
tensor = self.flat(x)
|
||||
tensor = self.dropout(tensor)
|
||||
tensor = self.norm(tensor)
|
||||
tensor = self.linear(tensor)
|
||||
tensor = self.activation(tensor)
|
||||
@ -108,12 +109,16 @@ class DeConvModule(ShapeMixin, nn.Module):
|
||||
|
||||
class ResidualModule(ShapeMixin, nn.Module):
|
||||
|
||||
def __init__(self, in_shape, module_class, n, activation=None, **module_parameters):
|
||||
def __init__(self, in_shape, module_class, n, **module_parameters):
|
||||
assert n >= 1
|
||||
super(ResidualModule, self).__init__()
|
||||
self.in_shape = in_shape
|
||||
module_parameters.update(in_shape=in_shape)
|
||||
self.activation = activation() if activation else lambda x: x
|
||||
self.activation = module_parameters.get('activation', None)
|
||||
if self.activation is not None:
|
||||
self.activation = self.activation()
|
||||
else:
|
||||
self.activation = F_x(self.in_shape)
|
||||
self.residual_block = nn.ModuleList([module_class(**module_parameters) for _ in range(n)])
|
||||
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
|
||||
|
||||
|
Reference in New Issue
Block a user