New Model, Many Changes

This commit is contained in:
Si11ium
2020-11-21 09:28:25 +01:00
parent 13812b83b5
commit 14ed4e0117
8 changed files with 127 additions and 102 deletions

View File

@ -1,3 +1,5 @@
from typing import List
from functools import reduce
from abc import ABC
@ -6,7 +8,7 @@ from pathlib import Path
import torch
from operator import mul
from torch import nn
from torch.nn import functional as F
from torch.nn import functional as F, Unfold
# Utility - Modules
###################
@ -38,6 +40,7 @@ try:
################################
self.hparams = hparams
self.params = ModelParameters(hparams)
self.lr = self.params.lr or 1e-4
def size(self):
return self.shape
@ -76,10 +79,10 @@ try:
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
self.apply(weight_initializer)
modules = [LightningBaseModule, nn.Module]
module_types = (LightningBaseModule, nn.Module,)
except ImportError:
modules = [nn.Module, ]
module_types = (nn.Module,)
pass # Maybe post a hint to install pytorch-lightning.
@ -88,7 +91,7 @@ class ShapeMixin:
@property
def shape(self):
assert isinstance(self, modules)
assert isinstance(self, module_types)
def get_out_shape(output):
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
@ -135,6 +138,41 @@ class F_x(ShapeMixin, nn.Module):
return x
class ResidualBlock(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class SlidingWindow(nn.Module):
def __init__(self, kernel, stride=1, padding=0, keepdim=False):
super(SlidingWindow, self).__init__()
self.kernel = kernel if not isinstance(kernel, int) else (kernel, kernel)
self.padding = padding
self.stride = stride
self.keepdim = keepdim
self._unfolder = Unfold(self.kernel, dilation=1, padding=self.padding, stride=self.stride)
def forward(self, x):
tensor = self._unfolder(x)
tensor = tensor.transpose(-1, -2)
if self.keepdim:
shape = *x.shape[:2], -1, *self.kernel
tensor = tensor.reshape(shape)
return tensor
# Utility - Modules
###################
class Flatten(ShapeMixin, nn.Module):
@ -232,14 +270,13 @@ class AutoPadToShape(object):
def __call__(self, x):
if not torch.is_tensor(x):
x = torch.as_tensor(x)
if x.shape[1:] == self.shape or x.shape == self.shape:
if x.shape[-len(self.shape):] == self.shape or x.shape == self.shape:
return x
for i in range(-1, -len(self.shape), -1):
idx = [0] * len(x.shape)
idx[i] = self.shape[i] - x.shape[i]
idx = tuple(idx)
x = torch.nn.functional.pad(x, idx)
idx = [0] * (len(self.shape) * 2)
for i, j in zip(range(-1, -(len(self.shape)+1), -1), range(0, len(idx), 2)):
idx[j] = self.shape[i] - x.shape[i]
x = torch.nn.functional.pad(x, idx)
return x
def __repr__(self):