New Model, Many Changes

This commit is contained in:
Si11ium 2020-11-22 16:23:59 +01:00
parent 14ed4e0117
commit cfeea05673

View File

@ -155,9 +155,10 @@ class PreNorm(nn.Module):
return self.fn(self.norm(x), **kwargs) return self.fn(self.norm(x), **kwargs)
class SlidingWindow(nn.Module): class SlidingWindow(ShapeMixin, nn.Module):
def __init__(self, kernel, stride=1, padding=0, keepdim=False): def __init__(self, in_shape, kernel, stride=1, padding=0, keepdim=False):
super(SlidingWindow, self).__init__() super(SlidingWindow, self).__init__()
self.in_shape = in_shape
self.kernel = kernel if not isinstance(kernel, int) else (kernel, kernel) self.kernel = kernel if not isinstance(kernel, int) else (kernel, kernel)
self.padding = padding self.padding = padding
self.stride = stride self.stride = stride
@ -263,24 +264,25 @@ class FlipTensor(nn.Module):
return inverted_tensor return inverted_tensor
class AutoPadToShape(object): class AutoPadToShape(nn.Module):
def __init__(self, shape): def __init__(self, target_shape):
self.shape = shape super(AutoPadToShape, self).__init__()
self.target_shape = target_shape
def __call__(self, x): def forward(self, x):
if not torch.is_tensor(x): if not torch.is_tensor(x):
x = torch.as_tensor(x) x = torch.as_tensor(x)
if x.shape[-len(self.shape):] == self.shape or x.shape == self.shape: if x.shape[-len(self.target_shape):] == self.target_shape or x.shape == self.target_shape:
return x return x
idx = [0] * (len(self.shape) * 2) idx = [0] * (len(self.target_shape) * 2)
for i, j in zip(range(-1, -(len(self.shape)+1), -1), range(0, len(idx), 2)): for i, j in zip(range(-1, -(len(self.target_shape)+1), -1), range(0, len(idx), 2)):
idx[j] = self.shape[i] - x.shape[i] idx[j] = self.target_shape[i] - x.shape[i]
x = torch.nn.functional.pad(x, idx) x = torch.nn.functional.pad(x, idx)
return x return x
def __repr__(self): def __repr__(self):
return f'AutoPadTransform({self.shape})' return f'AutoPadTransform({self.target_shape})'
class Splitter(nn.Module): class Splitter(nn.Module):