Files
hom_traj_gen/lib/modules/losses.py
2020-02-19 21:11:42 +01:00

18 lines
536 B
Python

import torch
from torch import nn
from lib.modules.utils import FlipTensor
from lib.objects.map import MapStorage
class BinaryHomotopicLoss(nn.Module):
def __init__(self, map_storage: MapStorage):
super(BinaryHomotopicLoss, self).__init__()
self.map_storage = map_storage
self.flipper = FlipTensor()
def forward(self, x:torch.Tensor, y: torch.Tensor, mapnames: str):
y_flipepd = self.flipper(y)
circle = torch.cat((x, y_flipepd), dim=-1)
masp = self.map_storage[mapnames].are