Al Lot
This commit is contained in:
@ -2,14 +2,18 @@
|
||||
# Full Model Parts
|
||||
###################
|
||||
from argparse import Namespace
|
||||
from functools import reduce
|
||||
from typing import Union, List, Tuple
|
||||
|
||||
import torch
|
||||
from abc import ABC
|
||||
from operator import mul
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .util import ShapeMixin, LightningBaseModule
|
||||
from .blocks import ConvModule, DeConvModule, LinearModule
|
||||
|
||||
from .util import ShapeMixin, LightningBaseModule, Flatten
|
||||
|
||||
|
||||
class AEBaseModule(LightningBaseModule, ABC):
|
||||
@ -33,7 +37,7 @@ class AEBaseModule(LightningBaseModule, ABC):
|
||||
def encode(self, x):
|
||||
if len(x.shape) == 3:
|
||||
x = x.unsqueeze(0)
|
||||
return self.encoder(x).squeeze()
|
||||
return self.encoder(x)
|
||||
|
||||
def _find_min_max(self, dataloader):
|
||||
encodings = list()
|
||||
@ -60,56 +64,61 @@ class AEBaseModule(LightningBaseModule, ABC):
|
||||
return self.decode(random_latent_samples).cpu().detach()
|
||||
|
||||
def decode(self, z):
|
||||
if len(z.shape) == 1:
|
||||
z = z.unsqueeze(0)
|
||||
try:
|
||||
if len(z.shape) == 1:
|
||||
z = z.unsqueeze(0)
|
||||
except AttributeError:
|
||||
# Does not seem to be a tensor.
|
||||
pass
|
||||
return self.decoder(z).squeeze()
|
||||
|
||||
def encode_and_restore(self, x):
|
||||
x = x.to(self.device)
|
||||
x = self.transfer_batch_to_device(x, self.device)
|
||||
if len(x.shape) == 3:
|
||||
x = x.unsqueeze(0)
|
||||
z = self.encode(x)
|
||||
try:
|
||||
z = z.squeeze()
|
||||
except AttributeError:
|
||||
# Does not seem to be a tensor.
|
||||
pass
|
||||
x_hat = self.decode(z)
|
||||
|
||||
return Namespace(main_out=x_hat.squeeze(), latent_out=z)
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
@property
|
||||
def shape(self):
|
||||
x = torch.randn(self.lat_dim).unsqueeze(0)
|
||||
output = self(x)
|
||||
return output.shape[1:]
|
||||
class Generator(ShapeMixin, nn.Module):
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
def __init__(self, out_channels, re_shape, lat_dim, use_norm=False, use_bias=True, dropout: Union[int, float] = 0,
|
||||
filters: List[int] = None, kernels: List[int] = None, activation=nn.ReLU, **kwargs):
|
||||
def __init__(self, in_shape, out_channels, re_shape, use_norm=False, use_bias=True,
|
||||
dropout: Union[int, float] = 0, interpolations: List[int] = None,
|
||||
filters: List[int] = None, kernels: List[int] = None, activation=nn.ReLU,
|
||||
**kwargs):
|
||||
super(Generator, self).__init__()
|
||||
assert filters, '"Filters" has to be a list of int.'
|
||||
assert filters, '"Filters" has to be a list of int.'
|
||||
assert len(filters) == len(kernels), '"Filters" and "Kernels" has to be of same length.'
|
||||
|
||||
self.filters = filters
|
||||
self.activation = activation
|
||||
self.inner_activation = activation()
|
||||
interpolations = interpolations or [2, 2, 2]
|
||||
|
||||
self.in_shape = in_shape
|
||||
self.activation = activation()
|
||||
self.out_activation = None
|
||||
self.lat_dim = lat_dim
|
||||
self.dropout = dropout
|
||||
self.l1 = nn.Linear(self.lat_dim, reduce(mul, re_shape), bias=use_bias)
|
||||
self.l1 = LinearModule(in_shape, reduce(mul, re_shape), bias=use_bias, activation=activation)
|
||||
# re_shape = (self.feature_mixed_dim // reduce(mul, re_shape[1:]), ) + tuple(re_shape[1:])
|
||||
|
||||
self.flat = Flatten(to=re_shape)
|
||||
self.flat = Flatten(self.l1.shape, to=re_shape)
|
||||
self.de_conv_list = nn.ModuleList()
|
||||
|
||||
last_shape = re_shape
|
||||
for conv_filter, conv_kernel in zip(filters, kernels):
|
||||
self.de_conv_list.append(DeConvModule(last_shape, conv_filters=self.filters[0],
|
||||
for conv_filter, conv_kernel, interpolation in zip(reversed(filters), kernels, interpolations):
|
||||
self.de_conv_list.append(DeConvModule(last_shape, conv_filters=conv_filter,
|
||||
conv_kernel=conv_kernel,
|
||||
conv_padding=conv_kernel-2,
|
||||
conv_stride=conv_filter,
|
||||
conv_stride=1,
|
||||
normalize=use_norm,
|
||||
activation=self.activation,
|
||||
interpolation_scale=2,
|
||||
activation=activation,
|
||||
interpolation_scale=interpolation,
|
||||
dropout=self.dropout
|
||||
)
|
||||
)
|
||||
@ -121,7 +130,7 @@ class Generator(nn.Module):
|
||||
|
||||
def forward(self, z):
|
||||
tensor = self.l1(z)
|
||||
tensor = self.inner_activation(tensor)
|
||||
tensor = self.activation(tensor)
|
||||
tensor = self.flat(tensor)
|
||||
|
||||
for de_conv in self.de_conv_list:
|
||||
@ -204,8 +213,9 @@ class BaseEncoder(ShapeMixin, nn.Module):
|
||||
)
|
||||
)
|
||||
last_shape = self.conv_list[-1].shape
|
||||
self.last_conv_shape = last_shape
|
||||
|
||||
self.flat = Flatten()
|
||||
self.flat = Flatten(self.last_conv_shape)
|
||||
|
||||
def forward(self, x):
|
||||
tensor = x
|
||||
@ -254,11 +264,11 @@ class VariationalEncoder(BaseEncoder):
|
||||
|
||||
|
||||
class Encoder(BaseEncoder):
|
||||
# noinspection PyUnresolvedReferences
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Encoder, self).__init__(*args, **kwargs)
|
||||
|
||||
self.l1 = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
||||
self.l1 = nn.Linear(self.flat.shape, self.lat_dim, bias=self.use_bias)
|
||||
|
||||
def forward(self, x):
|
||||
tensor = super(Encoder, self).forward(x)
|
||||
|
@ -6,7 +6,7 @@ from pathlib import Path
|
||||
import torch
|
||||
from operator import mul
|
||||
from torch import nn
|
||||
from torch import functional as F
|
||||
from torch.nn import functional as F
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
||||
@ -92,7 +92,14 @@ class ShapeMixin:
|
||||
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
|
||||
|
||||
if self.in_shape is not None:
|
||||
x = torch.randn(self.in_shape)
|
||||
try:
|
||||
device = self.device
|
||||
except AttributeError:
|
||||
try:
|
||||
device = next(self.parameters()).device
|
||||
except StopIteration:
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
x = torch.randn(self.in_shape, device=device)
|
||||
# This is needed for BatchNorm shape checking
|
||||
x = torch.stack((x, x))
|
||||
|
||||
@ -248,7 +255,7 @@ class Splitter(nn.Module):
|
||||
def __init__(self, in_shape, n, dim=-1):
|
||||
super(Splitter, self).__init__()
|
||||
|
||||
self.in_shape = in_shape
|
||||
self.in_shape = (in_shape, ) if isinstance(in_shape, int) else in_shape
|
||||
self.n = n
|
||||
self.dim = dim if dim > 0 else len(self.in_shape) - abs(dim)
|
||||
|
||||
|
Reference in New Issue
Block a user