import torch from torch import nn import torch.nn.functional as F import pytorch_lightning as pl from lib.models.blocks import FlipTensor from lib.objects.map import MapStorage class BinaryHomotopicLoss(nn.Module): def __init__(self, map_storage: MapStorage): super(BinaryHomotopicLoss, self).__init__() self.map_storage = map_storage self.flipper = FlipTensor() def forward(self, x:torch.Tensor, y: torch.Tensor, mapnames: str): y_flipepd = self.flipper(y) circle = torch.cat((x, y_flipepd), dim=-1) masp = self.map_storage[mapname].are