New Dataset Generator, How to differentiate the loss function?
This commit is contained in:
@ -12,7 +12,8 @@ import pytorch_lightning as pl
|
||||
###################
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from dataset.dataset import TrajData
|
||||
from dataset.dataset import TrajDataset
|
||||
from lib.objects.map import MapStorage
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
@ -77,7 +78,8 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
# Data loading
|
||||
# =============================================================================
|
||||
# Dataset
|
||||
self.dataset = TrajData('data')
|
||||
self.dataset = TrajDataset('data')
|
||||
self.map_storage = MapStorage(self.hparams.data_param.map_root)
|
||||
|
||||
def size(self):
|
||||
return self.shape
|
||||
@ -176,6 +178,17 @@ class MergingLayer(nn.Module):
|
||||
return
|
||||
|
||||
|
||||
class FlipTensor(nn.Module):
|
||||
def __init__(self, dim=-2):
|
||||
super(FlipTensor, self).__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
idx = [i for i in range(x.size(self.dim) - 1, -1, -1)]
|
||||
idx = torch.as_tensor(idx).long()
|
||||
inverted_tensor = x.index_select(self.dim, idx)
|
||||
return inverted_tensor
|
||||
|
||||
#
|
||||
# Sub - Modules
|
||||
###################
|
||||
|
@ -3,9 +3,7 @@ from lib.models.blocks import RecurrentModule, ConvModule, DeConvModule, Generat
|
||||
|
||||
class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
|
||||
@classmethod
|
||||
def name(cls):
|
||||
pass
|
||||
name = 'CNNRouteGenerator'
|
||||
|
||||
def configure_optimizers(self):
|
||||
pass
|
||||
|
@ -0,0 +1,49 @@
|
||||
from lib.models.blocks import RecurrentModule, ConvModule, DeConvModule, Generator, LightningBaseModule
|
||||
from lib.models.losses import BinaryHomotopicLoss
|
||||
from lib.objects.map import Map
|
||||
from lib.objects.trajectory import Trajectory
|
||||
|
||||
import torch
|
||||
import torch.functional as F
|
||||
import torch.nn as nn
|
||||
|
||||
nn.MSELoss
|
||||
|
||||
class LinearRouteGeneratorModel(LightningBaseModule):
|
||||
|
||||
name = 'LinearRouteGenerator'
|
||||
|
||||
def configure_optimizers(self):
|
||||
pass
|
||||
|
||||
def validation_step(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def validation_end(self, outputs):
|
||||
pass
|
||||
|
||||
def training_step(self, batch, batch_nb, *args, **kwargs):
|
||||
# Type Annotation
|
||||
traj_x: Trajectory
|
||||
traj_o: Trajectory
|
||||
label_x: int
|
||||
map_name: str
|
||||
map_x: Map
|
||||
# Batch unpacking
|
||||
traj_x, traj_o, label_x, map_name = batch
|
||||
map_x = self.map_storage[map_name]
|
||||
pred_y = self(map_x, traj_x, label_x)
|
||||
|
||||
loss = self.loss(traj_x, pred_y)
|
||||
return dict(loss=loss, log=dict(loss=loss))
|
||||
|
||||
def test_step(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __init__(self, *params):
|
||||
super(LinearRouteGeneratorModel, self).__init__(*params)
|
||||
|
||||
self.loss = BinaryHomotopicLoss(self.map_storage)
|
||||
|
||||
def forward(self, map_x, traj_x, label_x):
|
||||
pass
|
||||
|
21
lib/models/losses.py
Normal file
21
lib/models/losses.py
Normal file
@ -0,0 +1,21 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import pytorch_lightning as pl
|
||||
|
||||
from lib.models.blocks import FlipTensor
|
||||
from lib.objects.map import MapStorage
|
||||
|
||||
|
||||
class BinaryHomotopicLoss(nn.Module):
|
||||
def __init__(self, map_storage: MapStorage):
|
||||
super(BinaryHomotopicLoss, self).__init__()
|
||||
self.map_storage = map_storage
|
||||
self.flipper = FlipTensor()
|
||||
|
||||
def forward(self, x:torch.Tensor, y: torch.Tensor, mapnames: str):
|
||||
y_flipepd = self.flipper(y)
|
||||
circle = torch.cat((x, y_flipepd), dim=-1)
|
||||
masp = self.map_storage[mapname].are
|
||||
|
||||
|
@ -1,4 +1,7 @@
|
||||
import shelve
|
||||
from pathlib import Path
|
||||
from collections import UserDict
|
||||
|
||||
|
||||
import copy
|
||||
from math import sqrt
|
||||
@ -130,3 +133,30 @@ class Map(object):
|
||||
# https: // matplotlib.org / api / pyplot_summary.html?highlight = colormaps
|
||||
img = ax.imshow(self.as_array, cmap='Greys_r')
|
||||
return dict(img=img, fig=fig, ax=ax)
|
||||
|
||||
|
||||
class MapStorage(object):
|
||||
|
||||
def __init__(self, map_root, load_all=False):
|
||||
self.data = dict()
|
||||
self.map_root = Path(map_root)
|
||||
if load_all:
|
||||
for map_file in self.map_root.glob('*.bmp'):
|
||||
_ = self[map_file.name]
|
||||
|
||||
def __getitem__(self, item):
|
||||
if item in hasattr(self, item):
|
||||
return self.__getattribute__(item)
|
||||
else:
|
||||
with shelve.open(self.map_root / f'{item}.pik', flag='r') as d:
|
||||
self.__setattr__(item, d['map']['map'])
|
||||
return self[item]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -2,15 +2,10 @@ import multiprocessing as mp
|
||||
import pickle
|
||||
import shelve
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from tqdm import trange
|
||||
|
||||
from lib.objects.map import Map
|
||||
from lib.utils.parallel import run_n_in_parallel
|
||||
|
||||
|
||||
class Generator:
|
||||
@ -109,7 +104,7 @@ class Generator:
|
||||
trajectory=trajectory,
|
||||
labels=labels)
|
||||
if 'map' not in f:
|
||||
f['map'] = dict(map=self.map, name=f'map_{self.map.name}')
|
||||
f['map'] = dict(map=self.map, name=self.map.name)
|
||||
|
||||
@staticmethod
|
||||
def _remove_unequal(hom_dict):
|
||||
|
Reference in New Issue
Block a user