project Refactor, CNN Classifier Basics
This commit is contained in:
@ -1,8 +1,11 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from lib.modules.utils import FlipTensor
|
||||
from lib.objects.map import MapStorage
|
||||
from lib.objects.map import MapStorage, Map
|
||||
from lib.objects.trajectory import Trajectory
|
||||
|
||||
|
||||
class BinaryHomotopicLoss(nn.Module):
|
||||
@ -11,7 +14,10 @@ class BinaryHomotopicLoss(nn.Module):
|
||||
self.map_storage = map_storage
|
||||
self.flipper = FlipTensor()
|
||||
|
||||
def forward(self, x:torch.Tensor, y: torch.Tensor, mapnames: str):
|
||||
y_flipepd = self.flipper(y)
|
||||
circle = torch.cat((x, y_flipepd), dim=-1)
|
||||
masp = self.map_storage[mapnames].are
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor, mapnames: str):
|
||||
maps: List[Map] = [self.map_storage[mapname] for mapname in mapnames]
|
||||
for basemap in maps:
|
||||
basemap = basemap.as_2d_array
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user