56 lines
1.4 KiB
Python
56 lines
1.4 KiB
Python
from ml_lib.modules.losses import BinaryHomotopicLoss
|
|
from ml_lib.modules.utils import LightningBaseModule
|
|
from objects.map import Map
|
|
from objects.trajectory import Trajectory
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
class LinearRouteGeneratorModel(LightningBaseModule):
|
|
|
|
def test_epoch_end(self, outputs):
|
|
pass
|
|
|
|
name = 'LinearRouteGenerator'
|
|
|
|
def configure_optimizers(self):
|
|
pass
|
|
|
|
def validation_step(self, *args, **kwargs):
|
|
pass
|
|
|
|
def validation_end(self, outputs):
|
|
pass
|
|
|
|
def training_step(self, batch, batch_nb, *args, **kwargs):
|
|
# Type Annotation
|
|
traj_x: Trajectory
|
|
traj_o: Trajectory
|
|
label_x: int
|
|
map_name: str
|
|
map_x: Map
|
|
# Batch unpacking
|
|
traj_x, traj_o, label_x, map_name = batch
|
|
map_x = self.map_storage[map_name]
|
|
pred_y = self(map_x, traj_x, label_x)
|
|
|
|
loss = self.loss(traj_x, pred_y)
|
|
|
|
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
|
batch_x, batch_y = batch_xy
|
|
pred_y = self(batch_x)
|
|
loss = self.criterion(pred_y, batch_y.unsqueeze(-1).float())
|
|
|
|
return dict(loss=loss, log=dict(loss=loss))
|
|
|
|
def test_step(self, *args, **kwargs):
|
|
pass
|
|
|
|
def __init__(self, *params):
|
|
super(LinearRouteGeneratorModel, self).__init__(*params)
|
|
|
|
self.criterion = BinaryHomotopicLoss(self.map_storage)
|
|
|
|
def forward(self, map_x, traj_x, label_x):
|
|
pass
|