project Refactor, CNN Classifier Basics

This commit is contained in:
Steffen Illium
2020-02-19 21:11:42 +01:00
parent 8424251ca0
commit 78f0df8a2a
16 changed files with 622 additions and 560 deletions

View File

@ -1,4 +1,6 @@
from lib.models.blocks import RecurrentModule, ConvModule, DeConvModule, Generator, LightningBaseModule
from dataset.dataset import TrajPairData
from lib.modules.blocks import ConvModule
from lib.modules.utils import LightningBaseModule
class CNNRouteGeneratorModel(LightningBaseModule):
@ -23,5 +25,21 @@ class CNNRouteGeneratorModel(LightningBaseModule):
def __init__(self, *params):
super(CNNRouteGeneratorModel, self).__init__(*params)
# Dataset
self.dataset = TrajPairData(self.hparams.data_param.data_root)
# Additional Attributes
self.in_shape = self.dataset.map_shapes_max
# NN Nodes
self.conv1 = ConvModule(self.in_shape, self.hparams.model_param.filters[0])
self.conv2 = ConvModule(self.conv1.shape, self.hparams.model_param.filters[0])
self.conv3 = ConvModule(self.conv2.shape, self.hparams.model_param.filters[0])
def forward(self, x):
pass