hom_traj_gen/lib/models/blocks.py
Steffen Illium 91ecf157d6 initial
2020-02-13 20:28:20 +01:00

469 lines
16 KiB
Python

from abc import ABC
from pathlib import Path
from typing import Union
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl
# Utility - Modules
###################
from torch.utils.data import DataLoader
from dataset.dataset import TrajData
class Flatten(nn.Module):
def __init__(self, to=(-1, )):
super(Flatten, self).__init__()
self.to = to
def forward(self, x):
return x.view(x.size(0), *self.to)
class Interpolate(nn.Module):
def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None):
super(Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.size = size
self.scale_factor = scale_factor
self.align_corners = align_corners
self.mode = mode
def forward(self, x):
x = self.interp(x, size=self.size, scale_factor=self.scale_factor,
mode=self.mode, align_corners=self.align_corners)
return x
class AutoPad(nn.Module):
def __init__(self, interpolations=3, base=2):
super(AutoPad, self).__init__()
self.fct = base ** interpolations
def forward(self, x):
x = F.pad(x,
[0,
(x.shape[-1] // self.fct + 1) * self.fct - x.shape[-1] if x.shape[-1] % self.fct != 0 else 0,
(x.shape[-2] // self.fct + 1) * self.fct - x.shape[-2] if x.shape[-2] % self.fct != 0 else 0,
0])
return x
class LightningBaseModule(pl.LightningModule, ABC):
@classmethod
def name(cls):
raise NotImplementedError('Give your model a name!')
@property
def shape(self):
try:
x = torch.randn(self.in_shape).unsqueeze(0)
output = self(x)
return output.shape[1:]
except Exception as e:
print(e)
return -1
def __init__(self, params):
super(LightningBaseModule, self).__init__()
self.hparams = params
# Data loading
# =============================================================================
# Dataset
self.dataset = TrajData('data')
def size(self):
return self.shape
def _move_to_model_device(self, x):
return x.cuda() if next(self.parameters()).is_cuda else x.cpu()
def save_to_disk(self, model_path):
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
if not (model_path / 'model_class.obj').exists():
with (model_path / 'model_class.obj').open('wb') as f:
torch.save(self.__class__, f)
return True
@pl.data_loader
def train_dataloader(self):
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
@pl.data_loader
def test_dataloader(self):
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
@pl.data_loader
def val_dataloader(self):
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
def configure_optimizers(self):
raise NotImplementedError
def forward(self, *args, **kwargs):
raise NotImplementedError
def validation_step(self, *args, **kwargs):
raise NotImplementedError
def validation_end(self, outputs):
raise NotImplementedError
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
raise NotImplementedError
def test_step(self, *args, **kwargs):
raise NotImplementedError
def test_end(self, outputs):
from sklearn.metrics import roc_auc_score
y_scores, y_true = [], []
for output in outputs:
y_scores.append(output['y_pred'])
y_true.append(output['y_true'])
y_true = torch.cat(y_true, dim=0)
# FIXME: What did this do do i need it?
# y_true = (y_true != V.HOMOTOPIC).long()
y_scores = torch.cat(y_scores, dim=0)
roc_auc_scores = roc_auc_score(y_true.cpu().numpy(), y_scores.cpu().numpy())
print(f'AUC Score: {roc_auc_scores}')
return {'roc_auc_scores': roc_auc_scores}
def init_weights(self):
def _weight_init(m):
if hasattr(m, 'weight'):
if isinstance(m.weight, torch.Tensor):
torch.nn.init.xavier_uniform_(m.weight)
if hasattr(m, 'bias'):
if isinstance(m.bias, torch.Tensor):
m.bias.data.fill_(0.01)
self.apply(_weight_init)
#
# Sub - Modules
###################
class ConvModule(nn.Module):
@property
def shape(self):
x = torch.randn(self.in_shape).unsqueeze(0)
output = self(x)
return output.shape[1:]
def __init__(self, in_shape, activation: nn.Module = nn.ELU, pooling_size=None, use_bias=True, use_norm=True,
dropout: Union[int, float] = 0,
conv_filters=64, conv_kernel=5, conv_stride=1, conv_padding=0):
super(ConvModule, self).__init__()
# Module Paramters
self.in_shape = in_shape
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
self.activation = activation()
# Convolution Paramters
self.padding = conv_padding
self.stride = conv_stride
# Modules
self.dropout = nn.Dropout2d(dropout) if dropout else False
self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else False
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if use_norm else False
self.conv = nn.Conv2d(in_channels, conv_filters, conv_kernel, bias=use_bias,
padding=self.padding, stride=self.stride
)
def forward(self, x):
x = self.norm(x) if self.norm else x
tensor = self.conv(x)
tensor = self.dropout(tensor) if self.dropout else tensor
tensor = self.pooling(tensor) if self.pooling else tensor
tensor = self.activation(tensor)
return tensor
class DeConvModule(nn.Module):
@property
def shape(self):
x = torch.randn(self.in_shape).unsqueeze(0)
output = self(x)
return output.shape[1:]
def __init__(self, in_shape, conv_filters=3, conv_kernel=5, conv_stride=1, conv_padding=0,
dropout: Union[int, float] = 0, autopad=False,
activation: Union[None, nn.Module] = nn.ReLU, interpolation_scale=None,
use_bias=True, normalize=False):
super(DeConvModule, self).__init__()
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
self.padding = conv_padding
self.stride = conv_stride
self.in_shape = in_shape
self.conv_filters = conv_filters
self.autopad = AutoPad() if autopad else False
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else False
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if normalize else False
self.dropout = nn.Dropout2d(dropout) if dropout else False
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, conv_kernel, bias=use_bias,
padding=self.padding, stride=self.stride)
self.activation = activation() if activation else None
def forward(self, x):
x = self.norm(x) if self.norm else x
x = self.dropout(x) if self.dropout else x
x = self.autopad(x) if self.autopad else x
x = self.interpolation(x) if self.interpolation else x
tensor = self.de_conv(x)
tensor = self.activation(tensor) if self.activation else tensor
return tensor
def size(self):
return self.shape
#
# Full Model Parts
###################
class Generator(nn.Module):
@property
def shape(self):
x = torch.randn(self.lat_dim).unsqueeze(0)
output = self(x)
return output.shape[1:]
# noinspection PyUnresolvedReferences
def __init__(self, out_channels, re_shape, lat_dim, use_norm=False, use_bias=True, dropout: Union[int, float] = 0,
filters: List[int] = None, activation=nn.ReLU):
super(Generator, self).__init__()
assert filters, '"Filters" has to be a list of int len 3'
self.filters = filters
self.activation = activation
self.inner_activation = activation()
self.out_activation = None
self.lat_dim = lat_dim
self.dropout = dropout
self.l1 = nn.Linear(self.lat_dim, reduce(mul, re_shape), bias=use_bias)
# re_shape = (self.lat_dim // reduce(mul, re_shape[1:]), ) + tuple(re_shape[1:])
self.flat = Flatten(to=re_shape)
self.deconv1 = DeConvModule(re_shape, conv_filters=self.filters[0],
conv_kernel=5,
conv_padding=2,
conv_stride=1,
normalize=use_norm,
activation=self.activation,
interpolation_scale=2,
dropout=self.dropout
)
self.deconv2 = DeConvModule(self.deconv1.shape, conv_filters=self.filters[1],
conv_kernel=3,
conv_padding=1,
conv_stride=1,
normalize=use_norm,
activation=self.activation,
interpolation_scale=2,
dropout=self.dropout
)
self.deconv3 = DeConvModule(self.deconv2.shape, conv_filters=self.filters[2],
conv_kernel=3,
conv_padding=1,
conv_stride=1,
normalize=use_norm,
activation=self.activation,
interpolation_scale=2,
dropout=self.dropout
)
self.deconv4 = DeConvModule(self.deconv3.shape, conv_filters=out_channels,
conv_kernel=3,
conv_padding=1,
# normalize=use_norm,
activation=self.out_activation
)
def forward(self, z):
tensor = self.l1(z)
tensor = self.inner_activation(tensor)
tensor = self.flat(tensor)
tensor = self.deconv1(tensor)
tensor = self.deconv2(tensor)
tensor = self.deconv3(tensor)
tensor = self.deconv4(tensor)
return tensor
def size(self):
return self.shape
class UnitGenerator(Generator):
def __init__(self, *args, **kwargs):
kwargs.update(use_norm=True)
super(UnitGenerator, self).__init__(*args, **kwargs)
self.norm_f = nn.BatchNorm1d(self.l1.out_features, eps=1e-04, affine=False)
self.norm1 = nn.BatchNorm2d(self.deconv1.conv_filters, eps=1e-04, affine=False)
self.norm2 = nn.BatchNorm2d(self.deconv2.conv_filters, eps=1e-04, affine=False)
self.norm3 = nn.BatchNorm2d(self.deconv3.conv_filters, eps=1e-04, affine=False)
def forward(self, z_c1_c2_c3):
z, c1, c2, c3 = z_c1_c2_c3
tensor = self.l1(z)
tensor = self.inner_activation(tensor)
tensor = self.norm(tensor)
tensor = self.flat(tensor)
tensor = self.deconv1(tensor) + c3
tensor = self.inner_activation(tensor)
tensor = self.norm1(tensor)
tensor = self.deconv2(tensor) + c2
tensor = self.inner_activation(tensor)
tensor = self.norm2(tensor)
tensor = self.deconv3(tensor) + c1
tensor = self.inner_activation(tensor)
tensor = self.norm3(tensor)
tensor = self.deconv4(tensor)
return tensor
class BaseEncoder(nn.Module):
@property
def shape(self):
x = torch.randn(self.in_shape).unsqueeze(0)
output = self(x)
return output.shape[1:]
# noinspection PyUnresolvedReferences
def __init__(self, in_shape, lat_dim=256, use_bias=True, use_norm=False, dropout: Union[int, float] = 0,
latent_activation: Union[nn.Module, None] = None, activation: nn.Module = nn.ELU,
filters: List[int] = None):
super(BaseEncoder, self).__init__()
assert filters, '"Filters" has to be a list of int len 3'
# Optional Padding for odd image-sizes
# Obsolet, already Done by autopadding module on incoming tensors
# in_shape = [x+1 if x % 2 != 0 and idx else x for idx, x in enumerate(in_shape)]
# Parameters
self.lat_dim = lat_dim
self.in_shape = in_shape
self.use_bias = use_bias
self.latent_activation = latent_activation() if latent_activation else None
# Modules
self.conv1 = ConvModule(self.in_shape, conv_filters=filters[0],
conv_kernel=3,
conv_padding=1,
conv_stride=1,
pooling_size=2,
use_norm=use_norm,
dropout=dropout,
activation=activation
)
self.conv2 = ConvModule(self.conv1.shape, conv_filters=filters[1],
conv_kernel=3,
conv_padding=1,
conv_stride=1,
pooling_size=2,
use_norm=use_norm,
dropout=dropout,
activation=activation
)
self.conv3 = ConvModule(self.conv2.shape, conv_filters=filters[2],
conv_kernel=5,
conv_padding=2,
conv_stride=1,
pooling_size=2,
use_norm=use_norm,
dropout=dropout,
activation=activation
)
self.flat = Flatten()
def forward(self, x):
tensor = self.conv1(x)
tensor = self.conv2(tensor)
tensor = self.conv3(tensor)
tensor = self.flat(tensor)
return tensor
class UnitEncoder(BaseEncoder):
# noinspection PyUnresolvedReferences
def __init__(self, *args, **kwargs):
kwargs.update(use_norm=True)
super(UnitEncoder, self).__init__(*args, **kwargs)
self.l1 = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
def forward(self, x):
c1 = self.conv1(x)
c2 = self.conv2(c1)
c3 = self.conv3(c2)
tensor = self.flat(c3)
l1 = self.l1(tensor)
return c1, c2, c3, l1
class VariationalEncoder(BaseEncoder):
# noinspection PyUnresolvedReferences
def __init__(self, *args, **kwargs):
super(VariationalEncoder, self).__init__(*args, **kwargs)
self.logvar = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
self.mu = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
@staticmethod
def reparameterize(mu, logvar):
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return mu + eps*std
def forward(self, x):
tensor = super(VariationalEncoder, self).forward(x)
mu = self.mu(tensor)
logvar = self.logvar(tensor)
z = self.reparameterize(mu, logvar)
return mu, logvar, z
class Encoder(BaseEncoder):
# noinspection PyUnresolvedReferences
def __init__(self, *args, **kwargs):
super(Encoder, self).__init__(*args, **kwargs)
self.l1 = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
def forward(self, x):
tensor = super(Encoder, self).forward(x)
tensor = self.l1(tensor)
tensor = self.latent_activation(tensor) if self.latent_activation else tensor
return tensor