Model Loading by string. Within Debugging
This commit is contained in:
@ -91,12 +91,25 @@ class ConvModule(ShapeMixin, nn.Module):
|
||||
return tensor
|
||||
|
||||
|
||||
# TODO class PreInitializedConvModule(ShapeMixin, nn.Module):
|
||||
class PreInitializedConvModule(ShapeMixin, nn.Module):
|
||||
|
||||
def __init__(self, in_shape, weight_matrix):
|
||||
super(PreInitializedConvModule, self).__init__()
|
||||
self.in_shape = in_shape
|
||||
raise NotImplementedError
|
||||
# ToDo Get the weight_matrix shape and init a conv_module of similar size,
|
||||
# override the weights then.
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SobelFilter(ShapeMixin, nn.Module):
|
||||
|
||||
def __init__(self, in_shape):
|
||||
super(SobelFilter, self).__init__()
|
||||
self.in_shape = in_shape
|
||||
|
||||
self.sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).view(1, 1, 3, 3)
|
||||
self.sobel_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, 2, -1]]).view(1, 1, 3, 3)
|
||||
|
Reference in New Issue
Block a user