project Refactor, CNN Classifier Basics
This commit is contained in:
@ -1,4 +1,6 @@
|
||||
from lib.models.blocks import RecurrentModule, ConvModule, DeConvModule, Generator, LightningBaseModule
|
||||
from dataset.dataset import TrajPairData
|
||||
from lib.modules.blocks import ConvModule
|
||||
from lib.modules.utils import LightningBaseModule
|
||||
|
||||
|
||||
class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
@ -23,5 +25,21 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
def __init__(self, *params):
|
||||
super(CNNRouteGeneratorModel, self).__init__(*params)
|
||||
|
||||
# Dataset
|
||||
self.dataset = TrajPairData(self.hparams.data_param.data_root)
|
||||
|
||||
# Additional Attributes
|
||||
self.in_shape = self.dataset.map_shapes_max
|
||||
|
||||
# NN Nodes
|
||||
|
||||
self.conv1 = ConvModule(self.in_shape, self.hparams.model_param.filters[0])
|
||||
self.conv2 = ConvModule(self.conv1.shape, self.hparams.model_param.filters[0])
|
||||
self.conv3 = ConvModule(self.conv2.shape, self.hparams.model_param.filters[0])
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
pass
|
||||
|
Reference in New Issue
Block a user