New Dataset Generator, How to differentiate the loss function?

This commit is contained in:
Steffen Illium 2020-02-18 21:58:31 +01:00
parent 61c5cb44a0
commit 8424251ca0
13 changed files with 250 additions and 39 deletions

71
.gitignore vendored Normal file
View File

@ -0,0 +1,71 @@
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
# User-specific stuff
.idea/**/workspace.xml
.idea/**/tasks.xml
.idea/**/usage.statistics.xml
.idea/**/dictionaries
.idea/**/shelf
# Generated files
.idea/**/contentModel.xml
# Sensitive or high-churn files
.idea/**/dataSources/
.idea/**/dataSources.ids
.idea/**/dataSources.local.xml
.idea/**/sqlDataSources.xml
.idea/**/dynamic.xml
.idea/**/uiDesigner.xml
.idea/**/dbnavigator.xml
# Gradle
.idea/**/gradle.xml
.idea/**/libraries
# Gradle and Maven with auto-import
# When using Gradle or Maven with auto-import, you should exclude module files,
# since they will be recreated, and may cause churn. Uncomment if using
# auto-import.
# .idea/artifacts
# .idea/compiler.xml
# .idea/jarRepositories.xml
# .idea/modules.xml
# .idea/*.iml
# .idea/modules
# *.iml
# *.ipr
# CMake
cmake-build-*/
# Mongo Explorer plugin
.idea/**/mongoSettings.xml
# File-based project format
*.iws
# IntelliJ
out/
# mpeltonen/sbt-idea plugin
.idea_modules/
# JIRA plugin
atlassian-ide-plugin.xml
# Cursive Clojure plugin
.idea/replstate.xml
# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties
# Editor-based Rest Client
.idea/httpRequests
# Android studio 3.1+ serialized cache file
.idea/caches/build_file_checksums.ser

13
.idea/deployment.xml generated
View File

@ -1,18 +1,11 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="traj_gen-AiMachine">
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="steffen@aimachine:22">
<serverData>
<paths name="ErLoWa-AiMachine">
<paths name="steffen@aimachine:22">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="traj_gen-AiMachine">
<serverdata>
<mappings>
<mapping deploy="/hom_traj_gen" local="$PROJECT_DIR$" web="/" />
<mapping deploy="\" local="$PROJECT_DIR$" />
</mappings>
</serverdata>
</paths>

View File

@ -3,8 +3,10 @@
<words>
<w>conv</w>
<w>homotopic</w>
<w>hparams</w>
<w>hyperparamter</w>
<w>numlayers</w>
<w>traj</w>
</words>
</dictionary>
</component>

View File

@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="traj_gen@AiMachine" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="Remote Python 3.7.6 (sftp://steffen@aimachine:22/home/steffen/envs/traj_gen/bin/python)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

2
.idea/misc.xml generated
View File

@ -3,5 +3,5 @@
<component name="JavaScriptSettings">
<option name="languageLevel" value="ES6" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="traj_gen@AiMachine" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="Remote Python 3.7.6 (sftp://steffen@aimachine:22/home/steffen/envs/traj_gen/bin/python)" project-jdk-type="Python SDK" />
</project>

View File

@ -1,28 +1,31 @@
import shelve
from pathlib import Path
from typing import Union
import torch
from random import choice
from torch.utils.data import ConcatDataset, Dataset
from lib.objects.map import Map
from lib.preprocessing.generator import Generator
class TrajDataset(Dataset):
class TrajPairDataset(Dataset):
def __init__(self, data):
super(TrajDataset, self).__init__()
super(TrajPairDataset, self).__init__()
self.alternatives = data['alternatives']
self.trajectory = data['trajectory']
self.labels = data['labels']
self.mapname = data['map']['name'][4:] if data['map']['name'].startswith('map_') else data['map']['name']
def __len__(self):
return len(self.alternatives)
def __getitem__(self, item):
return self.trajectory.vertices, self.alternatives[item].vertices, self.labels[item]
return self.trajectory.vertices, self.alternatives[item].vertices, self.labels[item], self.mapname
class DataSetMapping(Dataset):
class DatasetMapping(Dataset):
def __init__(self, dataset, mapping):
self._dataset = dataset
self._mapping = mapping
@ -34,12 +37,12 @@ class DataSetMapping(Dataset):
return self._dataset[self._mapping[item]]
class TrajData(object):
class TrajPairData(object):
@property
def name(self):
return self.__class__.__name__
def __init__(self, data_root, mapname='tate_sw', trajectories=1000, alternatives=10,
def __init__(self, data_root, map_root: Union[Path, str] = '', mapname='tate_sw', trajectories=1000, alternatives=10,
train_val_test_split=(0.6, 0.2, 0.2), rebuild=False, equal_samples=True, **_):
self.rebuild = rebuild
@ -49,13 +52,13 @@ class TrajData(object):
self.mapname = mapname
self.train_split, self.val_split, self.test_split = train_val_test_split
self.data_root = Path(data_root)
self.maps_root = Path(data_root) if data_root else Path() / 'res' / 'maps'
self._dataset = None
self._dataset, self._train_map, self._val_map, self._test_map = self._load_dataset()
def _build_data_on_demand(self):
maps_root = Path() / 'res' / 'maps'
map_object = Map(self.mapname).from_image(maps_root / f'{self.mapname}.bmp')
assert maps_root.exists()
map_object = Map(self.mapname).from_image(self.maps_root / f'{self.mapname}.bmp')
assert self.maps_root.exists()
dataset_file = Path(self.data_root) / f'{self.mapname}.pik'
if dataset_file.exists() and self.rebuild:
dataset_file.unlink()
@ -68,7 +71,7 @@ class TrajData(object):
def _load_dataset(self):
assert self._build_data_on_demand()
with shelve.open(str(self.data_root / f'{self.mapname}.pik')) as d:
dataset = ConcatDataset([TrajDataset(d[key]) for key in d.keys() if key != 'map'])
dataset = ConcatDataset([TrajPairDataset(d[key]) for key in d.keys() if key != 'map'])
indices = torch.randperm(len(dataset))
train_size = int(len(dataset) * self.train_split)
@ -82,15 +85,50 @@ class TrajData(object):
@property
def train_dataset(self):
return DataSetMapping(self._dataset, self._train_map)
return DatasetMapping(self._dataset, self._train_map)
@property
def val_dataset(self):
return DataSetMapping(self._dataset, self._val_map)
return DatasetMapping(self._dataset, self._val_map)
@property
def test_dataset(self):
return DataSetMapping(self._dataset, self._test_map)
return DatasetMapping(self._dataset, self._test_map)
def get_datasets(self):
return self.train_dataset, self.val_dataset, self.test_dataset
class TrajDataset(Dataset):
def __init__(self, data_root, maps_root: Union[Path, str] = '', mapname='tate_sw', length=100.000, **_):
super(TrajDataset, self).__init__()
self.mapname = mapname
self.maps_root = maps_root
self.data_root = data_root
self._len = length
self._map_obj = Map(self.mapname).from_image(self.maps_root / f'{self.mapname}.bmp')
def __len__(self):
return self._len
def __getitem__(self, item):
trajectory = self._map_obj.get_random_trajectory()
label = choice([0, 1])
return trajectory.vertices, None, label, self.mapname
@property
def train_dataset(self):
return self
@property
def val_dataset(self):
return self
@property
def test_dataset(self):
return self
def get_datasets(self):
return self, self, self

View File

@ -12,7 +12,8 @@ import pytorch_lightning as pl
###################
from torch.utils.data import DataLoader
from dataset.dataset import TrajData
from dataset.dataset import TrajDataset
from lib.objects.map import MapStorage
class Flatten(nn.Module):
@ -77,7 +78,8 @@ class LightningBaseModule(pl.LightningModule, ABC):
# Data loading
# =============================================================================
# Dataset
self.dataset = TrajData('data')
self.dataset = TrajDataset('data')
self.map_storage = MapStorage(self.hparams.data_param.map_root)
def size(self):
return self.shape
@ -176,6 +178,17 @@ class MergingLayer(nn.Module):
return
class FlipTensor(nn.Module):
def __init__(self, dim=-2):
super(FlipTensor, self).__init__()
self.dim = dim
def forward(self, x):
idx = [i for i in range(x.size(self.dim) - 1, -1, -1)]
idx = torch.as_tensor(idx).long()
inverted_tensor = x.index_select(self.dim, idx)
return inverted_tensor
#
# Sub - Modules
###################

View File

@ -3,9 +3,7 @@ from lib.models.blocks import RecurrentModule, ConvModule, DeConvModule, Generat
class CNNRouteGeneratorModel(LightningBaseModule):
@classmethod
def name(cls):
pass
name = 'CNNRouteGenerator'
def configure_optimizers(self):
pass

View File

@ -0,0 +1,49 @@
from lib.models.blocks import RecurrentModule, ConvModule, DeConvModule, Generator, LightningBaseModule
from lib.models.losses import BinaryHomotopicLoss
from lib.objects.map import Map
from lib.objects.trajectory import Trajectory
import torch
import torch.functional as F
import torch.nn as nn
nn.MSELoss
class LinearRouteGeneratorModel(LightningBaseModule):
name = 'LinearRouteGenerator'
def configure_optimizers(self):
pass
def validation_step(self, *args, **kwargs):
pass
def validation_end(self, outputs):
pass
def training_step(self, batch, batch_nb, *args, **kwargs):
# Type Annotation
traj_x: Trajectory
traj_o: Trajectory
label_x: int
map_name: str
map_x: Map
# Batch unpacking
traj_x, traj_o, label_x, map_name = batch
map_x = self.map_storage[map_name]
pred_y = self(map_x, traj_x, label_x)
loss = self.loss(traj_x, pred_y)
return dict(loss=loss, log=dict(loss=loss))
def test_step(self, *args, **kwargs):
pass
def __init__(self, *params):
super(LinearRouteGeneratorModel, self).__init__(*params)
self.loss = BinaryHomotopicLoss(self.map_storage)
def forward(self, map_x, traj_x, label_x):
pass

21
lib/models/losses.py Normal file
View File

@ -0,0 +1,21 @@
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl
from lib.models.blocks import FlipTensor
from lib.objects.map import MapStorage
class BinaryHomotopicLoss(nn.Module):
def __init__(self, map_storage: MapStorage):
super(BinaryHomotopicLoss, self).__init__()
self.map_storage = map_storage
self.flipper = FlipTensor()
def forward(self, x:torch.Tensor, y: torch.Tensor, mapnames: str):
y_flipepd = self.flipper(y)
circle = torch.cat((x, y_flipepd), dim=-1)
masp = self.map_storage[mapname].are

View File

@ -1,4 +1,7 @@
import shelve
from pathlib import Path
from collections import UserDict
import copy
from math import sqrt
@ -130,3 +133,30 @@ class Map(object):
# https: // matplotlib.org / api / pyplot_summary.html?highlight = colormaps
img = ax.imshow(self.as_array, cmap='Greys_r')
return dict(img=img, fig=fig, ax=ax)
class MapStorage(object):
def __init__(self, map_root, load_all=False):
self.data = dict()
self.map_root = Path(map_root)
if load_all:
for map_file in self.map_root.glob('*.bmp'):
_ = self[map_file.name]
def __getitem__(self, item):
if item in hasattr(self, item):
return self.__getattribute__(item)
else:
with shelve.open(self.map_root / f'{item}.pik', flag='r') as d:
self.__setattr__(item, d['map']['map'])
return self[item]

View File

@ -2,15 +2,10 @@ import multiprocessing as mp
import pickle
import shelve
from collections import defaultdict
from functools import partial
from pathlib import Path
from typing import Union
from tqdm import trange
from lib.objects.map import Map
from lib.utils.parallel import run_n_in_parallel
class Generator:
@ -109,7 +104,7 @@ class Generator:
trajectory=trajectory,
labels=labels)
if 'map' not in f:
f['map'] = dict(map=self.map, name=f'map_{self.map.name}')
f['map'] = dict(map=self.map, name=self.map.name)
@staticmethod
def _remove_unequal(hom_dict):

View File

@ -10,7 +10,7 @@ import warnings
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from dataset.dataset import TrajData
from dataset.dataset import TrajPairData
from lib.utils.config import Config
from lib.utils.logging import Logger
@ -32,7 +32,8 @@ main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
# Data Parameters
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
main_arg_parser.add_argument("--data_batchsize", type=int, default=100, help="")
main_arg_parser.add_argument("--data_root", type=str, default='../data/rpoot', help="")
main_arg_parser.add_argument("--data_root", type=str, default='/data/', help="")
main_arg_parser.add_argument("--map_root", type=str, default='/res/maps', help="")
# Transformations
main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="")
@ -65,7 +66,7 @@ config = Config.read_namespace(args)
# TESTING ONLY #
# =============================================================================
hparams = config.model_paramters
dataset = TrajData('data', mapname='tate', alternatives=10000, trajectories=2500)
dataset = TrajPairData('data', mapname='tate', alternatives=10000, trajectories=2500)
dataloader = DataLoader(dataset=dataset.train_dataset, shuffle=True,
batch_size=hparams.data_param.batchsize,
num_workers=hparams.data_param.worker)