from ml_lib.modules.losses import BinaryHomotopicLoss from ml_lib.modules.utils import LightningBaseModule from objects.map import Map from objects.trajectory import Trajectory import torch.nn as nn class LinearRouteGeneratorModel(LightningBaseModule): def test_epoch_end(self, outputs): pass name = 'LinearRouteGenerator' def configure_optimizers(self): pass def validation_step(self, *args, **kwargs): pass def validation_end(self, outputs): pass def training_step(self, batch, batch_nb, *args, **kwargs): # Type Annotation traj_x: Trajectory traj_o: Trajectory label_x: int map_name: str map_x: Map # Batch unpacking traj_x, traj_o, label_x, map_name = batch map_x = self.map_storage[map_name] pred_y = self(map_x, traj_x, label_x) loss = self.loss(traj_x, pred_y) def training_step(self, batch_xy, batch_nb, *args, **kwargs): batch_x, batch_y = batch_xy pred_y = self(batch_x) loss = self.criterion(pred_y, batch_y.unsqueeze(-1).float()) return dict(loss=loss, log=dict(loss=loss)) def test_step(self, *args, **kwargs): pass def __init__(self, *params): super(LinearRouteGeneratorModel, self).__init__(*params) self.criterion = BinaryHomotopicLoss(self.map_storage) def forward(self, map_x, traj_x, label_x): pass