SubSpectral and Lightning 0.9 Update

This commit is contained in:
Si11ium
2020-09-25 15:35:15 +02:00
parent 6bc9447ce1
commit 5848b528f0
13 changed files with 197 additions and 630 deletions

View File

@ -1,7 +1,10 @@
from functools import reduce
from abc import ABC
from pathlib import Path
import torch
from operator import mul
from torch import nn
from torch import functional as F
@ -102,6 +105,14 @@ class ShapeMixin:
else:
return -1
@property
def flat_shape(self):
shape = self.shape
try:
return reduce(mul, shape)
except TypeError:
return shape
class F_x(ShapeMixin, nn.Module):
def __init__(self, in_shape):
@ -175,7 +186,7 @@ class WeightInit:
m.bias.data.fill_(0.01)
class Filter(nn.Module):
class Filter(nn.Module, ShapeMixin):
def __init__(self, in_shape, pos, dim=-1):
super(Filter, self).__init__()
@ -210,11 +221,15 @@ class AutoPadToShape(object):
def __call__(self, x):
if not torch.is_tensor(x):
x = torch.as_tensor(x)
if x.shape[1:] == self.shape:
if x.shape[1:] == self.shape or x.shape == self.shape:
return x
embedding = torch.zeros((x.shape[0], *self.shape))
embedding[:, :x.shape[1], :x.shape[2], :x.shape[3]] = x
return embedding
for i in range(-1, -len(self.shape), -1):
idx = [0] * len(x.shape)
idx[i] = self.shape[i] - x.shape[i]
idx = tuple(idx)
x = torch.nn.functional.pad(x, idx)
return x
def __repr__(self):
return f'AutoPadTransform({self.shape})'
@ -233,9 +248,9 @@ class Splitter(nn.Module):
def __init__(self, in_shape, n, dim=-1):
super(Splitter, self).__init__()
self.n = n
self.dim = dim
self.in_shape = in_shape
self.n = n
self.dim = dim if dim > 0 else len(self.in_shape) - abs(dim)
self.new_dim_size = (self.in_shape[self.dim] // self.n) + (1 if self.in_shape[self.dim] % self.n != 0 else 0)
self._out_shape = tuple([x if self.dim != i else self.new_dim_size for i, x in enumerate(self.in_shape)])
@ -243,22 +258,23 @@ class Splitter(nn.Module):
self.autopad = AutoPadToShape(self._out_shape)
def forward(self, x: torch.Tensor):
x = x.transpose(0, self.dim)
dim = self.dim + 1 if len(self.in_shape) == (x.ndim -1) else self.dim
x = x.transpose(0, dim)
n_blocks = list()
for block_idx in range(self.n):
start = block_idx * self.new_dim_size
end = (block_idx + 1) * self.new_dim_size
block = self.autopad(x[:, :, start:end, :])
n_blocks.append(block.transpose(0, self.dim))
block = x[start:end].transpose(0, dim)
block = self.autopad(block)
n_blocks.append(block)
return n_blocks
class Merger(nn.Module):
class Merger(nn.Module, ShapeMixin):
@property
def shape(self):
y = self.forward([torch.randn(self.in_shape)])
y = self.forward([torch.randn(self.in_shape) for _ in range(self.n)])
return y.shape
def __init__(self, in_shape, n, dim=-1):