from lib.models.blocks import RecurrentModule, ConvModule, DeConvModule, Generator, LightningBaseModule from lib.models.losses import BinaryHomotopicLoss from lib.objects.map import Map from lib.objects.trajectory import Trajectory import torch import torch.functional as F import torch.nn as nn nn.MSELoss class LinearRouteGeneratorModel(LightningBaseModule): name = 'LinearRouteGenerator' def configure_optimizers(self): pass def validation_step(self, *args, **kwargs): pass def validation_end(self, outputs): pass def training_step(self, batch, batch_nb, *args, **kwargs): # Type Annotation traj_x: Trajectory traj_o: Trajectory label_x: int map_name: str map_x: Map # Batch unpacking traj_x, traj_o, label_x, map_name = batch map_x = self.map_storage[map_name] pred_y = self(map_x, traj_x, label_x) loss = self.loss(traj_x, pred_y) return dict(loss=loss, log=dict(loss=loss)) def test_step(self, *args, **kwargs): pass def __init__(self, *params): super(LinearRouteGeneratorModel, self).__init__(*params) self.loss = BinaryHomotopicLoss(self.map_storage) def forward(self, map_x, traj_x, label_x): pass