This commit is contained in:
Si11ium
2020-10-07 15:21:45 +02:00
parent 5848b528f0
commit f296ba78b9
6 changed files with 78 additions and 39 deletions

View File

@ -2,14 +2,18 @@
# Full Model Parts
###################
from argparse import Namespace
from functools import reduce
from typing import Union, List, Tuple
import torch
from abc import ABC
from operator import mul
from torch import nn
from torch.utils.data import DataLoader
from .util import ShapeMixin, LightningBaseModule
from .blocks import ConvModule, DeConvModule, LinearModule
from .util import ShapeMixin, LightningBaseModule, Flatten
class AEBaseModule(LightningBaseModule, ABC):
@ -33,7 +37,7 @@ class AEBaseModule(LightningBaseModule, ABC):
def encode(self, x):
if len(x.shape) == 3:
x = x.unsqueeze(0)
return self.encoder(x).squeeze()
return self.encoder(x)
def _find_min_max(self, dataloader):
encodings = list()
@ -60,56 +64,61 @@ class AEBaseModule(LightningBaseModule, ABC):
return self.decode(random_latent_samples).cpu().detach()
def decode(self, z):
if len(z.shape) == 1:
z = z.unsqueeze(0)
try:
if len(z.shape) == 1:
z = z.unsqueeze(0)
except AttributeError:
# Does not seem to be a tensor.
pass
return self.decoder(z).squeeze()
def encode_and_restore(self, x):
x = x.to(self.device)
x = self.transfer_batch_to_device(x, self.device)
if len(x.shape) == 3:
x = x.unsqueeze(0)
z = self.encode(x)
try:
z = z.squeeze()
except AttributeError:
# Does not seem to be a tensor.
pass
x_hat = self.decode(z)
return Namespace(main_out=x_hat.squeeze(), latent_out=z)
class Generator(nn.Module):
@property
def shape(self):
x = torch.randn(self.lat_dim).unsqueeze(0)
output = self(x)
return output.shape[1:]
class Generator(ShapeMixin, nn.Module):
# noinspection PyUnresolvedReferences
def __init__(self, out_channels, re_shape, lat_dim, use_norm=False, use_bias=True, dropout: Union[int, float] = 0,
filters: List[int] = None, kernels: List[int] = None, activation=nn.ReLU, **kwargs):
def __init__(self, in_shape, out_channels, re_shape, use_norm=False, use_bias=True,
dropout: Union[int, float] = 0, interpolations: List[int] = None,
filters: List[int] = None, kernels: List[int] = None, activation=nn.ReLU,
**kwargs):
super(Generator, self).__init__()
assert filters, '"Filters" has to be a list of int.'
assert filters, '"Filters" has to be a list of int.'
assert len(filters) == len(kernels), '"Filters" and "Kernels" has to be of same length.'
self.filters = filters
self.activation = activation
self.inner_activation = activation()
interpolations = interpolations or [2, 2, 2]
self.in_shape = in_shape
self.activation = activation()
self.out_activation = None
self.lat_dim = lat_dim
self.dropout = dropout
self.l1 = nn.Linear(self.lat_dim, reduce(mul, re_shape), bias=use_bias)
self.l1 = LinearModule(in_shape, reduce(mul, re_shape), bias=use_bias, activation=activation)
# re_shape = (self.feature_mixed_dim // reduce(mul, re_shape[1:]), ) + tuple(re_shape[1:])
self.flat = Flatten(to=re_shape)
self.flat = Flatten(self.l1.shape, to=re_shape)
self.de_conv_list = nn.ModuleList()
last_shape = re_shape
for conv_filter, conv_kernel in zip(filters, kernels):
self.de_conv_list.append(DeConvModule(last_shape, conv_filters=self.filters[0],
for conv_filter, conv_kernel, interpolation in zip(reversed(filters), kernels, interpolations):
self.de_conv_list.append(DeConvModule(last_shape, conv_filters=conv_filter,
conv_kernel=conv_kernel,
conv_padding=conv_kernel-2,
conv_stride=conv_filter,
conv_stride=1,
normalize=use_norm,
activation=self.activation,
interpolation_scale=2,
activation=activation,
interpolation_scale=interpolation,
dropout=self.dropout
)
)
@ -121,7 +130,7 @@ class Generator(nn.Module):
def forward(self, z):
tensor = self.l1(z)
tensor = self.inner_activation(tensor)
tensor = self.activation(tensor)
tensor = self.flat(tensor)
for de_conv in self.de_conv_list:
@ -204,8 +213,9 @@ class BaseEncoder(ShapeMixin, nn.Module):
)
)
last_shape = self.conv_list[-1].shape
self.last_conv_shape = last_shape
self.flat = Flatten()
self.flat = Flatten(self.last_conv_shape)
def forward(self, x):
tensor = x
@ -254,11 +264,11 @@ class VariationalEncoder(BaseEncoder):
class Encoder(BaseEncoder):
# noinspection PyUnresolvedReferences
def __init__(self, *args, **kwargs):
super(Encoder, self).__init__(*args, **kwargs)
self.l1 = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
self.l1 = nn.Linear(self.flat.shape, self.lat_dim, bias=self.use_bias)
def forward(self, x):
tensor = super(Encoder, self).forward(x)

View File

@ -6,7 +6,7 @@ from pathlib import Path
import torch
from operator import mul
from torch import nn
from torch import functional as F
from torch.nn import functional as F
import pytorch_lightning as pl
@ -92,7 +92,14 @@ class ShapeMixin:
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
if self.in_shape is not None:
x = torch.randn(self.in_shape)
try:
device = self.device
except AttributeError:
try:
device = next(self.parameters()).device
except StopIteration:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
x = torch.randn(self.in_shape, device=device)
# This is needed for BatchNorm shape checking
x = torch.stack((x, x))
@ -248,7 +255,7 @@ class Splitter(nn.Module):
def __init__(self, in_shape, n, dim=-1):
super(Splitter, self).__init__()
self.in_shape = in_shape
self.in_shape = (in_shape, ) if isinstance(in_shape, int) else in_shape
self.n = n
self.dim = dim if dim > 0 else len(self.in_shape) - abs(dim)