project Refactor, CNN Classifier Basics

This commit is contained in:
Steffen Illium
2020-03-08 23:46:02 +01:00
parent 75e8a61628
commit cd4fdf2de3
20 changed files with 441 additions and 239 deletions

View File

@ -1,8 +1,11 @@
from typing import List
import torch
from torch import nn
from lib.modules.utils import FlipTensor
from lib.objects.map import MapStorage
from lib.objects.map import MapStorage, Map
from lib.objects.trajectory import Trajectory
class BinaryHomotopicLoss(nn.Module):
@ -11,7 +14,10 @@ class BinaryHomotopicLoss(nn.Module):
self.map_storage = map_storage
self.flipper = FlipTensor()
def forward(self, x:torch.Tensor, y: torch.Tensor, mapnames: str):
y_flipepd = self.flipper(y)
circle = torch.cat((x, y_flipepd), dim=-1)
masp = self.map_storage[mapnames].are
def forward(self, x: torch.Tensor, y: torch.Tensor, mapnames: str):
maps: List[Map] = [self.map_storage[mapname] for mapname in mapnames]
for basemap in maps:
basemap = basemap.as_2d_array