Model Loading by string. Within Debugging

This commit is contained in:
Si11ium
2020-08-15 12:42:57 +02:00
parent a4b6c698c3
commit 6bc9447ce1
5 changed files with 108 additions and 58 deletions

View File

@ -91,12 +91,25 @@ class ConvModule(ShapeMixin, nn.Module):
return tensor
# TODO class PreInitializedConvModule(ShapeMixin, nn.Module):
class PreInitializedConvModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, weight_matrix):
super(PreInitializedConvModule, self).__init__()
self.in_shape = in_shape
raise NotImplementedError
# ToDo Get the weight_matrix shape and init a conv_module of similar size,
# override the weights then.
def forward(self, x):
return x
class SobelFilter(ShapeMixin, nn.Module):
def __init__(self, in_shape):
super(SobelFilter, self).__init__()
self.in_shape = in_shape
self.sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).view(1, 1, 3, 3)
self.sobel_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, 2, -1]]).view(1, 1, 3, 3)