Residual Model
This commit is contained in:
@ -109,11 +109,16 @@ class DeConvModule(ShapeMixin, nn.Module):
|
||||
|
||||
class ResidualModule(ShapeMixin, nn.Module):
|
||||
|
||||
def __init__(self, in_shape, module_class, n, **module_parameters):
|
||||
def __init__(self, in_shape, module_class, n, norm=False, **module_parameters):
|
||||
assert n >= 1
|
||||
super(ResidualModule, self).__init__()
|
||||
self.in_shape = in_shape
|
||||
module_parameters.update(in_shape=in_shape)
|
||||
if norm:
|
||||
self.norm = nn.BatchNorm1d if len(self.in_shape) <= 2 else nn.BatchNorm2d
|
||||
self.norm = self.norm(self.in_shape if isinstance(self.in_shape, int) else self.in_shape[0])
|
||||
else:
|
||||
self.norm = F_x(self.in_shape)
|
||||
self.activation = module_parameters.get('activation', None)
|
||||
if self.activation is not None:
|
||||
self.activation = self.activation()
|
||||
|
Reference in New Issue
Block a user