New Dataset Generator, How to differentiate the loss function?

This commit is contained in:
Steffen Illium
2020-02-18 21:58:31 +01:00
parent 61c5cb44a0
commit 8424251ca0
13 changed files with 250 additions and 39 deletions

View File

@ -12,7 +12,8 @@ import pytorch_lightning as pl
###################
from torch.utils.data import DataLoader
from dataset.dataset import TrajData
from dataset.dataset import TrajDataset
from lib.objects.map import MapStorage
class Flatten(nn.Module):
@ -77,7 +78,8 @@ class LightningBaseModule(pl.LightningModule, ABC):
# Data loading
# =============================================================================
# Dataset
self.dataset = TrajData('data')
self.dataset = TrajDataset('data')
self.map_storage = MapStorage(self.hparams.data_param.map_root)
def size(self):
return self.shape
@ -176,6 +178,17 @@ class MergingLayer(nn.Module):
return
class FlipTensor(nn.Module):
def __init__(self, dim=-2):
super(FlipTensor, self).__init__()
self.dim = dim
def forward(self, x):
idx = [i for i in range(x.size(self.dim) - 1, -1, -1)]
idx = torch.as_tensor(idx).long()
inverted_tensor = x.index_select(self.dim, idx)
return inverted_tensor
#
# Sub - Modules
###################

View File

@ -3,9 +3,7 @@ from lib.models.blocks import RecurrentModule, ConvModule, DeConvModule, Generat
class CNNRouteGeneratorModel(LightningBaseModule):
@classmethod
def name(cls):
pass
name = 'CNNRouteGenerator'
def configure_optimizers(self):
pass

View File

@ -0,0 +1,49 @@
from lib.models.blocks import RecurrentModule, ConvModule, DeConvModule, Generator, LightningBaseModule
from lib.models.losses import BinaryHomotopicLoss
from lib.objects.map import Map
from lib.objects.trajectory import Trajectory
import torch
import torch.functional as F
import torch.nn as nn
nn.MSELoss
class LinearRouteGeneratorModel(LightningBaseModule):
name = 'LinearRouteGenerator'
def configure_optimizers(self):
pass
def validation_step(self, *args, **kwargs):
pass
def validation_end(self, outputs):
pass
def training_step(self, batch, batch_nb, *args, **kwargs):
# Type Annotation
traj_x: Trajectory
traj_o: Trajectory
label_x: int
map_name: str
map_x: Map
# Batch unpacking
traj_x, traj_o, label_x, map_name = batch
map_x = self.map_storage[map_name]
pred_y = self(map_x, traj_x, label_x)
loss = self.loss(traj_x, pred_y)
return dict(loss=loss, log=dict(loss=loss))
def test_step(self, *args, **kwargs):
pass
def __init__(self, *params):
super(LinearRouteGeneratorModel, self).__init__(*params)
self.loss = BinaryHomotopicLoss(self.map_storage)
def forward(self, map_x, traj_x, label_x):
pass

21
lib/models/losses.py Normal file
View File

@ -0,0 +1,21 @@
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl
from lib.models.blocks import FlipTensor
from lib.objects.map import MapStorage
class BinaryHomotopicLoss(nn.Module):
def __init__(self, map_storage: MapStorage):
super(BinaryHomotopicLoss, self).__init__()
self.map_storage = map_storage
self.flipper = FlipTensor()
def forward(self, x:torch.Tensor, y: torch.Tensor, mapnames: str):
y_flipepd = self.flipper(y)
circle = torch.cat((x, y_flipepd), dim=-1)
masp = self.map_storage[mapname].are

View File

@ -1,4 +1,7 @@
import shelve
from pathlib import Path
from collections import UserDict
import copy
from math import sqrt
@ -130,3 +133,30 @@ class Map(object):
# https: // matplotlib.org / api / pyplot_summary.html?highlight = colormaps
img = ax.imshow(self.as_array, cmap='Greys_r')
return dict(img=img, fig=fig, ax=ax)
class MapStorage(object):
def __init__(self, map_root, load_all=False):
self.data = dict()
self.map_root = Path(map_root)
if load_all:
for map_file in self.map_root.glob('*.bmp'):
_ = self[map_file.name]
def __getitem__(self, item):
if item in hasattr(self, item):
return self.__getattribute__(item)
else:
with shelve.open(self.map_root / f'{item}.pik', flag='r') as d:
self.__setattr__(item, d['map']['map'])
return self[item]

View File

@ -2,15 +2,10 @@ import multiprocessing as mp
import pickle
import shelve
from collections import defaultdict
from functools import partial
from pathlib import Path
from typing import Union
from tqdm import trange
from lib.objects.map import Map
from lib.utils.parallel import run_n_in_parallel
class Generator:
@ -109,7 +104,7 @@ class Generator:
trajectory=trajectory,
labels=labels)
if 'map' not in f:
f['map'] = dict(map=self.map, name=f'map_{self.map.name}')
f['map'] = dict(map=self.map, name=self.map.name)
@staticmethod
def _remove_unequal(hom_dict):