469 lines
16 KiB
Python
469 lines
16 KiB
Python
from abc import ABC
|
|
from pathlib import Path
|
|
from typing import Union
|
|
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
import pytorch_lightning as pl
|
|
|
|
|
|
# Utility - Modules
|
|
###################
|
|
from torch.utils.data import DataLoader
|
|
|
|
from dataset.dataset import TrajData
|
|
|
|
|
|
class Flatten(nn.Module):
|
|
def __init__(self, to=(-1, )):
|
|
super(Flatten, self).__init__()
|
|
self.to = to
|
|
|
|
def forward(self, x):
|
|
return x.view(x.size(0), *self.to)
|
|
|
|
|
|
class Interpolate(nn.Module):
|
|
def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None):
|
|
super(Interpolate, self).__init__()
|
|
self.interp = nn.functional.interpolate
|
|
self.size = size
|
|
self.scale_factor = scale_factor
|
|
self.align_corners = align_corners
|
|
self.mode = mode
|
|
|
|
def forward(self, x):
|
|
x = self.interp(x, size=self.size, scale_factor=self.scale_factor,
|
|
mode=self.mode, align_corners=self.align_corners)
|
|
return x
|
|
|
|
|
|
class AutoPad(nn.Module):
|
|
|
|
def __init__(self, interpolations=3, base=2):
|
|
super(AutoPad, self).__init__()
|
|
self.fct = base ** interpolations
|
|
|
|
def forward(self, x):
|
|
x = F.pad(x,
|
|
[0,
|
|
(x.shape[-1] // self.fct + 1) * self.fct - x.shape[-1] if x.shape[-1] % self.fct != 0 else 0,
|
|
(x.shape[-2] // self.fct + 1) * self.fct - x.shape[-2] if x.shape[-2] % self.fct != 0 else 0,
|
|
0])
|
|
return x
|
|
|
|
|
|
class LightningBaseModule(pl.LightningModule, ABC):
|
|
|
|
@classmethod
|
|
def name(cls):
|
|
raise NotImplementedError('Give your model a name!')
|
|
|
|
@property
|
|
def shape(self):
|
|
try:
|
|
x = torch.randn(self.in_shape).unsqueeze(0)
|
|
output = self(x)
|
|
return output.shape[1:]
|
|
except Exception as e:
|
|
print(e)
|
|
return -1
|
|
|
|
def __init__(self, params):
|
|
super(LightningBaseModule, self).__init__()
|
|
self.hparams = params
|
|
|
|
# Data loading
|
|
# =============================================================================
|
|
# Dataset
|
|
self.dataset = TrajData('data')
|
|
|
|
def size(self):
|
|
return self.shape
|
|
|
|
def _move_to_model_device(self, x):
|
|
return x.cuda() if next(self.parameters()).is_cuda else x.cpu()
|
|
|
|
def save_to_disk(self, model_path):
|
|
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
|
|
if not (model_path / 'model_class.obj').exists():
|
|
with (model_path / 'model_class.obj').open('wb') as f:
|
|
torch.save(self.__class__, f)
|
|
return True
|
|
|
|
@pl.data_loader
|
|
def train_dataloader(self):
|
|
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
|
|
batch_size=self.hparams.data_param.batchsize,
|
|
num_workers=self.hparams.data_param.worker)
|
|
|
|
@pl.data_loader
|
|
def test_dataloader(self):
|
|
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
|
|
batch_size=self.hparams.data_param.batchsize,
|
|
num_workers=self.hparams.data_param.worker)
|
|
|
|
@pl.data_loader
|
|
def val_dataloader(self):
|
|
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
|
|
batch_size=self.hparams.data_param.batchsize,
|
|
num_workers=self.hparams.data_param.worker)
|
|
|
|
def configure_optimizers(self):
|
|
raise NotImplementedError
|
|
|
|
def forward(self, *args, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
def validation_step(self, *args, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
def validation_end(self, outputs):
|
|
raise NotImplementedError
|
|
|
|
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
def test_step(self, *args, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
def test_end(self, outputs):
|
|
from sklearn.metrics import roc_auc_score
|
|
|
|
y_scores, y_true = [], []
|
|
for output in outputs:
|
|
y_scores.append(output['y_pred'])
|
|
y_true.append(output['y_true'])
|
|
|
|
y_true = torch.cat(y_true, dim=0)
|
|
# FIXME: What did this do do i need it?
|
|
# y_true = (y_true != V.HOMOTOPIC).long()
|
|
y_scores = torch.cat(y_scores, dim=0)
|
|
|
|
roc_auc_scores = roc_auc_score(y_true.cpu().numpy(), y_scores.cpu().numpy())
|
|
print(f'AUC Score: {roc_auc_scores}')
|
|
return {'roc_auc_scores': roc_auc_scores}
|
|
|
|
def init_weights(self):
|
|
def _weight_init(m):
|
|
if hasattr(m, 'weight'):
|
|
if isinstance(m.weight, torch.Tensor):
|
|
torch.nn.init.xavier_uniform_(m.weight)
|
|
if hasattr(m, 'bias'):
|
|
if isinstance(m.bias, torch.Tensor):
|
|
m.bias.data.fill_(0.01)
|
|
self.apply(_weight_init)
|
|
|
|
|
|
#
|
|
# Sub - Modules
|
|
###################
|
|
class ConvModule(nn.Module):
|
|
|
|
@property
|
|
def shape(self):
|
|
x = torch.randn(self.in_shape).unsqueeze(0)
|
|
output = self(x)
|
|
return output.shape[1:]
|
|
|
|
def __init__(self, in_shape, activation: nn.Module = nn.ELU, pooling_size=None, use_bias=True, use_norm=True,
|
|
dropout: Union[int, float] = 0,
|
|
conv_filters=64, conv_kernel=5, conv_stride=1, conv_padding=0):
|
|
super(ConvModule, self).__init__()
|
|
|
|
# Module Paramters
|
|
self.in_shape = in_shape
|
|
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
|
self.activation = activation()
|
|
|
|
# Convolution Paramters
|
|
self.padding = conv_padding
|
|
self.stride = conv_stride
|
|
|
|
# Modules
|
|
self.dropout = nn.Dropout2d(dropout) if dropout else False
|
|
self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else False
|
|
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if use_norm else False
|
|
self.conv = nn.Conv2d(in_channels, conv_filters, conv_kernel, bias=use_bias,
|
|
padding=self.padding, stride=self.stride
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.norm(x) if self.norm else x
|
|
|
|
tensor = self.conv(x)
|
|
tensor = self.dropout(tensor) if self.dropout else tensor
|
|
tensor = self.pooling(tensor) if self.pooling else tensor
|
|
tensor = self.activation(tensor)
|
|
return tensor
|
|
|
|
|
|
class DeConvModule(nn.Module):
|
|
|
|
@property
|
|
def shape(self):
|
|
x = torch.randn(self.in_shape).unsqueeze(0)
|
|
output = self(x)
|
|
return output.shape[1:]
|
|
|
|
def __init__(self, in_shape, conv_filters=3, conv_kernel=5, conv_stride=1, conv_padding=0,
|
|
dropout: Union[int, float] = 0, autopad=False,
|
|
activation: Union[None, nn.Module] = nn.ReLU, interpolation_scale=None,
|
|
use_bias=True, normalize=False):
|
|
super(DeConvModule, self).__init__()
|
|
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
|
self.padding = conv_padding
|
|
self.stride = conv_stride
|
|
self.in_shape = in_shape
|
|
self.conv_filters = conv_filters
|
|
|
|
self.autopad = AutoPad() if autopad else False
|
|
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else False
|
|
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if normalize else False
|
|
self.dropout = nn.Dropout2d(dropout) if dropout else False
|
|
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, conv_kernel, bias=use_bias,
|
|
padding=self.padding, stride=self.stride)
|
|
|
|
self.activation = activation() if activation else None
|
|
|
|
def forward(self, x):
|
|
x = self.norm(x) if self.norm else x
|
|
x = self.dropout(x) if self.dropout else x
|
|
x = self.autopad(x) if self.autopad else x
|
|
x = self.interpolation(x) if self.interpolation else x
|
|
|
|
tensor = self.de_conv(x)
|
|
tensor = self.activation(tensor) if self.activation else tensor
|
|
return tensor
|
|
|
|
def size(self):
|
|
return self.shape
|
|
|
|
|
|
#
|
|
# Full Model Parts
|
|
###################
|
|
class Generator(nn.Module):
|
|
@property
|
|
def shape(self):
|
|
x = torch.randn(self.lat_dim).unsqueeze(0)
|
|
output = self(x)
|
|
return output.shape[1:]
|
|
|
|
# noinspection PyUnresolvedReferences
|
|
def __init__(self, out_channels, re_shape, lat_dim, use_norm=False, use_bias=True, dropout: Union[int, float] = 0,
|
|
filters: List[int] = None, activation=nn.ReLU):
|
|
super(Generator, self).__init__()
|
|
assert filters, '"Filters" has to be a list of int len 3'
|
|
self.filters = filters
|
|
self.activation = activation
|
|
self.inner_activation = activation()
|
|
self.out_activation = None
|
|
self.lat_dim = lat_dim
|
|
self.dropout = dropout
|
|
self.l1 = nn.Linear(self.lat_dim, reduce(mul, re_shape), bias=use_bias)
|
|
# re_shape = (self.lat_dim // reduce(mul, re_shape[1:]), ) + tuple(re_shape[1:])
|
|
|
|
self.flat = Flatten(to=re_shape)
|
|
|
|
self.deconv1 = DeConvModule(re_shape, conv_filters=self.filters[0],
|
|
conv_kernel=5,
|
|
conv_padding=2,
|
|
conv_stride=1,
|
|
normalize=use_norm,
|
|
activation=self.activation,
|
|
interpolation_scale=2,
|
|
dropout=self.dropout
|
|
)
|
|
|
|
self.deconv2 = DeConvModule(self.deconv1.shape, conv_filters=self.filters[1],
|
|
conv_kernel=3,
|
|
conv_padding=1,
|
|
conv_stride=1,
|
|
normalize=use_norm,
|
|
activation=self.activation,
|
|
interpolation_scale=2,
|
|
dropout=self.dropout
|
|
)
|
|
|
|
self.deconv3 = DeConvModule(self.deconv2.shape, conv_filters=self.filters[2],
|
|
conv_kernel=3,
|
|
conv_padding=1,
|
|
conv_stride=1,
|
|
normalize=use_norm,
|
|
activation=self.activation,
|
|
interpolation_scale=2,
|
|
dropout=self.dropout
|
|
)
|
|
|
|
self.deconv4 = DeConvModule(self.deconv3.shape, conv_filters=out_channels,
|
|
conv_kernel=3,
|
|
conv_padding=1,
|
|
# normalize=use_norm,
|
|
activation=self.out_activation
|
|
)
|
|
|
|
def forward(self, z):
|
|
tensor = self.l1(z)
|
|
tensor = self.inner_activation(tensor)
|
|
tensor = self.flat(tensor)
|
|
tensor = self.deconv1(tensor)
|
|
tensor = self.deconv2(tensor)
|
|
tensor = self.deconv3(tensor)
|
|
tensor = self.deconv4(tensor)
|
|
return tensor
|
|
|
|
def size(self):
|
|
return self.shape
|
|
|
|
|
|
class UnitGenerator(Generator):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
kwargs.update(use_norm=True)
|
|
super(UnitGenerator, self).__init__(*args, **kwargs)
|
|
self.norm_f = nn.BatchNorm1d(self.l1.out_features, eps=1e-04, affine=False)
|
|
self.norm1 = nn.BatchNorm2d(self.deconv1.conv_filters, eps=1e-04, affine=False)
|
|
self.norm2 = nn.BatchNorm2d(self.deconv2.conv_filters, eps=1e-04, affine=False)
|
|
self.norm3 = nn.BatchNorm2d(self.deconv3.conv_filters, eps=1e-04, affine=False)
|
|
|
|
def forward(self, z_c1_c2_c3):
|
|
z, c1, c2, c3 = z_c1_c2_c3
|
|
tensor = self.l1(z)
|
|
tensor = self.inner_activation(tensor)
|
|
tensor = self.norm(tensor)
|
|
tensor = self.flat(tensor)
|
|
|
|
tensor = self.deconv1(tensor) + c3
|
|
tensor = self.inner_activation(tensor)
|
|
tensor = self.norm1(tensor)
|
|
|
|
tensor = self.deconv2(tensor) + c2
|
|
tensor = self.inner_activation(tensor)
|
|
tensor = self.norm2(tensor)
|
|
|
|
tensor = self.deconv3(tensor) + c1
|
|
tensor = self.inner_activation(tensor)
|
|
tensor = self.norm3(tensor)
|
|
|
|
tensor = self.deconv4(tensor)
|
|
return tensor
|
|
|
|
|
|
class BaseEncoder(nn.Module):
|
|
@property
|
|
def shape(self):
|
|
x = torch.randn(self.in_shape).unsqueeze(0)
|
|
output = self(x)
|
|
return output.shape[1:]
|
|
|
|
# noinspection PyUnresolvedReferences
|
|
def __init__(self, in_shape, lat_dim=256, use_bias=True, use_norm=False, dropout: Union[int, float] = 0,
|
|
latent_activation: Union[nn.Module, None] = None, activation: nn.Module = nn.ELU,
|
|
filters: List[int] = None):
|
|
super(BaseEncoder, self).__init__()
|
|
assert filters, '"Filters" has to be a list of int len 3'
|
|
|
|
# Optional Padding for odd image-sizes
|
|
# Obsolet, already Done by autopadding module on incoming tensors
|
|
# in_shape = [x+1 if x % 2 != 0 and idx else x for idx, x in enumerate(in_shape)]
|
|
|
|
# Parameters
|
|
self.lat_dim = lat_dim
|
|
self.in_shape = in_shape
|
|
self.use_bias = use_bias
|
|
self.latent_activation = latent_activation() if latent_activation else None
|
|
|
|
# Modules
|
|
self.conv1 = ConvModule(self.in_shape, conv_filters=filters[0],
|
|
conv_kernel=3,
|
|
conv_padding=1,
|
|
conv_stride=1,
|
|
pooling_size=2,
|
|
use_norm=use_norm,
|
|
dropout=dropout,
|
|
activation=activation
|
|
)
|
|
|
|
self.conv2 = ConvModule(self.conv1.shape, conv_filters=filters[1],
|
|
conv_kernel=3,
|
|
conv_padding=1,
|
|
conv_stride=1,
|
|
pooling_size=2,
|
|
use_norm=use_norm,
|
|
dropout=dropout,
|
|
activation=activation
|
|
)
|
|
|
|
self.conv3 = ConvModule(self.conv2.shape, conv_filters=filters[2],
|
|
conv_kernel=5,
|
|
conv_padding=2,
|
|
conv_stride=1,
|
|
pooling_size=2,
|
|
use_norm=use_norm,
|
|
dropout=dropout,
|
|
activation=activation
|
|
)
|
|
|
|
self.flat = Flatten()
|
|
|
|
def forward(self, x):
|
|
tensor = self.conv1(x)
|
|
tensor = self.conv2(tensor)
|
|
tensor = self.conv3(tensor)
|
|
tensor = self.flat(tensor)
|
|
return tensor
|
|
|
|
|
|
class UnitEncoder(BaseEncoder):
|
|
# noinspection PyUnresolvedReferences
|
|
def __init__(self, *args, **kwargs):
|
|
kwargs.update(use_norm=True)
|
|
super(UnitEncoder, self).__init__(*args, **kwargs)
|
|
self.l1 = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
|
|
|
def forward(self, x):
|
|
c1 = self.conv1(x)
|
|
c2 = self.conv2(c1)
|
|
c3 = self.conv3(c2)
|
|
tensor = self.flat(c3)
|
|
l1 = self.l1(tensor)
|
|
return c1, c2, c3, l1
|
|
|
|
|
|
class VariationalEncoder(BaseEncoder):
|
|
# noinspection PyUnresolvedReferences
|
|
def __init__(self, *args, **kwargs):
|
|
super(VariationalEncoder, self).__init__(*args, **kwargs)
|
|
|
|
self.logvar = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
|
self.mu = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
|
|
|
@staticmethod
|
|
def reparameterize(mu, logvar):
|
|
std = torch.exp(0.5*logvar)
|
|
eps = torch.randn_like(std)
|
|
return mu + eps*std
|
|
|
|
def forward(self, x):
|
|
tensor = super(VariationalEncoder, self).forward(x)
|
|
mu = self.mu(tensor)
|
|
logvar = self.logvar(tensor)
|
|
z = self.reparameterize(mu, logvar)
|
|
return mu, logvar, z
|
|
|
|
|
|
class Encoder(BaseEncoder):
|
|
# noinspection PyUnresolvedReferences
|
|
def __init__(self, *args, **kwargs):
|
|
super(Encoder, self).__init__(*args, **kwargs)
|
|
|
|
self.l1 = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
|
|
|
def forward(self, x):
|
|
tensor = super(Encoder, self).forward(x)
|
|
tensor = self.l1(tensor)
|
|
tensor = self.latent_activation(tensor) if self.latent_activation else tensor
|
|
return tensor
|