New Model, Many Changes
This commit is contained in:
parent
14ed4e0117
commit
cfeea05673
@ -155,9 +155,10 @@ class PreNorm(nn.Module):
|
|||||||
return self.fn(self.norm(x), **kwargs)
|
return self.fn(self.norm(x), **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class SlidingWindow(nn.Module):
|
class SlidingWindow(ShapeMixin, nn.Module):
|
||||||
def __init__(self, kernel, stride=1, padding=0, keepdim=False):
|
def __init__(self, in_shape, kernel, stride=1, padding=0, keepdim=False):
|
||||||
super(SlidingWindow, self).__init__()
|
super(SlidingWindow, self).__init__()
|
||||||
|
self.in_shape = in_shape
|
||||||
self.kernel = kernel if not isinstance(kernel, int) else (kernel, kernel)
|
self.kernel = kernel if not isinstance(kernel, int) else (kernel, kernel)
|
||||||
self.padding = padding
|
self.padding = padding
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
@ -263,24 +264,25 @@ class FlipTensor(nn.Module):
|
|||||||
return inverted_tensor
|
return inverted_tensor
|
||||||
|
|
||||||
|
|
||||||
class AutoPadToShape(object):
|
class AutoPadToShape(nn.Module):
|
||||||
def __init__(self, shape):
|
def __init__(self, target_shape):
|
||||||
self.shape = shape
|
super(AutoPadToShape, self).__init__()
|
||||||
|
self.target_shape = target_shape
|
||||||
|
|
||||||
def __call__(self, x):
|
def forward(self, x):
|
||||||
if not torch.is_tensor(x):
|
if not torch.is_tensor(x):
|
||||||
x = torch.as_tensor(x)
|
x = torch.as_tensor(x)
|
||||||
if x.shape[-len(self.shape):] == self.shape or x.shape == self.shape:
|
if x.shape[-len(self.target_shape):] == self.target_shape or x.shape == self.target_shape:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
idx = [0] * (len(self.shape) * 2)
|
idx = [0] * (len(self.target_shape) * 2)
|
||||||
for i, j in zip(range(-1, -(len(self.shape)+1), -1), range(0, len(idx), 2)):
|
for i, j in zip(range(-1, -(len(self.target_shape)+1), -1), range(0, len(idx), 2)):
|
||||||
idx[j] = self.shape[i] - x.shape[i]
|
idx[j] = self.target_shape[i] - x.shape[i]
|
||||||
x = torch.nn.functional.pad(x, idx)
|
x = torch.nn.functional.pad(x, idx)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'AutoPadTransform({self.shape})'
|
return f'AutoPadTransform({self.target_shape})'
|
||||||
|
|
||||||
|
|
||||||
class Splitter(nn.Module):
|
class Splitter(nn.Module):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user