from datasets.paired_dataset import TrajPairData from lib.modules.blocks import ConvModule from lib.modules.utils import LightningBaseModule class CNNRouteGeneratorModel(LightningBaseModule): name = 'CNNRouteGenerator' def configure_optimizers(self): pass def validation_step(self, *args, **kwargs): pass def validation_end(self, outputs): pass def training_step(self, batch_xy, batch_nb, *args, **kwargs): pass def test_step(self, *args, **kwargs): pass def __init__(self, *params): super(CNNRouteGeneratorModel, self).__init__(*params) # Dataset self.dataset = TrajPairData(self.hparams.data_param.data_root) # Additional Attributes self.in_shape = self.dataset.map_shapes_max # NN Nodes self.conv2 = ConvModule(self.conv1.shape, conv_kernel=3, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[0]) self.conv3 = ConvModule(self.conv2.shape, conv_kernel=3, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[0]) def forward(self, x): pass