from ml_lib.modules.losses import BinaryHomotopicLoss
from ml_lib.modules.utils import LightningBaseModule
from objects.map import Map
from objects.trajectory import Trajectory

import torch.nn as nn


class LinearRouteGeneratorModel(LightningBaseModule):

    def test_epoch_end(self, outputs):
        pass

    name = 'LinearRouteGenerator'

    def configure_optimizers(self):
        pass

    def validation_step(self, *args, **kwargs):
        pass

    def validation_end(self, outputs):
        pass

    def training_step(self, batch, batch_nb, *args, **kwargs):
        # Type Annotation
        traj_x: Trajectory
        traj_o: Trajectory
        label_x: int
        map_name: str
        map_x: Map
        # Batch unpacking
        traj_x, traj_o, label_x, map_name = batch
        map_x = self.map_storage[map_name]
        pred_y = self(map_x, traj_x, label_x)

        loss = self.loss(traj_x, pred_y)

        def training_step(self, batch_xy, batch_nb, *args, **kwargs):
            batch_x, batch_y = batch_xy
            pred_y = self(batch_x)
            loss = self.criterion(pred_y, batch_y.unsqueeze(-1).float())

        return dict(loss=loss, log=dict(loss=loss))

    def test_step(self, *args, **kwargs):
        pass

    def __init__(self, *params):
        super(LinearRouteGeneratorModel, self).__init__(*params)

        self.criterion = BinaryHomotopicLoss(self.map_storage)

    def forward(self, map_x, traj_x, label_x):
        pass