Debugging
This commit is contained in:
0
lib/models/generators/__init__.py
Normal file
0
lib/models/generators/__init__.py
Normal file
43
lib/models/generators/cnn.py
Normal file
43
lib/models/generators/cnn.py
Normal file
@ -0,0 +1,43 @@
|
||||
from dataset.dataset import TrajPairData
|
||||
from lib.modules.blocks import ConvModule
|
||||
from lib.modules.utils import LightningBaseModule
|
||||
|
||||
|
||||
class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
|
||||
name = 'CNNRouteGenerator'
|
||||
|
||||
def configure_optimizers(self):
|
||||
pass
|
||||
|
||||
def validation_step(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def validation_end(self, outputs):
|
||||
pass
|
||||
|
||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def test_step(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __init__(self, *params):
|
||||
super(CNNRouteGeneratorModel, self).__init__(*params)
|
||||
|
||||
# Dataset
|
||||
self.dataset = TrajPairData(self.hparams.data_param.data_root)
|
||||
|
||||
# Additional Attributes
|
||||
self.in_shape = self.dataset.map_shapes_max
|
||||
|
||||
# NN Nodes
|
||||
|
||||
|
||||
self.conv2 = ConvModule(self.conv1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[0])
|
||||
self.conv3 = ConvModule(self.conv2.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[0])
|
||||
|
||||
def forward(self, x):
|
||||
pass
|
47
lib/models/generators/full.py
Normal file
47
lib/models/generators/full.py
Normal file
@ -0,0 +1,47 @@
|
||||
from lib.modules.blocks import LightningBaseModule
|
||||
from lib.modules.losses import BinaryHomotopicLoss
|
||||
from lib.objects.map import Map
|
||||
from lib.objects.trajectory import Trajectory
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
nn.MSELoss
|
||||
|
||||
class LinearRouteGeneratorModel(LightningBaseModule):
|
||||
|
||||
name = 'LinearRouteGenerator'
|
||||
|
||||
def configure_optimizers(self):
|
||||
pass
|
||||
|
||||
def validation_step(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def validation_end(self, outputs):
|
||||
pass
|
||||
|
||||
def training_step(self, batch, batch_nb, *args, **kwargs):
|
||||
# Type Annotation
|
||||
traj_x: Trajectory
|
||||
traj_o: Trajectory
|
||||
label_x: int
|
||||
map_name: str
|
||||
map_x: Map
|
||||
# Batch unpacking
|
||||
traj_x, traj_o, label_x, map_name = batch
|
||||
map_x = self.map_storage[map_name]
|
||||
pred_y = self(map_x, traj_x, label_x)
|
||||
|
||||
loss = self.loss(traj_x, pred_y)
|
||||
return dict(loss=loss, log=dict(loss=loss))
|
||||
|
||||
def test_step(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __init__(self, *params):
|
||||
super(LinearRouteGeneratorModel, self).__init__(*params)
|
||||
|
||||
self.loss = BinaryHomotopicLoss(self.map_storage)
|
||||
|
||||
def forward(self, map_x, traj_x, label_x):
|
||||
pass
|
0
lib/models/generators/recurrent.py
Normal file
0
lib/models/generators/recurrent.py
Normal file
Reference in New Issue
Block a user