From cfeea05673da1147a3c54bf0df378d32115a92d4 Mon Sep 17 00:00:00 2001 From: Si11ium Date: Sun, 22 Nov 2020 16:23:59 +0100 Subject: [PATCH] New Model, Many Changes --- modules/util.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/modules/util.py b/modules/util.py index 3394aae..7554941 100644 --- a/modules/util.py +++ b/modules/util.py @@ -155,9 +155,10 @@ class PreNorm(nn.Module): return self.fn(self.norm(x), **kwargs) -class SlidingWindow(nn.Module): - def __init__(self, kernel, stride=1, padding=0, keepdim=False): +class SlidingWindow(ShapeMixin, nn.Module): + def __init__(self, in_shape, kernel, stride=1, padding=0, keepdim=False): super(SlidingWindow, self).__init__() + self.in_shape = in_shape self.kernel = kernel if not isinstance(kernel, int) else (kernel, kernel) self.padding = padding self.stride = stride @@ -263,24 +264,25 @@ class FlipTensor(nn.Module): return inverted_tensor -class AutoPadToShape(object): - def __init__(self, shape): - self.shape = shape +class AutoPadToShape(nn.Module): + def __init__(self, target_shape): + super(AutoPadToShape, self).__init__() + self.target_shape = target_shape - def __call__(self, x): + def forward(self, x): if not torch.is_tensor(x): x = torch.as_tensor(x) - if x.shape[-len(self.shape):] == self.shape or x.shape == self.shape: + if x.shape[-len(self.target_shape):] == self.target_shape or x.shape == self.target_shape: return x - idx = [0] * (len(self.shape) * 2) - for i, j in zip(range(-1, -(len(self.shape)+1), -1), range(0, len(idx), 2)): - idx[j] = self.shape[i] - x.shape[i] + idx = [0] * (len(self.target_shape) * 2) + for i, j in zip(range(-1, -(len(self.target_shape)+1), -1), range(0, len(idx), 2)): + idx[j] = self.target_shape[i] - x.shape[i] x = torch.nn.functional.pad(x, idx) return x def __repr__(self): - return f'AutoPadTransform({self.shape})' + return f'AutoPadTransform({self.target_shape})' class Splitter(nn.Module):