restructured
This commit is contained in:
@ -34,7 +34,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||
# Dimensional Resizing TODO: Does This make sense? Sanity Check it!
|
||||
# kld_loss /= reduce(mul, self.in_shape)
|
||||
kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size
|
||||
kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size * 100
|
||||
|
||||
loss = (kld_loss + element_wise_loss) / 2
|
||||
return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss))
|
||||
|
Reference in New Issue
Block a user