SubSpectral and Lightning 0.9 Update
This commit is contained in:
@@ -130,8 +130,9 @@ class DeConvModule(ShapeMixin, nn.Module):
|
||||
def __init__(self, in_shape, conv_filters, conv_kernel, conv_stride=1, conv_padding=0,
|
||||
dropout: Union[int, float] = 0, autopad=0,
|
||||
activation: Union[None, nn.Module] = nn.ReLU, interpolation_scale=0,
|
||||
bias=True, norm=False):
|
||||
bias=True, norm=False, **kwargs):
|
||||
super(DeConvModule, self).__init__()
|
||||
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
|
||||
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
||||
self.padding = conv_padding
|
||||
self.conv_kernel = conv_kernel
|
||||
|
@@ -1,8 +1,10 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import ReLU
|
||||
|
||||
from torch_geometric.nn import PointConv, fps, radius, global_max_pool, knn_interpolate
|
||||
try:
|
||||
from torch_geometric.nn import PointConv, fps, radius, global_max_pool, knn_interpolate
|
||||
except ImportError:
|
||||
print('Install torch-geometric to use this package.')
|
||||
|
||||
|
||||
class SAModule(torch.nn.Module):
|
||||
|
@@ -1,10 +1,77 @@
|
||||
#
|
||||
# Full Model Parts
|
||||
###################
|
||||
import torch
|
||||
from torch import nn
|
||||
from argparse import Namespace
|
||||
from typing import Union, List, Tuple
|
||||
|
||||
from .util import ShapeMixin
|
||||
import torch
|
||||
from abc import ABC
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .util import ShapeMixin, LightningBaseModule
|
||||
|
||||
|
||||
class AEBaseModule(LightningBaseModule, ABC):
|
||||
|
||||
def generate_random_image(self, dataloader: Union[None, str, DataLoader] = None,
|
||||
lat_min: Union[Tuple, List, None] = None,
|
||||
lat_max: Union[Tuple, List, None] = None):
|
||||
|
||||
assert bool(dataloader) ^ bool(lat_min and lat_max), 'Decide wether to give min, max or a dataloader, not both.'
|
||||
|
||||
min_max = self._find_min_max(dataloader) if dataloader else [None, None]
|
||||
# assert not any([x is None for x in min_max])
|
||||
lat_min = torch.as_tensor(lat_min or min_max[0])
|
||||
lat_max = lat_max or min_max[1]
|
||||
|
||||
random_z = torch.rand((1, self.lat_dim))
|
||||
random_z = random_z * (abs(lat_min) + lat_max) - abs(lat_min)
|
||||
|
||||
return self.decoder(random_z).squeeze()
|
||||
|
||||
def encode(self, x):
|
||||
if len(x.shape) == 3:
|
||||
x = x.unsqueeze(0)
|
||||
return self.encoder(x).squeeze()
|
||||
|
||||
def _find_min_max(self, dataloader):
|
||||
encodings = list()
|
||||
for batch in dataloader:
|
||||
encodings.append(self.encode(batch))
|
||||
encodings = torch.cat(encodings, dim=0)
|
||||
min_lat = encodings.min(dim=1)
|
||||
max_lat = encodings.max(dim=1)
|
||||
return min_lat, max_lat
|
||||
|
||||
def decode_lat_evenly(self, n: int,
|
||||
dataloader: Union[None, str, DataLoader] = None,
|
||||
lat_min: Union[Tuple, List, None] = None,
|
||||
lat_max: Union[Tuple, List, None] = None):
|
||||
assert bool(dataloader) ^ bool(lat_min and lat_max), 'Decide wether to give min, max or a dataloader, not both.'
|
||||
|
||||
min_max = self._find_min_max(dataloader) if dataloader else [None, None]
|
||||
|
||||
lat_min = lat_min or min_max[0]
|
||||
lat_max = lat_max or min_max[1]
|
||||
|
||||
random_latent_samples = torch.stack([torch.linspace(lat_min[i].item(), lat_max[i].item(), n)
|
||||
for i in range(self.params.lat_dim)], dim=-1).cpu().detach()
|
||||
return self.decode(random_latent_samples).cpu().detach()
|
||||
|
||||
def decode(self, z):
|
||||
if len(z.shape) == 1:
|
||||
z = z.unsqueeze(0)
|
||||
return self.decoder(z).squeeze()
|
||||
|
||||
def encode_and_restore(self, x):
|
||||
x = x.to(self.device)
|
||||
if len(x.shape) == 3:
|
||||
x = x.unsqueeze(0)
|
||||
z = self.encode(x)
|
||||
x_hat = self.decode(z)
|
||||
|
||||
return Namespace(main_out=x_hat.squeeze(), latent_out=z)
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
@@ -16,9 +83,12 @@ class Generator(nn.Module):
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
def __init__(self, out_channels, re_shape, lat_dim, use_norm=False, use_bias=True, dropout: Union[int, float] = 0,
|
||||
filters: List[int] = None, activation=nn.ReLU):
|
||||
filters: List[int] = None, kernels: List[int] = None, activation=nn.ReLU, **kwargs):
|
||||
super(Generator, self).__init__()
|
||||
assert filters, '"Filters" has to be a list of int len 3'
|
||||
assert filters, '"Filters" has to be a list of int.'
|
||||
assert filters, '"Filters" has to be a list of int.'
|
||||
assert len(filters) == len(kernels), '"Filters" and "Kernels" has to be of same length.'
|
||||
|
||||
self.filters = filters
|
||||
self.activation = activation
|
||||
self.inner_activation = activation()
|
||||
@@ -29,52 +99,35 @@ class Generator(nn.Module):
|
||||
# re_shape = (self.feature_mixed_dim // reduce(mul, re_shape[1:]), ) + tuple(re_shape[1:])
|
||||
|
||||
self.flat = Flatten(to=re_shape)
|
||||
self.de_conv_list = nn.ModuleList()
|
||||
|
||||
self.deconv1 = DeConvModule(re_shape, conv_filters=self.filters[0],
|
||||
conv_kernel=5,
|
||||
conv_padding=2,
|
||||
conv_stride=1,
|
||||
normalize=use_norm,
|
||||
activation=self.activation,
|
||||
interpolation_scale=2,
|
||||
dropout=self.dropout
|
||||
)
|
||||
last_shape = re_shape
|
||||
for conv_filter, conv_kernel in zip(filters, kernels):
|
||||
self.de_conv_list.append(DeConvModule(last_shape, conv_filters=self.filters[0],
|
||||
conv_kernel=conv_kernel,
|
||||
conv_padding=conv_kernel-2,
|
||||
conv_stride=conv_filter,
|
||||
normalize=use_norm,
|
||||
activation=self.activation,
|
||||
interpolation_scale=2,
|
||||
dropout=self.dropout
|
||||
)
|
||||
)
|
||||
last_shape = self.de_conv_list[-1].shape
|
||||
|
||||
self.deconv2 = DeConvModule(self.deconv1.shape, conv_filters=self.filters[1],
|
||||
conv_kernel=3,
|
||||
conv_padding=1,
|
||||
conv_stride=1,
|
||||
normalize=use_norm,
|
||||
activation=self.activation,
|
||||
interpolation_scale=2,
|
||||
dropout=self.dropout
|
||||
)
|
||||
|
||||
self.deconv3 = DeConvModule(self.deconv2.shape, conv_filters=self.filters[2],
|
||||
conv_kernel=3,
|
||||
conv_padding=1,
|
||||
conv_stride=1,
|
||||
normalize=use_norm,
|
||||
activation=self.activation,
|
||||
interpolation_scale=2,
|
||||
dropout=self.dropout
|
||||
)
|
||||
|
||||
self.deconv4 = DeConvModule(self.deconv3.shape, conv_filters=out_channels,
|
||||
conv_kernel=3,
|
||||
conv_padding=1,
|
||||
# normalize=norm,
|
||||
activation=self.out_activation
|
||||
)
|
||||
self.de_conv_out = DeConvModule(self.de_conv_list[-1].shape, conv_filters=out_channels, conv_kernel=3,
|
||||
conv_padding=1, activation=self.out_activation
|
||||
)
|
||||
|
||||
def forward(self, z):
|
||||
tensor = self.l1(z)
|
||||
tensor = self.inner_activation(tensor)
|
||||
tensor = self.flat(tensor)
|
||||
tensor = self.deconv1(tensor)
|
||||
tensor = self.deconv2(tensor)
|
||||
tensor = self.deconv3(tensor)
|
||||
tensor = self.deconv4(tensor)
|
||||
|
||||
for de_conv in self.de_conv_list:
|
||||
tensor = de_conv(tensor)
|
||||
|
||||
tensor = self.de_conv_out(tensor)
|
||||
return tensor
|
||||
|
||||
def size(self):
|
||||
@@ -119,12 +172,14 @@ class BaseEncoder(ShapeMixin, nn.Module):
|
||||
# noinspection PyUnresolvedReferences
|
||||
def __init__(self, in_shape, lat_dim=256, use_bias=True, use_norm=False, dropout: Union[int, float] = 0,
|
||||
latent_activation: Union[nn.Module, None] = None, activation: nn.Module = nn.ELU,
|
||||
filters: List[int] = None):
|
||||
filters: List[int] = None, kernels: List[int] = None, **kwargs):
|
||||
super(BaseEncoder, self).__init__()
|
||||
assert filters, '"Filters" has to be a list of int len 3'
|
||||
assert filters, '"Filters" has to be a list of int'
|
||||
assert kernels, '"Kernels" has to be a list of int'
|
||||
assert len(kernels) == len(filters), 'Length of "Filters" and "Kernels" has to be same.'
|
||||
|
||||
# Optional Padding for odd image-sizes
|
||||
# Obsolet, already Done by autopadding module on incoming tensors
|
||||
# Obsolet, cdan be done by autopadding module on incoming tensors
|
||||
# in_shape = [x+1 if x % 2 != 0 and idx else x for idx, x in enumerate(in_shape)]
|
||||
|
||||
# Parameters
|
||||
@@ -133,43 +188,29 @@ class BaseEncoder(ShapeMixin, nn.Module):
|
||||
self.use_bias = use_bias
|
||||
self.latent_activation = latent_activation() if latent_activation else None
|
||||
|
||||
self.conv_list = nn.ModuleList()
|
||||
|
||||
# Modules
|
||||
self.conv1 = ConvModule(self.in_shape, conv_filters=filters[0],
|
||||
conv_kernel=3,
|
||||
conv_padding=1,
|
||||
conv_stride=1,
|
||||
pooling_size=2,
|
||||
use_norm=use_norm,
|
||||
dropout=dropout,
|
||||
activation=activation
|
||||
)
|
||||
|
||||
self.conv2 = ConvModule(self.conv1.shape, conv_filters=filters[1],
|
||||
conv_kernel=3,
|
||||
conv_padding=1,
|
||||
conv_stride=1,
|
||||
pooling_size=2,
|
||||
use_norm=use_norm,
|
||||
dropout=dropout,
|
||||
activation=activation
|
||||
)
|
||||
|
||||
self.conv3 = ConvModule(self.conv2.shape, conv_filters=filters[2],
|
||||
conv_kernel=5,
|
||||
conv_padding=2,
|
||||
conv_stride=1,
|
||||
pooling_size=2,
|
||||
use_norm=use_norm,
|
||||
dropout=dropout,
|
||||
activation=activation
|
||||
)
|
||||
last_shape = self.in_shape
|
||||
for conv_filter, conv_kernel in zip(filters, kernels):
|
||||
self.conv_list.append(ConvModule(last_shape, conv_filters=conv_filter,
|
||||
conv_kernel=conv_kernel,
|
||||
conv_padding=conv_kernel-2,
|
||||
conv_stride=1,
|
||||
pooling_size=2,
|
||||
use_norm=use_norm,
|
||||
dropout=dropout,
|
||||
activation=activation
|
||||
)
|
||||
)
|
||||
last_shape = self.conv_list[-1].shape
|
||||
|
||||
self.flat = Flatten()
|
||||
|
||||
def forward(self, x):
|
||||
tensor = self.conv1(x)
|
||||
tensor = self.conv2(tensor)
|
||||
tensor = self.conv3(tensor)
|
||||
tensor = x
|
||||
for conv in self.conv_list:
|
||||
tensor = conv(tensor)
|
||||
tensor = self.flat(tensor)
|
||||
return tensor
|
||||
|
||||
|
@@ -1,7 +1,10 @@
|
||||
from functools import reduce
|
||||
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from operator import mul
|
||||
from torch import nn
|
||||
from torch import functional as F
|
||||
|
||||
@@ -102,6 +105,14 @@ class ShapeMixin:
|
||||
else:
|
||||
return -1
|
||||
|
||||
@property
|
||||
def flat_shape(self):
|
||||
shape = self.shape
|
||||
try:
|
||||
return reduce(mul, shape)
|
||||
except TypeError:
|
||||
return shape
|
||||
|
||||
|
||||
class F_x(ShapeMixin, nn.Module):
|
||||
def __init__(self, in_shape):
|
||||
@@ -175,7 +186,7 @@ class WeightInit:
|
||||
m.bias.data.fill_(0.01)
|
||||
|
||||
|
||||
class Filter(nn.Module):
|
||||
class Filter(nn.Module, ShapeMixin):
|
||||
|
||||
def __init__(self, in_shape, pos, dim=-1):
|
||||
super(Filter, self).__init__()
|
||||
@@ -210,11 +221,15 @@ class AutoPadToShape(object):
|
||||
def __call__(self, x):
|
||||
if not torch.is_tensor(x):
|
||||
x = torch.as_tensor(x)
|
||||
if x.shape[1:] == self.shape:
|
||||
if x.shape[1:] == self.shape or x.shape == self.shape:
|
||||
return x
|
||||
embedding = torch.zeros((x.shape[0], *self.shape))
|
||||
embedding[:, :x.shape[1], :x.shape[2], :x.shape[3]] = x
|
||||
return embedding
|
||||
|
||||
for i in range(-1, -len(self.shape), -1):
|
||||
idx = [0] * len(x.shape)
|
||||
idx[i] = self.shape[i] - x.shape[i]
|
||||
idx = tuple(idx)
|
||||
x = torch.nn.functional.pad(x, idx)
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
return f'AutoPadTransform({self.shape})'
|
||||
@@ -233,9 +248,9 @@ class Splitter(nn.Module):
|
||||
def __init__(self, in_shape, n, dim=-1):
|
||||
super(Splitter, self).__init__()
|
||||
|
||||
self.n = n
|
||||
self.dim = dim
|
||||
self.in_shape = in_shape
|
||||
self.n = n
|
||||
self.dim = dim if dim > 0 else len(self.in_shape) - abs(dim)
|
||||
|
||||
self.new_dim_size = (self.in_shape[self.dim] // self.n) + (1 if self.in_shape[self.dim] % self.n != 0 else 0)
|
||||
self._out_shape = tuple([x if self.dim != i else self.new_dim_size for i, x in enumerate(self.in_shape)])
|
||||
@@ -243,22 +258,23 @@ class Splitter(nn.Module):
|
||||
self.autopad = AutoPadToShape(self._out_shape)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = x.transpose(0, self.dim)
|
||||
dim = self.dim + 1 if len(self.in_shape) == (x.ndim -1) else self.dim
|
||||
x = x.transpose(0, dim)
|
||||
n_blocks = list()
|
||||
for block_idx in range(self.n):
|
||||
start = block_idx * self.new_dim_size
|
||||
end = (block_idx + 1) * self.new_dim_size
|
||||
block = self.autopad(x[:, :, start:end, :])
|
||||
|
||||
n_blocks.append(block.transpose(0, self.dim))
|
||||
block = x[start:end].transpose(0, dim)
|
||||
block = self.autopad(block)
|
||||
n_blocks.append(block)
|
||||
return n_blocks
|
||||
|
||||
|
||||
class Merger(nn.Module):
|
||||
class Merger(nn.Module, ShapeMixin):
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
y = self.forward([torch.randn(self.in_shape)])
|
||||
y = self.forward([torch.randn(self.in_shape) for _ in range(self.n)])
|
||||
return y.shape
|
||||
|
||||
def __init__(self, in_shape, n, dim=-1):
|
||||
|
Reference in New Issue
Block a user