New Model, Many Changes
This commit is contained in:
@ -1,3 +1,5 @@
|
||||
from typing import List
|
||||
|
||||
from functools import reduce
|
||||
|
||||
from abc import ABC
|
||||
@ -6,7 +8,7 @@ from pathlib import Path
|
||||
import torch
|
||||
from operator import mul
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn import functional as F, Unfold
|
||||
|
||||
# Utility - Modules
|
||||
###################
|
||||
@ -38,6 +40,7 @@ try:
|
||||
################################
|
||||
self.hparams = hparams
|
||||
self.params = ModelParameters(hparams)
|
||||
self.lr = self.params.lr or 1e-4
|
||||
|
||||
def size(self):
|
||||
return self.shape
|
||||
@ -76,10 +79,10 @@ try:
|
||||
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
|
||||
self.apply(weight_initializer)
|
||||
|
||||
modules = [LightningBaseModule, nn.Module]
|
||||
module_types = (LightningBaseModule, nn.Module,)
|
||||
|
||||
except ImportError:
|
||||
modules = [nn.Module, ]
|
||||
module_types = (nn.Module,)
|
||||
pass # Maybe post a hint to install pytorch-lightning.
|
||||
|
||||
|
||||
@ -88,7 +91,7 @@ class ShapeMixin:
|
||||
@property
|
||||
def shape(self):
|
||||
|
||||
assert isinstance(self, modules)
|
||||
assert isinstance(self, module_types)
|
||||
|
||||
def get_out_shape(output):
|
||||
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
|
||||
@ -135,6 +138,41 @@ class F_x(ShapeMixin, nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(x, **kwargs) + x
|
||||
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
|
||||
class SlidingWindow(nn.Module):
|
||||
def __init__(self, kernel, stride=1, padding=0, keepdim=False):
|
||||
super(SlidingWindow, self).__init__()
|
||||
self.kernel = kernel if not isinstance(kernel, int) else (kernel, kernel)
|
||||
self.padding = padding
|
||||
self.stride = stride
|
||||
self.keepdim = keepdim
|
||||
self._unfolder = Unfold(self.kernel, dilation=1, padding=self.padding, stride=self.stride)
|
||||
|
||||
def forward(self, x):
|
||||
tensor = self._unfolder(x)
|
||||
tensor = tensor.transpose(-1, -2)
|
||||
if self.keepdim:
|
||||
shape = *x.shape[:2], -1, *self.kernel
|
||||
tensor = tensor.reshape(shape)
|
||||
return tensor
|
||||
|
||||
|
||||
# Utility - Modules
|
||||
###################
|
||||
class Flatten(ShapeMixin, nn.Module):
|
||||
@ -232,14 +270,13 @@ class AutoPadToShape(object):
|
||||
def __call__(self, x):
|
||||
if not torch.is_tensor(x):
|
||||
x = torch.as_tensor(x)
|
||||
if x.shape[1:] == self.shape or x.shape == self.shape:
|
||||
if x.shape[-len(self.shape):] == self.shape or x.shape == self.shape:
|
||||
return x
|
||||
|
||||
for i in range(-1, -len(self.shape), -1):
|
||||
idx = [0] * len(x.shape)
|
||||
idx[i] = self.shape[i] - x.shape[i]
|
||||
idx = tuple(idx)
|
||||
x = torch.nn.functional.pad(x, idx)
|
||||
idx = [0] * (len(self.shape) * 2)
|
||||
for i, j in zip(range(-1, -(len(self.shape)+1), -1), range(0, len(idx), 2)):
|
||||
idx[j] = self.shape[i] - x.shape[i]
|
||||
x = torch.nn.functional.pad(x, idx)
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
|
Reference in New Issue
Block a user