SubSpectral and Lightning 0.9 Update

This commit is contained in:
Si11ium
2020-09-25 15:35:15 +02:00
parent 6bc9447ce1
commit 5848b528f0
13 changed files with 197 additions and 630 deletions

View File

@@ -130,8 +130,9 @@ class DeConvModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, conv_filters, conv_kernel, conv_stride=1, conv_padding=0,
dropout: Union[int, float] = 0, autopad=0,
activation: Union[None, nn.Module] = nn.ReLU, interpolation_scale=0,
bias=True, norm=False):
bias=True, norm=False, **kwargs):
super(DeConvModule, self).__init__()
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
self.padding = conv_padding
self.conv_kernel = conv_kernel

View File

@@ -1,8 +1,10 @@
import torch
from torch import nn
from torch.nn import ReLU
from torch_geometric.nn import PointConv, fps, radius, global_max_pool, knn_interpolate
try:
from torch_geometric.nn import PointConv, fps, radius, global_max_pool, knn_interpolate
except ImportError:
print('Install torch-geometric to use this package.')
class SAModule(torch.nn.Module):

View File

@@ -1,10 +1,77 @@
#
# Full Model Parts
###################
import torch
from torch import nn
from argparse import Namespace
from typing import Union, List, Tuple
from .util import ShapeMixin
import torch
from abc import ABC
from torch import nn
from torch.utils.data import DataLoader
from .util import ShapeMixin, LightningBaseModule
class AEBaseModule(LightningBaseModule, ABC):
def generate_random_image(self, dataloader: Union[None, str, DataLoader] = None,
lat_min: Union[Tuple, List, None] = None,
lat_max: Union[Tuple, List, None] = None):
assert bool(dataloader) ^ bool(lat_min and lat_max), 'Decide wether to give min, max or a dataloader, not both.'
min_max = self._find_min_max(dataloader) if dataloader else [None, None]
# assert not any([x is None for x in min_max])
lat_min = torch.as_tensor(lat_min or min_max[0])
lat_max = lat_max or min_max[1]
random_z = torch.rand((1, self.lat_dim))
random_z = random_z * (abs(lat_min) + lat_max) - abs(lat_min)
return self.decoder(random_z).squeeze()
def encode(self, x):
if len(x.shape) == 3:
x = x.unsqueeze(0)
return self.encoder(x).squeeze()
def _find_min_max(self, dataloader):
encodings = list()
for batch in dataloader:
encodings.append(self.encode(batch))
encodings = torch.cat(encodings, dim=0)
min_lat = encodings.min(dim=1)
max_lat = encodings.max(dim=1)
return min_lat, max_lat
def decode_lat_evenly(self, n: int,
dataloader: Union[None, str, DataLoader] = None,
lat_min: Union[Tuple, List, None] = None,
lat_max: Union[Tuple, List, None] = None):
assert bool(dataloader) ^ bool(lat_min and lat_max), 'Decide wether to give min, max or a dataloader, not both.'
min_max = self._find_min_max(dataloader) if dataloader else [None, None]
lat_min = lat_min or min_max[0]
lat_max = lat_max or min_max[1]
random_latent_samples = torch.stack([torch.linspace(lat_min[i].item(), lat_max[i].item(), n)
for i in range(self.params.lat_dim)], dim=-1).cpu().detach()
return self.decode(random_latent_samples).cpu().detach()
def decode(self, z):
if len(z.shape) == 1:
z = z.unsqueeze(0)
return self.decoder(z).squeeze()
def encode_and_restore(self, x):
x = x.to(self.device)
if len(x.shape) == 3:
x = x.unsqueeze(0)
z = self.encode(x)
x_hat = self.decode(z)
return Namespace(main_out=x_hat.squeeze(), latent_out=z)
class Generator(nn.Module):
@@ -16,9 +83,12 @@ class Generator(nn.Module):
# noinspection PyUnresolvedReferences
def __init__(self, out_channels, re_shape, lat_dim, use_norm=False, use_bias=True, dropout: Union[int, float] = 0,
filters: List[int] = None, activation=nn.ReLU):
filters: List[int] = None, kernels: List[int] = None, activation=nn.ReLU, **kwargs):
super(Generator, self).__init__()
assert filters, '"Filters" has to be a list of int len 3'
assert filters, '"Filters" has to be a list of int.'
assert filters, '"Filters" has to be a list of int.'
assert len(filters) == len(kernels), '"Filters" and "Kernels" has to be of same length.'
self.filters = filters
self.activation = activation
self.inner_activation = activation()
@@ -29,52 +99,35 @@ class Generator(nn.Module):
# re_shape = (self.feature_mixed_dim // reduce(mul, re_shape[1:]), ) + tuple(re_shape[1:])
self.flat = Flatten(to=re_shape)
self.de_conv_list = nn.ModuleList()
self.deconv1 = DeConvModule(re_shape, conv_filters=self.filters[0],
conv_kernel=5,
conv_padding=2,
conv_stride=1,
normalize=use_norm,
activation=self.activation,
interpolation_scale=2,
dropout=self.dropout
)
last_shape = re_shape
for conv_filter, conv_kernel in zip(filters, kernels):
self.de_conv_list.append(DeConvModule(last_shape, conv_filters=self.filters[0],
conv_kernel=conv_kernel,
conv_padding=conv_kernel-2,
conv_stride=conv_filter,
normalize=use_norm,
activation=self.activation,
interpolation_scale=2,
dropout=self.dropout
)
)
last_shape = self.de_conv_list[-1].shape
self.deconv2 = DeConvModule(self.deconv1.shape, conv_filters=self.filters[1],
conv_kernel=3,
conv_padding=1,
conv_stride=1,
normalize=use_norm,
activation=self.activation,
interpolation_scale=2,
dropout=self.dropout
)
self.deconv3 = DeConvModule(self.deconv2.shape, conv_filters=self.filters[2],
conv_kernel=3,
conv_padding=1,
conv_stride=1,
normalize=use_norm,
activation=self.activation,
interpolation_scale=2,
dropout=self.dropout
)
self.deconv4 = DeConvModule(self.deconv3.shape, conv_filters=out_channels,
conv_kernel=3,
conv_padding=1,
# normalize=norm,
activation=self.out_activation
)
self.de_conv_out = DeConvModule(self.de_conv_list[-1].shape, conv_filters=out_channels, conv_kernel=3,
conv_padding=1, activation=self.out_activation
)
def forward(self, z):
tensor = self.l1(z)
tensor = self.inner_activation(tensor)
tensor = self.flat(tensor)
tensor = self.deconv1(tensor)
tensor = self.deconv2(tensor)
tensor = self.deconv3(tensor)
tensor = self.deconv4(tensor)
for de_conv in self.de_conv_list:
tensor = de_conv(tensor)
tensor = self.de_conv_out(tensor)
return tensor
def size(self):
@@ -119,12 +172,14 @@ class BaseEncoder(ShapeMixin, nn.Module):
# noinspection PyUnresolvedReferences
def __init__(self, in_shape, lat_dim=256, use_bias=True, use_norm=False, dropout: Union[int, float] = 0,
latent_activation: Union[nn.Module, None] = None, activation: nn.Module = nn.ELU,
filters: List[int] = None):
filters: List[int] = None, kernels: List[int] = None, **kwargs):
super(BaseEncoder, self).__init__()
assert filters, '"Filters" has to be a list of int len 3'
assert filters, '"Filters" has to be a list of int'
assert kernels, '"Kernels" has to be a list of int'
assert len(kernels) == len(filters), 'Length of "Filters" and "Kernels" has to be same.'
# Optional Padding for odd image-sizes
# Obsolet, already Done by autopadding module on incoming tensors
# Obsolet, cdan be done by autopadding module on incoming tensors
# in_shape = [x+1 if x % 2 != 0 and idx else x for idx, x in enumerate(in_shape)]
# Parameters
@@ -133,43 +188,29 @@ class BaseEncoder(ShapeMixin, nn.Module):
self.use_bias = use_bias
self.latent_activation = latent_activation() if latent_activation else None
self.conv_list = nn.ModuleList()
# Modules
self.conv1 = ConvModule(self.in_shape, conv_filters=filters[0],
conv_kernel=3,
conv_padding=1,
conv_stride=1,
pooling_size=2,
use_norm=use_norm,
dropout=dropout,
activation=activation
)
self.conv2 = ConvModule(self.conv1.shape, conv_filters=filters[1],
conv_kernel=3,
conv_padding=1,
conv_stride=1,
pooling_size=2,
use_norm=use_norm,
dropout=dropout,
activation=activation
)
self.conv3 = ConvModule(self.conv2.shape, conv_filters=filters[2],
conv_kernel=5,
conv_padding=2,
conv_stride=1,
pooling_size=2,
use_norm=use_norm,
dropout=dropout,
activation=activation
)
last_shape = self.in_shape
for conv_filter, conv_kernel in zip(filters, kernels):
self.conv_list.append(ConvModule(last_shape, conv_filters=conv_filter,
conv_kernel=conv_kernel,
conv_padding=conv_kernel-2,
conv_stride=1,
pooling_size=2,
use_norm=use_norm,
dropout=dropout,
activation=activation
)
)
last_shape = self.conv_list[-1].shape
self.flat = Flatten()
def forward(self, x):
tensor = self.conv1(x)
tensor = self.conv2(tensor)
tensor = self.conv3(tensor)
tensor = x
for conv in self.conv_list:
tensor = conv(tensor)
tensor = self.flat(tensor)
return tensor

View File

@@ -1,7 +1,10 @@
from functools import reduce
from abc import ABC
from pathlib import Path
import torch
from operator import mul
from torch import nn
from torch import functional as F
@@ -102,6 +105,14 @@ class ShapeMixin:
else:
return -1
@property
def flat_shape(self):
shape = self.shape
try:
return reduce(mul, shape)
except TypeError:
return shape
class F_x(ShapeMixin, nn.Module):
def __init__(self, in_shape):
@@ -175,7 +186,7 @@ class WeightInit:
m.bias.data.fill_(0.01)
class Filter(nn.Module):
class Filter(nn.Module, ShapeMixin):
def __init__(self, in_shape, pos, dim=-1):
super(Filter, self).__init__()
@@ -210,11 +221,15 @@ class AutoPadToShape(object):
def __call__(self, x):
if not torch.is_tensor(x):
x = torch.as_tensor(x)
if x.shape[1:] == self.shape:
if x.shape[1:] == self.shape or x.shape == self.shape:
return x
embedding = torch.zeros((x.shape[0], *self.shape))
embedding[:, :x.shape[1], :x.shape[2], :x.shape[3]] = x
return embedding
for i in range(-1, -len(self.shape), -1):
idx = [0] * len(x.shape)
idx[i] = self.shape[i] - x.shape[i]
idx = tuple(idx)
x = torch.nn.functional.pad(x, idx)
return x
def __repr__(self):
return f'AutoPadTransform({self.shape})'
@@ -233,9 +248,9 @@ class Splitter(nn.Module):
def __init__(self, in_shape, n, dim=-1):
super(Splitter, self).__init__()
self.n = n
self.dim = dim
self.in_shape = in_shape
self.n = n
self.dim = dim if dim > 0 else len(self.in_shape) - abs(dim)
self.new_dim_size = (self.in_shape[self.dim] // self.n) + (1 if self.in_shape[self.dim] % self.n != 0 else 0)
self._out_shape = tuple([x if self.dim != i else self.new_dim_size for i, x in enumerate(self.in_shape)])
@@ -243,22 +258,23 @@ class Splitter(nn.Module):
self.autopad = AutoPadToShape(self._out_shape)
def forward(self, x: torch.Tensor):
x = x.transpose(0, self.dim)
dim = self.dim + 1 if len(self.in_shape) == (x.ndim -1) else self.dim
x = x.transpose(0, dim)
n_blocks = list()
for block_idx in range(self.n):
start = block_idx * self.new_dim_size
end = (block_idx + 1) * self.new_dim_size
block = self.autopad(x[:, :, start:end, :])
n_blocks.append(block.transpose(0, self.dim))
block = x[start:end].transpose(0, dim)
block = self.autopad(block)
n_blocks.append(block)
return n_blocks
class Merger(nn.Module):
class Merger(nn.Module, ShapeMixin):
@property
def shape(self):
y = self.forward([torch.randn(self.in_shape)])
y = self.forward([torch.randn(self.in_shape) for _ in range(self.n)])
return y.shape
def __init__(self, in_shape, n, dim=-1):