22 lines
600 B
Python
22 lines
600 B
Python
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
import pytorch_lightning as pl
|
|
|
|
from lib.models.blocks import FlipTensor
|
|
from lib.objects.map import MapStorage
|
|
|
|
|
|
class BinaryHomotopicLoss(nn.Module):
|
|
def __init__(self, map_storage: MapStorage):
|
|
super(BinaryHomotopicLoss, self).__init__()
|
|
self.map_storage = map_storage
|
|
self.flipper = FlipTensor()
|
|
|
|
def forward(self, x:torch.Tensor, y: torch.Tensor, mapnames: str):
|
|
y_flipepd = self.flipper(y)
|
|
circle = torch.cat((x, y_flipepd), dim=-1)
|
|
masp = self.map_storage[mapname].are
|
|
|
|
|