fig clf inserted and not resize on kld

This commit is contained in:
Steffen Illium
2020-03-12 07:10:21 +01:00
parent bb8151d9ba
commit 2f99341cc3

View File

@ -34,7 +34,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# Dimensional Resizing TODO: Does This make sense? Sanity Check it! # Dimensional Resizing TODO: Does This make sense? Sanity Check it!
# kld_loss /= reduce(mul, self.in_shape) # kld_loss /= reduce(mul, self.in_shape)
kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size * 100 # kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size * 100
loss = (kld_loss + element_wise_loss) / 2 loss = (kld_loss + element_wise_loss) / 2
return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss)) return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss))
@ -54,6 +54,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict) g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
fig = g.draw() fig = g.draw()
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step) self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
plt.clf()
return dict(epoch=self.current_epoch) return dict(epoch=self.current_epoch)
@ -299,12 +300,14 @@ class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
# self.logger.log_metrics(score_dict) # self.logger.log_metrics(score_dict)
self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf(), step=self.global_step) self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf(), step=self.global_step)
plt.clf() plt.clf()
maps, trajectories, labels, val_restul_dict = self.generate_random() maps, trajectories, labels, val_restul_dict = self.generate_random()
from lib.visualization.generator_eval import GeneratorVisualizer from lib.visualization.generator_eval import GeneratorVisualizer
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict) g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
fig = g.draw() fig = g.draw()
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step) self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
plt.clf()
return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch) return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch)