Files
hom_traj_gen/lib/models/generators/full.py
2020-02-28 19:11:53 +01:00

48 lines
1.2 KiB
Python

from lib.modules.blocks import LightningBaseModule
from lib.modules.losses import BinaryHomotopicLoss
from lib.objects.map import Map
from lib.objects.trajectory import Trajectory
import torch.nn as nn
nn.MSELoss
class LinearRouteGeneratorModel(LightningBaseModule):
name = 'LinearRouteGenerator'
def configure_optimizers(self):
pass
def validation_step(self, *args, **kwargs):
pass
def validation_end(self, outputs):
pass
def training_step(self, batch, batch_nb, *args, **kwargs):
# Type Annotation
traj_x: Trajectory
traj_o: Trajectory
label_x: int
map_name: str
map_x: Map
# Batch unpacking
traj_x, traj_o, label_x, map_name = batch
map_x = self.map_storage[map_name]
pred_y = self(map_x, traj_x, label_x)
loss = self.loss(traj_x, pred_y)
return dict(loss=loss, log=dict(loss=loss))
def test_step(self, *args, **kwargs):
pass
def __init__(self, *params):
super(LinearRouteGeneratorModel, self).__init__(*params)
self.loss = BinaryHomotopicLoss(self.map_storage)
def forward(self, map_x, traj_x, label_x):
pass