New Dataset Generator, How to differentiate the loss function?

This commit is contained in:
Steffen Illium
2020-02-18 21:58:31 +01:00
parent 61c5cb44a0
commit 8424251ca0
13 changed files with 250 additions and 39 deletions

View File

@@ -12,7 +12,8 @@ import pytorch_lightning as pl
###################
from torch.utils.data import DataLoader
from dataset.dataset import TrajData
from dataset.dataset import TrajDataset
from lib.objects.map import MapStorage
class Flatten(nn.Module):
@@ -77,7 +78,8 @@ class LightningBaseModule(pl.LightningModule, ABC):
# Data loading
# =============================================================================
# Dataset
self.dataset = TrajData('data')
self.dataset = TrajDataset('data')
self.map_storage = MapStorage(self.hparams.data_param.map_root)
def size(self):
return self.shape
@@ -176,6 +178,17 @@ class MergingLayer(nn.Module):
return
class FlipTensor(nn.Module):
def __init__(self, dim=-2):
super(FlipTensor, self).__init__()
self.dim = dim
def forward(self, x):
idx = [i for i in range(x.size(self.dim) - 1, -1, -1)]
idx = torch.as_tensor(idx).long()
inverted_tensor = x.index_select(self.dim, idx)
return inverted_tensor
#
# Sub - Modules
###################

View File

@@ -3,9 +3,7 @@ from lib.models.blocks import RecurrentModule, ConvModule, DeConvModule, Generat
class CNNRouteGeneratorModel(LightningBaseModule):
@classmethod
def name(cls):
pass
name = 'CNNRouteGenerator'
def configure_optimizers(self):
pass

View File

@@ -0,0 +1,49 @@
from lib.models.blocks import RecurrentModule, ConvModule, DeConvModule, Generator, LightningBaseModule
from lib.models.losses import BinaryHomotopicLoss
from lib.objects.map import Map
from lib.objects.trajectory import Trajectory
import torch
import torch.functional as F
import torch.nn as nn
nn.MSELoss
class LinearRouteGeneratorModel(LightningBaseModule):
name = 'LinearRouteGenerator'
def configure_optimizers(self):
pass
def validation_step(self, *args, **kwargs):
pass
def validation_end(self, outputs):
pass
def training_step(self, batch, batch_nb, *args, **kwargs):
# Type Annotation
traj_x: Trajectory
traj_o: Trajectory
label_x: int
map_name: str
map_x: Map
# Batch unpacking
traj_x, traj_o, label_x, map_name = batch
map_x = self.map_storage[map_name]
pred_y = self(map_x, traj_x, label_x)
loss = self.loss(traj_x, pred_y)
return dict(loss=loss, log=dict(loss=loss))
def test_step(self, *args, **kwargs):
pass
def __init__(self, *params):
super(LinearRouteGeneratorModel, self).__init__(*params)
self.loss = BinaryHomotopicLoss(self.map_storage)
def forward(self, map_x, traj_x, label_x):
pass

21
lib/models/losses.py Normal file
View File

@@ -0,0 +1,21 @@
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl
from lib.models.blocks import FlipTensor
from lib.objects.map import MapStorage
class BinaryHomotopicLoss(nn.Module):
def __init__(self, map_storage: MapStorage):
super(BinaryHomotopicLoss, self).__init__()
self.map_storage = map_storage
self.flipper = FlipTensor()
def forward(self, x:torch.Tensor, y: torch.Tensor, mapnames: str):
y_flipepd = self.flipper(y)
circle = torch.cat((x, y_flipepd), dim=-1)
masp = self.map_storage[mapname].are