SubSpectral and Lightning 0.9 Update
This commit is contained in:
@ -1,7 +1,10 @@
|
||||
from functools import reduce
|
||||
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from operator import mul
|
||||
from torch import nn
|
||||
from torch import functional as F
|
||||
|
||||
@ -102,6 +105,14 @@ class ShapeMixin:
|
||||
else:
|
||||
return -1
|
||||
|
||||
@property
|
||||
def flat_shape(self):
|
||||
shape = self.shape
|
||||
try:
|
||||
return reduce(mul, shape)
|
||||
except TypeError:
|
||||
return shape
|
||||
|
||||
|
||||
class F_x(ShapeMixin, nn.Module):
|
||||
def __init__(self, in_shape):
|
||||
@ -175,7 +186,7 @@ class WeightInit:
|
||||
m.bias.data.fill_(0.01)
|
||||
|
||||
|
||||
class Filter(nn.Module):
|
||||
class Filter(nn.Module, ShapeMixin):
|
||||
|
||||
def __init__(self, in_shape, pos, dim=-1):
|
||||
super(Filter, self).__init__()
|
||||
@ -210,11 +221,15 @@ class AutoPadToShape(object):
|
||||
def __call__(self, x):
|
||||
if not torch.is_tensor(x):
|
||||
x = torch.as_tensor(x)
|
||||
if x.shape[1:] == self.shape:
|
||||
if x.shape[1:] == self.shape or x.shape == self.shape:
|
||||
return x
|
||||
embedding = torch.zeros((x.shape[0], *self.shape))
|
||||
embedding[:, :x.shape[1], :x.shape[2], :x.shape[3]] = x
|
||||
return embedding
|
||||
|
||||
for i in range(-1, -len(self.shape), -1):
|
||||
idx = [0] * len(x.shape)
|
||||
idx[i] = self.shape[i] - x.shape[i]
|
||||
idx = tuple(idx)
|
||||
x = torch.nn.functional.pad(x, idx)
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
return f'AutoPadTransform({self.shape})'
|
||||
@ -233,9 +248,9 @@ class Splitter(nn.Module):
|
||||
def __init__(self, in_shape, n, dim=-1):
|
||||
super(Splitter, self).__init__()
|
||||
|
||||
self.n = n
|
||||
self.dim = dim
|
||||
self.in_shape = in_shape
|
||||
self.n = n
|
||||
self.dim = dim if dim > 0 else len(self.in_shape) - abs(dim)
|
||||
|
||||
self.new_dim_size = (self.in_shape[self.dim] // self.n) + (1 if self.in_shape[self.dim] % self.n != 0 else 0)
|
||||
self._out_shape = tuple([x if self.dim != i else self.new_dim_size for i, x in enumerate(self.in_shape)])
|
||||
@ -243,22 +258,23 @@ class Splitter(nn.Module):
|
||||
self.autopad = AutoPadToShape(self._out_shape)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = x.transpose(0, self.dim)
|
||||
dim = self.dim + 1 if len(self.in_shape) == (x.ndim -1) else self.dim
|
||||
x = x.transpose(0, dim)
|
||||
n_blocks = list()
|
||||
for block_idx in range(self.n):
|
||||
start = block_idx * self.new_dim_size
|
||||
end = (block_idx + 1) * self.new_dim_size
|
||||
block = self.autopad(x[:, :, start:end, :])
|
||||
|
||||
n_blocks.append(block.transpose(0, self.dim))
|
||||
block = x[start:end].transpose(0, dim)
|
||||
block = self.autopad(block)
|
||||
n_blocks.append(block)
|
||||
return n_blocks
|
||||
|
||||
|
||||
class Merger(nn.Module):
|
||||
class Merger(nn.Module, ShapeMixin):
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
y = self.forward([torch.randn(self.in_shape)])
|
||||
y = self.forward([torch.randn(self.in_shape) for _ in range(self.n)])
|
||||
return y.shape
|
||||
|
||||
def __init__(self, in_shape, n, dim=-1):
|
||||
|
Reference in New Issue
Block a user