New Model, Many Changes
This commit is contained in:
parent
14ed4e0117
commit
cfeea05673
@ -155,9 +155,10 @@ class PreNorm(nn.Module):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
|
||||
class SlidingWindow(nn.Module):
|
||||
def __init__(self, kernel, stride=1, padding=0, keepdim=False):
|
||||
class SlidingWindow(ShapeMixin, nn.Module):
|
||||
def __init__(self, in_shape, kernel, stride=1, padding=0, keepdim=False):
|
||||
super(SlidingWindow, self).__init__()
|
||||
self.in_shape = in_shape
|
||||
self.kernel = kernel if not isinstance(kernel, int) else (kernel, kernel)
|
||||
self.padding = padding
|
||||
self.stride = stride
|
||||
@ -263,24 +264,25 @@ class FlipTensor(nn.Module):
|
||||
return inverted_tensor
|
||||
|
||||
|
||||
class AutoPadToShape(object):
|
||||
def __init__(self, shape):
|
||||
self.shape = shape
|
||||
class AutoPadToShape(nn.Module):
|
||||
def __init__(self, target_shape):
|
||||
super(AutoPadToShape, self).__init__()
|
||||
self.target_shape = target_shape
|
||||
|
||||
def __call__(self, x):
|
||||
def forward(self, x):
|
||||
if not torch.is_tensor(x):
|
||||
x = torch.as_tensor(x)
|
||||
if x.shape[-len(self.shape):] == self.shape or x.shape == self.shape:
|
||||
if x.shape[-len(self.target_shape):] == self.target_shape or x.shape == self.target_shape:
|
||||
return x
|
||||
|
||||
idx = [0] * (len(self.shape) * 2)
|
||||
for i, j in zip(range(-1, -(len(self.shape)+1), -1), range(0, len(idx), 2)):
|
||||
idx[j] = self.shape[i] - x.shape[i]
|
||||
idx = [0] * (len(self.target_shape) * 2)
|
||||
for i, j in zip(range(-1, -(len(self.target_shape)+1), -1), range(0, len(idx), 2)):
|
||||
idx[j] = self.target_shape[i] - x.shape[i]
|
||||
x = torch.nn.functional.pad(x, idx)
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
return f'AutoPadTransform({self.shape})'
|
||||
return f'AutoPadTransform({self.target_shape})'
|
||||
|
||||
|
||||
class Splitter(nn.Module):
|
||||
|
Loading…
x
Reference in New Issue
Block a user