from functools import reduce from operator import mul from random import choice import torch from torch import nn from torch.optim import Adam from datasets.mnist import MyMNIST from datasets.trajectory_dataset import TrajData from ml_lib.modules.blocks import ConvModule, DeConvModule from ml_lib.modules.utils import LightningBaseModule, Flatten import matplotlib.pyplot as plt import variables as V from generator_eval import GeneratorVisualizer class CNNRouteGeneratorModel(LightningBaseModule): torch.autograd.set_detect_anomaly(True) name = 'CNNRouteGenerator' def configure_optimizers(self): return Adam(self.parameters(), lr=self.hparams.train_param.lr) def training_step(self, batch_xy, batch_nb, *args, **kwargs): batch_x, _ = batch_xy reconstruction, z, mu, logvar = self(batch_x) recon_loss = self.criterion(reconstruction, batch_x) kldivergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) loss = recon_loss + kldivergence return dict(loss=loss, log=dict(reconstruction_loss=recon_loss, loss=loss, kld_loss=kldivergence)) def _test_val_step(self, batch_xy, batch_nb, *args): batch_x, _ = batch_xy mu, logvar = self.encoder(batch_x) z = self.reparameterize(mu, logvar) reconstruction = self.decoder(mu) return_dict = dict(input=batch_x, batch_nb=batch_nb, output=reconstruction, z=z, mu=mu, logvar=logvar) labels = torch.full((batch_x.shape[0], 1), V.ANY) return_dict.update(labels=self._move_to_model_device(labels)) return return_dict def _test_val_epoch_end(self, outputs, test=False): plt.close('all') g = GeneratorVisualizer(choice(outputs)) fig = g.draw_io_bundle() self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step) plt.clf() fig = g.draw_latent() self.logger.log_image(f'{self.name}_Latent', fig, step=self.global_step) plt.clf() return dict(epoch=self.current_epoch) def validation_step(self, *args): return self._test_val_step(*args) def validation_epoch_end(self, outputs: list): return self._test_val_epoch_end(outputs) def test_step(self, *args): return self._test_val_step(*args) def test_epoch_end(self, outputs): return self._test_val_epoch_end(outputs, test=True) def __init__(self, *params, issubclassed=False): super(CNNRouteGeneratorModel, self).__init__(*params) # Dataset self.dataset = TrajData(self.hparams.data_param.map_root, mode=self.hparams.data_param.mode, preprocessed=self.hparams.data_param.use_preprocessed, length=self.hparams.data_param.dataset_length) self.criterion = nn.BCELoss(reduction='sum') # Additional Attributes ################################################### self.in_shape = self.dataset.map_shapes_max self.use_res_net = self.hparams.model_param.use_res_net self.lat_dim = self.hparams.model_param.lat_dim self.feature_dim = self.lat_dim self.out_channels = 1 if 'generator' in self.hparams.data_param.mode else self.in_shape[0] # NN Nodes ################################################### self.encoder = Encoder(self.in_shape, self.hparams) self.decoder = Decoder(self.out_channels, self.encoder.last_conv_shape, self.hparams) def forward(self, batch_x): # Encode mu, logvar = self.encoder(batch_x) # Bottleneck z = self.reparameterize(mu, logvar) # Decode reconstruction = self.decoder(z) return reconstruction, z, mu, logvar @staticmethod def reparameterize(mu, logvar): std = 0.5 * torch.exp(logvar) eps = torch.randn_like(mu) z = mu + std * eps return z class Encoder(nn.Module): def __init__(self, in_shape, hparams): super(Encoder, self).__init__() # Params ################################################### self.hparams = hparams # Additional Attributes ################################################### self.in_shape = in_shape self.use_res_net = self.hparams.model_param.use_res_net self.lat_dim = self.hparams.model_param.lat_dim self.feature_dim = self.lat_dim * 10 # NN Nodes ################################################### # # Utils self.activation = self.hparams.activation() # # Encoder self.conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1, conv_filters=self.hparams.model_param.filters[0], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) self.conv_1 = ConvModule(self.conv_0.shape, conv_kernel=3, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[0], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) self.conv_2 = ConvModule(self.conv_1.shape, conv_kernel=5, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[1], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) self.conv_3 = ConvModule(self.conv_2.shape, conv_kernel=7, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[2], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) self.last_conv_shape = self.conv_3.shape self.flat = Flatten(in_shape=self.last_conv_shape) self.lin = nn.Linear(self.flat.shape, self.feature_dim) # # Variational Bottleneck self.mu = nn.Linear(self.feature_dim, self.lat_dim) self.logvar = nn.Linear(self.feature_dim, self.lat_dim) def forward(self, batch_x): tensor = self.conv_0(batch_x) tensor = self.conv_1(tensor) tensor = self.conv_2(tensor) tensor = self.conv_3(tensor) tensor = self.flat(tensor) tensor = self.lin(tensor) tensor = self.activation(tensor) # # Variational # Parameter for Sampling mu = self.mu(tensor) logvar = self.logvar(tensor) return mu, logvar class Decoder(nn.Module): def __init__(self, out_channels, last_conv_shape, hparams): super(Decoder, self).__init__() # Params ################################################### self.hparams = hparams # Additional Attributes ################################################### self.use_res_net = self.hparams.model_param.use_res_net self.lat_dim = self.hparams.model_param.lat_dim self.feature_dim = self.lat_dim self.out_channels = out_channels # NN Nodes ################################################### # # Utils self.activation = self.hparams.activation() # # Alternative Generator self.lin = nn.Linear(self.lat_dim, reduce(mul, last_conv_shape)) self.reshape = Flatten(in_shape=reduce(mul, last_conv_shape), to=last_conv_shape) self.deconv_1 = DeConvModule(last_conv_shape, self.hparams.model_param.filters[2], conv_padding=0, conv_kernel=7, conv_stride=1, use_norm=self.hparams.model_param.use_norm) self.deconv_2 = DeConvModule(self.deconv_1.shape, self.hparams.model_param.filters[1], conv_padding=1, conv_kernel=5, conv_stride=1, use_norm=self.hparams.model_param.use_norm) self.deconv_3 = DeConvModule(self.deconv_2.shape, self.hparams.model_param.filters[0], conv_padding=0, conv_kernel=3, conv_stride=1, use_norm=self.hparams.model_param.use_norm) self.deconv_out = DeConvModule(self.deconv_3.shape, self.out_channels, activation=nn.Sigmoid, conv_padding=0, conv_kernel=3, conv_stride=1, use_norm=self.hparams.model_param.use_norm) def forward(self, z): tensor = self.lin(z) tensor = self.activation(tensor) tensor = self.reshape(tensor) tensor = self.deconv_1(tensor) tensor = self.deconv_2(tensor) tensor = self.deconv_3(tensor) reconstruction = self.deconv_out(tensor) return reconstruction