44 lines
1.2 KiB
Python
44 lines
1.2 KiB
Python
from dataset.dataset import TrajPairData
|
|
from lib.modules.blocks import ConvModule
|
|
from lib.modules.utils import LightningBaseModule
|
|
|
|
|
|
class CNNRouteGeneratorModel(LightningBaseModule):
|
|
|
|
name = 'CNNRouteGenerator'
|
|
|
|
def configure_optimizers(self):
|
|
pass
|
|
|
|
def validation_step(self, *args, **kwargs):
|
|
pass
|
|
|
|
def validation_end(self, outputs):
|
|
pass
|
|
|
|
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
|
pass
|
|
|
|
def test_step(self, *args, **kwargs):
|
|
pass
|
|
|
|
def __init__(self, *params):
|
|
super(CNNRouteGeneratorModel, self).__init__(*params)
|
|
|
|
# Dataset
|
|
self.dataset = TrajPairData(self.hparams.data_param.data_root)
|
|
|
|
# Additional Attributes
|
|
self.in_shape = self.dataset.map_shapes_max
|
|
|
|
# NN Nodes
|
|
|
|
|
|
self.conv2 = ConvModule(self.conv1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
|
conv_filters=self.hparams.model_param.filters[0])
|
|
self.conv3 = ConvModule(self.conv2.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
|
conv_filters=self.hparams.model_param.filters[0])
|
|
|
|
def forward(self, x):
|
|
pass
|