VAE Debugging of Route Generator

This commit is contained in:
Si11ium
2020-04-08 08:53:20 +02:00
parent 934dadb558
commit c7971c063f
3 changed files with 25 additions and 26 deletions

View File

@ -79,15 +79,13 @@ class CNNRouteGeneratorModel(LightningBaseModule):
def __init__(self, *params, issubclassed=False):
super(CNNRouteGeneratorModel, self).__init__(*params)
if False:
# Dataset
self.dataset = TrajData(self.hparams.data_param.map_root,
mode=self.hparams.data_param.mode,
preprocessed=self.hparams.data_param.use_preprocessed,
length=self.hparams.data_param.dataset_length)
self.criterion = nn.BCELoss(reduction='sum')
# Dataset
self.dataset = TrajData(self.hparams.data_param.map_root,
mode=self.hparams.data_param.mode,
preprocessed=self.hparams.data_param.use_preprocessed,
length=self.hparams.data_param.dataset_length)
self.dataset = MyMNIST()
self.criterion = nn.BCELoss(reduction='sum')
# Additional Attributes
###################################################

View File

@ -7,6 +7,7 @@ from pytorch_lightning.loggers.test_tube import TestTubeLogger
from lib.utils.config import Config
import numpy as np
class Logger(LightningLoggerBase):
media_dir = 'media'