from functools import reduce from operator import mul from random import choices, choice import torch from torch import nn from torch.optim import Adam from torchvision.datasets import MNIST from datasets.mnist import MyMNIST from datasets.trajectory_dataset import TrajData from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule from lib.modules.utils import LightningBaseModule, Flatten import matplotlib.pyplot as plt import lib.variables as V from lib.visualization.generator_eval import GeneratorVisualizer class CNNRouteGeneratorModel(LightningBaseModule): name = 'CNNRouteGenerator' def configure_optimizers(self): return Adam(self.parameters(), lr=self.hparams.train_param.lr) def training_step(self, batch_xy, batch_nb, *args, **kwargs): batch_x, target = batch_xy generated_alternative, z, mu, logvar = self(batch_x) target = batch_x if 'ae' in self.hparams.data_param.mode else target element_wise_loss = self.criterion(generated_alternative, target) if 'vae' in self.hparams.data_param.mode: # see Appendix B from VAE paper: # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 # https://arxiv.org/abs/1312.6114 # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) # Dimensional Resizing TODO: Does This make sense? Sanity Check it! # kld_loss /= reduce(mul, self.in_shape) kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size loss = kld_loss + element_wise_loss else: loss = element_wise_loss kld_loss = 0 return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss)) def _test_val_step(self, batch_xy, batch_nb, *args): batch_x, _ = batch_xy if 'vae' in self.hparams.data_param.mode: z, mu, logvar = self.encode(batch_x) else: z = self.encode(batch_x) mu, logvar = z, z generated_alternative = self.generate(mu) return_dict = dict(input=batch_x, batch_nb=batch_nb, output=generated_alternative, z=z, mu=mu, logvar=logvar) if 'hom' in self.hparams.data_param.mode: labels = torch.full((batch_x.shape[0], 1), V.HOMOTOPIC) elif 'alt' in self.hparams.data_param.mode: labels = torch.full((batch_x.shape[0], 1), V.ALTERNATIVE) elif 'vae' in self.hparams.data_param.mode: labels = torch.full((batch_x.shape[0], 1), V.ANY) elif 'ae' in self.hparams.data_param.mode: labels = torch.full((batch_x.shape[0], 1), V.ANY) else: labels = batch_x[:, 2].unsqueeze(1).max(dim=-1).values.max(-1).values return_dict.update(labels=self._move_to_model_device(labels)) return return_dict def _test_val_epoch_end(self, outputs, test=False): plt.close('all') g = GeneratorVisualizer(choice(outputs)) fig = g.draw_io_bundle() self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step) plt.clf() fig = g.draw_latent() self.logger.log_image(f'{self.name}_Latent', fig, step=self.global_step) plt.clf() return dict(epoch=self.current_epoch) def on_epoch_start(self): # self.dataset.seed(self.logger.version) # torch.random.manual_seed(self.logger.version) # np.random.seed(self.logger.version) pass def validation_step(self, *args): return self._test_val_step(*args) def validation_epoch_end(self, outputs: list): return self._test_val_epoch_end(outputs) def test_step(self, *args): return self._test_val_step(*args) def test_epoch_end(self, outputs): return self._test_val_epoch_end(outputs, test=True) def __init__(self, *params, issubclassed=False): super(CNNRouteGeneratorModel, self).__init__(*params) if False: # Dataset self.dataset = TrajData(self.hparams.data_param.map_root, mode=self.hparams.data_param.mode, preprocessed=self.hparams.data_param.use_preprocessed, length=self.hparams.data_param.dataset_length, normalized=True) self.criterion = nn.MSELoss() self.dataset = MyMNIST() # Additional Attributes # ####################################################### self.in_shape = self.dataset.map_shapes_max self.use_res_net = self.hparams.model_param.use_res_net self.lat_dim = self.hparams.model_param.lat_dim self.feature_dim = self.lat_dim self.out_channels = 1 if 'generator' in self.hparams.data_param.mode else self.in_shape[0] ######################################################## # NN Nodes ################################################### # # Utils self.activation = nn.LeakyReLU() self.sigmoid = nn.Sigmoid() # # Map Encoder self.enc_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1, conv_filters=self.hparams.model_param.filters[0], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) self.enc_res_1 = ResidualModule(self.enc_conv_0.shape, ConvModule, 2, conv_kernel=5, conv_stride=1, conv_padding=2, conv_filters=self.hparams.model_param.filters[0], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) self.enc_conv_1a = ConvModule(self.enc_res_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[1], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) self.enc_res_2 = ResidualModule(self.enc_conv_1a.shape, ConvModule, 2, conv_kernel=5, conv_stride=1, conv_padding=2, conv_filters=self.hparams.model_param.filters[1], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) self.enc_conv_2a = ConvModule(self.enc_res_2.shape, conv_kernel=5, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[2], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) self.enc_conv_3a = ConvModule(self.enc_conv_2a.shape, conv_kernel=7, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[2], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) last_conv_shape = self.enc_conv_3a.shape self.enc_flat = Flatten(last_conv_shape) self.enc_lin_1 = nn.Linear(self.enc_flat.shape, self.feature_dim) # # Mixed Encoder self.enc_lin_2 = nn.Linear(self.feature_dim, self.feature_dim) self.enc_norm = nn.BatchNorm1d(self.feature_dim) if self.hparams.model_param.use_norm else lambda x: x # # Variational Bottleneck if 'vae' in self.hparams.data_param.mode: self.mu = nn.Linear(self.feature_dim, self.lat_dim) self.logvar = nn.Linear(self.feature_dim, self.lat_dim) # # Linear Bottleneck else: self.z = nn.Linear(self.feature_dim, self.lat_dim) # # Alternative Generator self.gen_lin_1 = nn.Linear(self.lat_dim, self.enc_flat.shape) # self.gen_lin_2 = nn.Linear(self.feature_dim, self.enc_flat.shape) self.reshape_to_last_conv = Flatten(self.enc_flat.shape, last_conv_shape) self.gen_deconv_1a = DeConvModule(last_conv_shape, self.hparams.model_param.filters[2], conv_padding=0, conv_kernel=7, conv_stride=1, use_norm=self.hparams.model_param.use_norm) self.gen_deconv_2a = DeConvModule(self.gen_deconv_1a.shape, self.hparams.model_param.filters[1], conv_padding=1, conv_kernel=5, conv_stride=1, use_norm=self.hparams.model_param.use_norm) self.gen_deconv_3a = DeConvModule(self.gen_deconv_2a.shape, self.hparams.model_param.filters[0], conv_padding=0, conv_kernel=3, conv_stride=1, use_norm=self.hparams.model_param.use_norm) self.gen_deconv_out = DeConvModule(self.gen_deconv_3a.shape, self.out_channels, activation=None, conv_padding=0, conv_kernel=3, conv_stride=1, use_norm=self.hparams.model_param.use_norm) def forward(self, batch_x): # # Encode if 'vae' in self.hparams.data_param.mode: z, mu, logvar = self.encode(batch_x) else: z = self.encode(batch_x) mu, logvar = z, z # # Generate alt_tensor = self.generate(z) return alt_tensor, z, mu, logvar @staticmethod def reparameterize(mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def encode(self, batch_x): combined_tensor = self.enc_conv_0(batch_x) combined_tensor = self.enc_conv_1a(combined_tensor) combined_tensor = self.enc_conv_2a(combined_tensor) combined_tensor = self.enc_conv_3a(combined_tensor) combined_tensor = self.enc_flat(combined_tensor) combined_tensor = self.enc_lin_1(combined_tensor) combined_tensor = self.activation(combined_tensor) combined_tensor = self.enc_lin_2(combined_tensor) combined_tensor = self.activation(combined_tensor) # # Variational # Parameter and Sampling if 'vae' in self.hparams.data_param.mode: mu = self.mu(combined_tensor) logvar = self.logvar(combined_tensor) z = self.reparameterize(mu, logvar) return z, mu, logvar else: # # Linear Bottleneck z = self.z(combined_tensor) return z def generate(self, z): alt_tensor = self.gen_lin_1(z) alt_tensor = self.activation(alt_tensor) alt_tensor = self.reshape_to_last_conv(alt_tensor) alt_tensor = self.gen_deconv_1a(alt_tensor) alt_tensor = self.gen_deconv_2a(alt_tensor) alt_tensor = self.gen_deconv_3a(alt_tensor) alt_tensor = self.gen_deconv_out(alt_tensor) # alt_tensor = self.activation(alt_tensor) # alt_tensor = self.sigmoid(alt_tensor) return alt_tensor