245 lines
8.9 KiB
Python
245 lines
8.9 KiB
Python
from functools import reduce
|
|
from operator import mul
|
|
|
|
from random import choice
|
|
|
|
import torch
|
|
|
|
from torch import nn
|
|
from torch.optim import Adam
|
|
|
|
from datasets.mnist import MyMNIST
|
|
from datasets.trajectory_dataset import TrajData
|
|
from ml_lib.modules.blocks import ConvModule, DeConvModule
|
|
from ml_lib.modules.utils import LightningBaseModule, Flatten
|
|
|
|
import matplotlib.pyplot as plt
|
|
import variables as V
|
|
from generator_eval import GeneratorVisualizer
|
|
|
|
|
|
class CNNRouteGeneratorModel(LightningBaseModule):
|
|
torch.autograd.set_detect_anomaly(True)
|
|
name = 'CNNRouteGenerator'
|
|
|
|
def configure_optimizers(self):
|
|
return Adam(self.parameters(), lr=self.hparams.train_param.lr)
|
|
|
|
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
|
batch_x, _ = batch_xy
|
|
reconstruction, z, mu, logvar = self(batch_x)
|
|
|
|
recon_loss = self.criterion(reconstruction, batch_x)
|
|
|
|
kldivergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
|
|
|
loss = recon_loss + kldivergence
|
|
return dict(loss=loss, log=dict(reconstruction_loss=recon_loss, loss=loss, kld_loss=kldivergence))
|
|
|
|
def _test_val_step(self, batch_xy, batch_nb, *args):
|
|
batch_x, _ = batch_xy
|
|
|
|
mu, logvar = self.encoder(batch_x)
|
|
z = self.reparameterize(mu, logvar)
|
|
|
|
reconstruction = self.decoder(mu)
|
|
return_dict = dict(input=batch_x, batch_nb=batch_nb, output=reconstruction, z=z, mu=mu, logvar=logvar)
|
|
|
|
labels = torch.full((batch_x.shape[0], 1), V.ANY)
|
|
|
|
return_dict.update(labels=self._move_to_model_device(labels))
|
|
return return_dict
|
|
|
|
def _test_val_epoch_end(self, outputs, test=False):
|
|
plt.close('all')
|
|
|
|
g = GeneratorVisualizer(choice(outputs))
|
|
fig = g.draw_io_bundle()
|
|
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
|
|
plt.clf()
|
|
|
|
fig = g.draw_latent()
|
|
self.logger.log_image(f'{self.name}_Latent', fig, step=self.global_step)
|
|
plt.clf()
|
|
|
|
return dict(epoch=self.current_epoch)
|
|
|
|
def validation_step(self, *args):
|
|
return self._test_val_step(*args)
|
|
|
|
def validation_epoch_end(self, outputs: list):
|
|
return self._test_val_epoch_end(outputs)
|
|
|
|
def test_step(self, *args):
|
|
return self._test_val_step(*args)
|
|
|
|
def test_epoch_end(self, outputs):
|
|
return self._test_val_epoch_end(outputs, test=True)
|
|
|
|
def __init__(self, *params, issubclassed=False):
|
|
super(CNNRouteGeneratorModel, self).__init__(*params)
|
|
|
|
# Dataset
|
|
self.dataset = TrajData(self.hparams.data_param.map_root,
|
|
mode=self.hparams.data_param.mode,
|
|
preprocessed=self.hparams.data_param.use_preprocessed,
|
|
length=self.hparams.data_param.dataset_length)
|
|
|
|
self.criterion = nn.BCELoss(reduction='sum')
|
|
|
|
# Additional Attributes
|
|
###################################################
|
|
self.in_shape = self.dataset.map_shapes_max
|
|
self.use_res_net = self.hparams.model_param.use_res_net
|
|
self.lat_dim = self.hparams.model_param.lat_dim
|
|
self.feature_dim = self.lat_dim
|
|
self.out_channels = 1 if 'generator' in self.hparams.data_param.mode else self.in_shape[0]
|
|
|
|
# NN Nodes
|
|
###################################################
|
|
self.encoder = Encoder(self.in_shape, self.hparams)
|
|
self.decoder = Decoder(self.out_channels, self.encoder.last_conv_shape, self.hparams)
|
|
|
|
def forward(self, batch_x):
|
|
# Encode
|
|
mu, logvar = self.encoder(batch_x)
|
|
|
|
# Bottleneck
|
|
z = self.reparameterize(mu, logvar)
|
|
|
|
# Decode
|
|
reconstruction = self.decoder(z)
|
|
return reconstruction, z, mu, logvar
|
|
|
|
@staticmethod
|
|
def reparameterize(mu, logvar):
|
|
std = 0.5 * torch.exp(logvar)
|
|
eps = torch.randn_like(mu)
|
|
z = mu + std * eps
|
|
return z
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
def __init__(self, in_shape, hparams):
|
|
super(Encoder, self).__init__()
|
|
# Params
|
|
###################################################
|
|
self.hparams = hparams
|
|
|
|
# Additional Attributes
|
|
###################################################
|
|
self.in_shape = in_shape
|
|
self.use_res_net = self.hparams.model_param.use_res_net
|
|
self.lat_dim = self.hparams.model_param.lat_dim
|
|
self.feature_dim = self.lat_dim * 10
|
|
|
|
# NN Nodes
|
|
###################################################
|
|
#
|
|
# Utils
|
|
self.activation = self.hparams.activation()
|
|
|
|
#
|
|
# Encoder
|
|
self.conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1,
|
|
conv_filters=self.hparams.model_param.filters[0],
|
|
use_norm=self.hparams.model_param.use_norm,
|
|
use_bias=self.hparams.model_param.use_bias)
|
|
|
|
self.conv_1 = ConvModule(self.conv_0.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
|
conv_filters=self.hparams.model_param.filters[0],
|
|
use_norm=self.hparams.model_param.use_norm,
|
|
use_bias=self.hparams.model_param.use_bias)
|
|
|
|
self.conv_2 = ConvModule(self.conv_1.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
|
|
conv_filters=self.hparams.model_param.filters[1],
|
|
use_norm=self.hparams.model_param.use_norm,
|
|
use_bias=self.hparams.model_param.use_bias)
|
|
|
|
self.conv_3 = ConvModule(self.conv_2.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
|
|
conv_filters=self.hparams.model_param.filters[2],
|
|
use_norm=self.hparams.model_param.use_norm,
|
|
use_bias=self.hparams.model_param.use_bias)
|
|
|
|
self.last_conv_shape = self.conv_3.shape
|
|
self.flat = Flatten(in_shape=self.last_conv_shape)
|
|
self.lin = nn.Linear(self.flat.shape, self.feature_dim)
|
|
|
|
#
|
|
# Variational Bottleneck
|
|
self.mu = nn.Linear(self.feature_dim, self.lat_dim)
|
|
self.logvar = nn.Linear(self.feature_dim, self.lat_dim)
|
|
|
|
def forward(self, batch_x):
|
|
tensor = self.conv_0(batch_x)
|
|
tensor = self.conv_1(tensor)
|
|
tensor = self.conv_2(tensor)
|
|
tensor = self.conv_3(tensor)
|
|
|
|
tensor = self.flat(tensor)
|
|
tensor = self.lin(tensor)
|
|
tensor = self.activation(tensor)
|
|
|
|
#
|
|
# Variational
|
|
# Parameter for Sampling
|
|
mu = self.mu(tensor)
|
|
logvar = self.logvar(tensor)
|
|
return mu, logvar
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
|
|
def __init__(self, out_channels, last_conv_shape, hparams):
|
|
super(Decoder, self).__init__()
|
|
# Params
|
|
###################################################
|
|
self.hparams = hparams
|
|
|
|
# Additional Attributes
|
|
###################################################
|
|
self.use_res_net = self.hparams.model_param.use_res_net
|
|
self.lat_dim = self.hparams.model_param.lat_dim
|
|
self.feature_dim = self.lat_dim
|
|
self.out_channels = out_channels
|
|
|
|
# NN Nodes
|
|
###################################################
|
|
#
|
|
# Utils
|
|
self.activation = self.hparams.activation()
|
|
|
|
#
|
|
# Alternative Generator
|
|
self.lin = nn.Linear(self.lat_dim, reduce(mul, last_conv_shape))
|
|
|
|
self.reshape = Flatten(in_shape=reduce(mul, last_conv_shape), to=last_conv_shape)
|
|
|
|
self.deconv_1 = DeConvModule(last_conv_shape, self.hparams.model_param.filters[2],
|
|
conv_padding=0, conv_kernel=7, conv_stride=1,
|
|
use_norm=self.hparams.model_param.use_norm)
|
|
|
|
self.deconv_2 = DeConvModule(self.deconv_1.shape, self.hparams.model_param.filters[1],
|
|
conv_padding=1, conv_kernel=5, conv_stride=1,
|
|
use_norm=self.hparams.model_param.use_norm)
|
|
|
|
self.deconv_3 = DeConvModule(self.deconv_2.shape, self.hparams.model_param.filters[0],
|
|
conv_padding=0, conv_kernel=3, conv_stride=1,
|
|
use_norm=self.hparams.model_param.use_norm)
|
|
|
|
self.deconv_out = DeConvModule(self.deconv_3.shape, self.out_channels, activation=nn.Sigmoid,
|
|
conv_padding=0, conv_kernel=3, conv_stride=1,
|
|
use_norm=self.hparams.model_param.use_norm)
|
|
|
|
def forward(self, z):
|
|
|
|
tensor = self.lin(z)
|
|
tensor = self.activation(tensor)
|
|
tensor = self.reshape(tensor)
|
|
|
|
tensor = self.deconv_1(tensor)
|
|
tensor = self.deconv_2(tensor)
|
|
tensor = self.deconv_3(tensor)
|
|
reconstruction = self.deconv_out(tensor)
|
|
return reconstruction
|