diff --git a/lib/models/generators/cnn.py b/lib/models/generators/cnn.py index 6e2ad6c..169aa60 100644 --- a/lib/models/generators/cnn.py +++ b/lib/models/generators/cnn.py @@ -34,7 +34,7 @@ class CNNRouteGeneratorModel(LightningBaseModule): kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) # Dimensional Resizing TODO: Does This make sense? Sanity Check it! # kld_loss /= reduce(mul, self.in_shape) - kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size * 100 + # kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size * 100 loss = (kld_loss + element_wise_loss) / 2 return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss)) @@ -54,6 +54,7 @@ class CNNRouteGeneratorModel(LightningBaseModule): g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict) fig = g.draw() self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step) + plt.clf() return dict(epoch=self.current_epoch) @@ -299,12 +300,14 @@ class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel): # self.logger.log_metrics(score_dict) self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf(), step=self.global_step) plt.clf() + maps, trajectories, labels, val_restul_dict = self.generate_random() from lib.visualization.generator_eval import GeneratorVisualizer g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict) fig = g.draw() self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step) + plt.clf() return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch)