From 8424251ca0f13f474fb2bdcf2b4b8b9b6b7741b7 Mon Sep 17 00:00:00 2001 From: Steffen Illium Date: Tue, 18 Feb 2020 21:58:31 +0100 Subject: [PATCH] New Dataset Generator, How to differentiate the loss function? --- .gitignore | 71 ++++++++++++++++++++++++++++++++++ .idea/deployment.xml | 13 ++----- .idea/dictionaries/steffen.xml | 2 + .idea/hom_traj_gen.iml | 2 +- .idea/misc.xml | 2 +- dataset/dataset.py | 64 +++++++++++++++++++++++------- lib/models/blocks.py | 17 +++++++- lib/models/cnn.py | 4 +- lib/models/full.py | 49 +++++++++++++++++++++++ lib/models/losses.py | 21 ++++++++++ lib/objects/map.py | 30 ++++++++++++++ lib/preprocessing/generator.py | 7 +--- main.py | 7 ++-- 13 files changed, 250 insertions(+), 39 deletions(-) create mode 100644 .gitignore create mode 100644 lib/models/losses.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..37be74a --- /dev/null +++ b/.gitignore @@ -0,0 +1,71 @@ +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/**/usage.statistics.xml +.idea/**/dictionaries +.idea/**/shelf + +# Generated files +.idea/**/contentModel.xml + +# Sensitive or high-churn files +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml +.idea/**/dbnavigator.xml + +# Gradle +.idea/**/gradle.xml +.idea/**/libraries + +# Gradle and Maven with auto-import +# When using Gradle or Maven with auto-import, you should exclude module files, +# since they will be recreated, and may cause churn. Uncomment if using +# auto-import. +# .idea/artifacts +# .idea/compiler.xml +# .idea/jarRepositories.xml +# .idea/modules.xml +# .idea/*.iml +# .idea/modules +# *.iml +# *.ipr + +# CMake +cmake-build-*/ + +# Mongo Explorer plugin +.idea/**/mongoSettings.xml + +# File-based project format +*.iws + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties + +# Editor-based Rest Client +.idea/httpRequests + +# Android studio 3.1+ serialized cache file +.idea/caches/build_file_checksums.ser diff --git a/.idea/deployment.xml b/.idea/deployment.xml index e99a9d5..80ddd26 100644 --- a/.idea/deployment.xml +++ b/.idea/deployment.xml @@ -1,18 +1,11 @@ - + - + - - - - - - - - + diff --git a/.idea/dictionaries/steffen.xml b/.idea/dictionaries/steffen.xml index 92d2dbb..21f6ccf 100644 --- a/.idea/dictionaries/steffen.xml +++ b/.idea/dictionaries/steffen.xml @@ -3,8 +3,10 @@ conv homotopic + hparams hyperparamter numlayers + traj \ No newline at end of file diff --git a/.idea/hom_traj_gen.iml b/.idea/hom_traj_gen.iml index 4b1d9c2..241d6f7 100644 --- a/.idea/hom_traj_gen.iml +++ b/.idea/hom_traj_gen.iml @@ -2,7 +2,7 @@ - + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml index f164374..ac76099 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -3,5 +3,5 @@ - + \ No newline at end of file diff --git a/dataset/dataset.py b/dataset/dataset.py index 49130c5..24ce303 100644 --- a/dataset/dataset.py +++ b/dataset/dataset.py @@ -1,28 +1,31 @@ import shelve from pathlib import Path +from typing import Union import torch +from random import choice from torch.utils.data import ConcatDataset, Dataset from lib.objects.map import Map from lib.preprocessing.generator import Generator -class TrajDataset(Dataset): +class TrajPairDataset(Dataset): def __init__(self, data): - super(TrajDataset, self).__init__() + super(TrajPairDataset, self).__init__() self.alternatives = data['alternatives'] self.trajectory = data['trajectory'] self.labels = data['labels'] + self.mapname = data['map']['name'][4:] if data['map']['name'].startswith('map_') else data['map']['name'] def __len__(self): return len(self.alternatives) def __getitem__(self, item): - return self.trajectory.vertices, self.alternatives[item].vertices, self.labels[item] + return self.trajectory.vertices, self.alternatives[item].vertices, self.labels[item], self.mapname -class DataSetMapping(Dataset): +class DatasetMapping(Dataset): def __init__(self, dataset, mapping): self._dataset = dataset self._mapping = mapping @@ -34,12 +37,12 @@ class DataSetMapping(Dataset): return self._dataset[self._mapping[item]] -class TrajData(object): +class TrajPairData(object): @property def name(self): return self.__class__.__name__ - def __init__(self, data_root, mapname='tate_sw', trajectories=1000, alternatives=10, + def __init__(self, data_root, map_root: Union[Path, str] = '', mapname='tate_sw', trajectories=1000, alternatives=10, train_val_test_split=(0.6, 0.2, 0.2), rebuild=False, equal_samples=True, **_): self.rebuild = rebuild @@ -49,13 +52,13 @@ class TrajData(object): self.mapname = mapname self.train_split, self.val_split, self.test_split = train_val_test_split self.data_root = Path(data_root) + self.maps_root = Path(data_root) if data_root else Path() / 'res' / 'maps' self._dataset = None self._dataset, self._train_map, self._val_map, self._test_map = self._load_dataset() def _build_data_on_demand(self): - maps_root = Path() / 'res' / 'maps' - map_object = Map(self.mapname).from_image(maps_root / f'{self.mapname}.bmp') - assert maps_root.exists() + map_object = Map(self.mapname).from_image(self.maps_root / f'{self.mapname}.bmp') + assert self.maps_root.exists() dataset_file = Path(self.data_root) / f'{self.mapname}.pik' if dataset_file.exists() and self.rebuild: dataset_file.unlink() @@ -68,7 +71,7 @@ class TrajData(object): def _load_dataset(self): assert self._build_data_on_demand() with shelve.open(str(self.data_root / f'{self.mapname}.pik')) as d: - dataset = ConcatDataset([TrajDataset(d[key]) for key in d.keys() if key != 'map']) + dataset = ConcatDataset([TrajPairDataset(d[key]) for key in d.keys() if key != 'map']) indices = torch.randperm(len(dataset)) train_size = int(len(dataset) * self.train_split) @@ -82,15 +85,50 @@ class TrajData(object): @property def train_dataset(self): - return DataSetMapping(self._dataset, self._train_map) + return DatasetMapping(self._dataset, self._train_map) @property def val_dataset(self): - return DataSetMapping(self._dataset, self._val_map) + return DatasetMapping(self._dataset, self._val_map) @property def test_dataset(self): - return DataSetMapping(self._dataset, self._test_map) + return DatasetMapping(self._dataset, self._test_map) def get_datasets(self): return self.train_dataset, self.val_dataset, self.test_dataset + + +class TrajDataset(Dataset): + + def __init__(self, data_root, maps_root: Union[Path, str] = '', mapname='tate_sw', length=100.000, **_): + super(TrajDataset, self).__init__() + self.mapname = mapname + self.maps_root = maps_root + self.data_root = data_root + self._len = length + + self._map_obj = Map(self.mapname).from_image(self.maps_root / f'{self.mapname}.bmp') + + def __len__(self): + return self._len + + def __getitem__(self, item): + trajectory = self._map_obj.get_random_trajectory() + label = choice([0, 1]) + return trajectory.vertices, None, label, self.mapname + + @property + def train_dataset(self): + return self + + @property + def val_dataset(self): + return self + + @property + def test_dataset(self): + return self + + def get_datasets(self): + return self, self, self diff --git a/lib/models/blocks.py b/lib/models/blocks.py index 9ad85de..d972c30 100644 --- a/lib/models/blocks.py +++ b/lib/models/blocks.py @@ -12,7 +12,8 @@ import pytorch_lightning as pl ################### from torch.utils.data import DataLoader -from dataset.dataset import TrajData +from dataset.dataset import TrajDataset +from lib.objects.map import MapStorage class Flatten(nn.Module): @@ -77,7 +78,8 @@ class LightningBaseModule(pl.LightningModule, ABC): # Data loading # ============================================================================= # Dataset - self.dataset = TrajData('data') + self.dataset = TrajDataset('data') + self.map_storage = MapStorage(self.hparams.data_param.map_root) def size(self): return self.shape @@ -176,6 +178,17 @@ class MergingLayer(nn.Module): return +class FlipTensor(nn.Module): + def __init__(self, dim=-2): + super(FlipTensor, self).__init__() + self.dim = dim + + def forward(self, x): + idx = [i for i in range(x.size(self.dim) - 1, -1, -1)] + idx = torch.as_tensor(idx).long() + inverted_tensor = x.index_select(self.dim, idx) + return inverted_tensor + # # Sub - Modules ################### diff --git a/lib/models/cnn.py b/lib/models/cnn.py index 5b0d9d1..66d07f1 100644 --- a/lib/models/cnn.py +++ b/lib/models/cnn.py @@ -3,9 +3,7 @@ from lib.models.blocks import RecurrentModule, ConvModule, DeConvModule, Generat class CNNRouteGeneratorModel(LightningBaseModule): - @classmethod - def name(cls): - pass + name = 'CNNRouteGenerator' def configure_optimizers(self): pass diff --git a/lib/models/full.py b/lib/models/full.py index e69de29..312f43b 100644 --- a/lib/models/full.py +++ b/lib/models/full.py @@ -0,0 +1,49 @@ +from lib.models.blocks import RecurrentModule, ConvModule, DeConvModule, Generator, LightningBaseModule +from lib.models.losses import BinaryHomotopicLoss +from lib.objects.map import Map +from lib.objects.trajectory import Trajectory + +import torch +import torch.functional as F +import torch.nn as nn + +nn.MSELoss + +class LinearRouteGeneratorModel(LightningBaseModule): + + name = 'LinearRouteGenerator' + + def configure_optimizers(self): + pass + + def validation_step(self, *args, **kwargs): + pass + + def validation_end(self, outputs): + pass + + def training_step(self, batch, batch_nb, *args, **kwargs): + # Type Annotation + traj_x: Trajectory + traj_o: Trajectory + label_x: int + map_name: str + map_x: Map + # Batch unpacking + traj_x, traj_o, label_x, map_name = batch + map_x = self.map_storage[map_name] + pred_y = self(map_x, traj_x, label_x) + + loss = self.loss(traj_x, pred_y) + return dict(loss=loss, log=dict(loss=loss)) + + def test_step(self, *args, **kwargs): + pass + + def __init__(self, *params): + super(LinearRouteGeneratorModel, self).__init__(*params) + + self.loss = BinaryHomotopicLoss(self.map_storage) + + def forward(self, map_x, traj_x, label_x): + pass diff --git a/lib/models/losses.py b/lib/models/losses.py new file mode 100644 index 0000000..a2cbc45 --- /dev/null +++ b/lib/models/losses.py @@ -0,0 +1,21 @@ +import torch +from torch import nn +import torch.nn.functional as F +import pytorch_lightning as pl + +from lib.models.blocks import FlipTensor +from lib.objects.map import MapStorage + + +class BinaryHomotopicLoss(nn.Module): + def __init__(self, map_storage: MapStorage): + super(BinaryHomotopicLoss, self).__init__() + self.map_storage = map_storage + self.flipper = FlipTensor() + + def forward(self, x:torch.Tensor, y: torch.Tensor, mapnames: str): + y_flipepd = self.flipper(y) + circle = torch.cat((x, y_flipepd), dim=-1) + masp = self.map_storage[mapname].are + + diff --git a/lib/objects/map.py b/lib/objects/map.py index d3152f5..fb04c52 100644 --- a/lib/objects/map.py +++ b/lib/objects/map.py @@ -1,4 +1,7 @@ +import shelve from pathlib import Path +from collections import UserDict + import copy from math import sqrt @@ -130,3 +133,30 @@ class Map(object): # https: // matplotlib.org / api / pyplot_summary.html?highlight = colormaps img = ax.imshow(self.as_array, cmap='Greys_r') return dict(img=img, fig=fig, ax=ax) + + +class MapStorage(object): + + def __init__(self, map_root, load_all=False): + self.data = dict() + self.map_root = Path(map_root) + if load_all: + for map_file in self.map_root.glob('*.bmp'): + _ = self[map_file.name] + + def __getitem__(self, item): + if item in hasattr(self, item): + return self.__getattribute__(item) + else: + with shelve.open(self.map_root / f'{item}.pik', flag='r') as d: + self.__setattr__(item, d['map']['map']) + return self[item] + + + + + + + + + diff --git a/lib/preprocessing/generator.py b/lib/preprocessing/generator.py index f4ee676..dada23c 100644 --- a/lib/preprocessing/generator.py +++ b/lib/preprocessing/generator.py @@ -2,15 +2,10 @@ import multiprocessing as mp import pickle import shelve from collections import defaultdict -from functools import partial from pathlib import Path -from typing import Union - -from tqdm import trange from lib.objects.map import Map -from lib.utils.parallel import run_n_in_parallel class Generator: @@ -109,7 +104,7 @@ class Generator: trajectory=trajectory, labels=labels) if 'map' not in f: - f['map'] = dict(map=self.map, name=f'map_{self.map.name}') + f['map'] = dict(map=self.map, name=self.map.name) @staticmethod def _remove_unequal(hom_dict): diff --git a/main.py b/main.py index 8564e68..eab41fb 100644 --- a/main.py +++ b/main.py @@ -10,7 +10,7 @@ import warnings from pytorch_lightning import Trainer from torch.utils.data import DataLoader -from dataset.dataset import TrajData +from dataset.dataset import TrajPairData from lib.utils.config import Config from lib.utils.logging import Logger @@ -32,7 +32,8 @@ main_arg_parser.add_argument("--main_seed", type=int, default=69, help="") # Data Parameters main_arg_parser.add_argument("--data_worker", type=int, default=10, help="") main_arg_parser.add_argument("--data_batchsize", type=int, default=100, help="") -main_arg_parser.add_argument("--data_root", type=str, default='../data/rpoot', help="") +main_arg_parser.add_argument("--data_root", type=str, default='/data/', help="") +main_arg_parser.add_argument("--map_root", type=str, default='/res/maps', help="") # Transformations main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="") @@ -65,7 +66,7 @@ config = Config.read_namespace(args) # TESTING ONLY # # ============================================================================= hparams = config.model_paramters -dataset = TrajData('data', mapname='tate', alternatives=10000, trajectories=2500) +dataset = TrajPairData('data', mapname='tate', alternatives=10000, trajectories=2500) dataloader = DataLoader(dataset=dataset.train_dataset, shuffle=True, batch_size=hparams.data_param.batchsize, num_workers=hparams.data_param.worker)