New Model, Many Changes
This commit is contained in:
		| @@ -155,9 +155,10 @@ class PreNorm(nn.Module): | ||||
|         return self.fn(self.norm(x), **kwargs) | ||||
|  | ||||
|  | ||||
| class SlidingWindow(nn.Module): | ||||
|     def __init__(self, kernel, stride=1, padding=0, keepdim=False): | ||||
| class SlidingWindow(ShapeMixin, nn.Module): | ||||
|     def __init__(self, in_shape, kernel, stride=1, padding=0, keepdim=False): | ||||
|         super(SlidingWindow, self).__init__() | ||||
|         self.in_shape = in_shape | ||||
|         self.kernel = kernel if not isinstance(kernel, int) else (kernel, kernel) | ||||
|         self.padding = padding | ||||
|         self.stride = stride | ||||
| @@ -263,24 +264,25 @@ class FlipTensor(nn.Module): | ||||
|         return inverted_tensor | ||||
|  | ||||
|  | ||||
| class AutoPadToShape(object): | ||||
|     def __init__(self, shape): | ||||
|         self.shape = shape | ||||
| class AutoPadToShape(nn.Module): | ||||
|     def __init__(self, target_shape): | ||||
|         super(AutoPadToShape, self).__init__() | ||||
|         self.target_shape = target_shape | ||||
|  | ||||
|     def __call__(self, x): | ||||
|     def forward(self, x): | ||||
|         if not torch.is_tensor(x): | ||||
|             x = torch.as_tensor(x) | ||||
|         if x.shape[-len(self.shape):] == self.shape or x.shape == self.shape: | ||||
|         if x.shape[-len(self.target_shape):] == self.target_shape or x.shape == self.target_shape: | ||||
|             return x | ||||
|  | ||||
|         idx = [0] * (len(self.shape) * 2) | ||||
|         for i, j in zip(range(-1, -(len(self.shape)+1), -1), range(0, len(idx), 2)): | ||||
|             idx[j] = self.shape[i] - x.shape[i] | ||||
|         idx = [0] * (len(self.target_shape) * 2) | ||||
|         for i, j in zip(range(-1, -(len(self.target_shape)+1), -1), range(0, len(idx), 2)): | ||||
|             idx[j] = self.target_shape[i] - x.shape[i] | ||||
|         x = torch.nn.functional.pad(x, idx) | ||||
|         return x | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return f'AutoPadTransform({self.shape})' | ||||
|         return f'AutoPadTransform({self.target_shape})' | ||||
|  | ||||
|  | ||||
| class Splitter(nn.Module): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Si11ium
					Si11ium