New Dataset Generator, How to differentiate the loss function?
This commit is contained in:
parent
61c5cb44a0
commit
8424251ca0
71
.gitignore
vendored
Normal file
71
.gitignore
vendored
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
|
||||||
|
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
||||||
|
|
||||||
|
# User-specific stuff
|
||||||
|
.idea/**/workspace.xml
|
||||||
|
.idea/**/tasks.xml
|
||||||
|
.idea/**/usage.statistics.xml
|
||||||
|
.idea/**/dictionaries
|
||||||
|
.idea/**/shelf
|
||||||
|
|
||||||
|
# Generated files
|
||||||
|
.idea/**/contentModel.xml
|
||||||
|
|
||||||
|
# Sensitive or high-churn files
|
||||||
|
.idea/**/dataSources/
|
||||||
|
.idea/**/dataSources.ids
|
||||||
|
.idea/**/dataSources.local.xml
|
||||||
|
.idea/**/sqlDataSources.xml
|
||||||
|
.idea/**/dynamic.xml
|
||||||
|
.idea/**/uiDesigner.xml
|
||||||
|
.idea/**/dbnavigator.xml
|
||||||
|
|
||||||
|
# Gradle
|
||||||
|
.idea/**/gradle.xml
|
||||||
|
.idea/**/libraries
|
||||||
|
|
||||||
|
# Gradle and Maven with auto-import
|
||||||
|
# When using Gradle or Maven with auto-import, you should exclude module files,
|
||||||
|
# since they will be recreated, and may cause churn. Uncomment if using
|
||||||
|
# auto-import.
|
||||||
|
# .idea/artifacts
|
||||||
|
# .idea/compiler.xml
|
||||||
|
# .idea/jarRepositories.xml
|
||||||
|
# .idea/modules.xml
|
||||||
|
# .idea/*.iml
|
||||||
|
# .idea/modules
|
||||||
|
# *.iml
|
||||||
|
# *.ipr
|
||||||
|
|
||||||
|
# CMake
|
||||||
|
cmake-build-*/
|
||||||
|
|
||||||
|
# Mongo Explorer plugin
|
||||||
|
.idea/**/mongoSettings.xml
|
||||||
|
|
||||||
|
# File-based project format
|
||||||
|
*.iws
|
||||||
|
|
||||||
|
# IntelliJ
|
||||||
|
out/
|
||||||
|
|
||||||
|
# mpeltonen/sbt-idea plugin
|
||||||
|
.idea_modules/
|
||||||
|
|
||||||
|
# JIRA plugin
|
||||||
|
atlassian-ide-plugin.xml
|
||||||
|
|
||||||
|
# Cursive Clojure plugin
|
||||||
|
.idea/replstate.xml
|
||||||
|
|
||||||
|
# Crashlytics plugin (for Android Studio and IntelliJ)
|
||||||
|
com_crashlytics_export_strings.xml
|
||||||
|
crashlytics.properties
|
||||||
|
crashlytics-build.properties
|
||||||
|
fabric.properties
|
||||||
|
|
||||||
|
# Editor-based Rest Client
|
||||||
|
.idea/httpRequests
|
||||||
|
|
||||||
|
# Android studio 3.1+ serialized cache file
|
||||||
|
.idea/caches/build_file_checksums.ser
|
13
.idea/deployment.xml
generated
13
.idea/deployment.xml
generated
@ -1,18 +1,11 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project version="4">
|
<project version="4">
|
||||||
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="traj_gen-AiMachine">
|
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="steffen@aimachine:22">
|
||||||
<serverData>
|
<serverData>
|
||||||
<paths name="ErLoWa-AiMachine">
|
<paths name="steffen@aimachine:22">
|
||||||
<serverdata>
|
<serverdata>
|
||||||
<mappings>
|
<mappings>
|
||||||
<mapping local="$PROJECT_DIR$" web="/" />
|
<mapping deploy="\" local="$PROJECT_DIR$" />
|
||||||
</mappings>
|
|
||||||
</serverdata>
|
|
||||||
</paths>
|
|
||||||
<paths name="traj_gen-AiMachine">
|
|
||||||
<serverdata>
|
|
||||||
<mappings>
|
|
||||||
<mapping deploy="/hom_traj_gen" local="$PROJECT_DIR$" web="/" />
|
|
||||||
</mappings>
|
</mappings>
|
||||||
</serverdata>
|
</serverdata>
|
||||||
</paths>
|
</paths>
|
||||||
|
2
.idea/dictionaries/steffen.xml
generated
2
.idea/dictionaries/steffen.xml
generated
@ -3,8 +3,10 @@
|
|||||||
<words>
|
<words>
|
||||||
<w>conv</w>
|
<w>conv</w>
|
||||||
<w>homotopic</w>
|
<w>homotopic</w>
|
||||||
|
<w>hparams</w>
|
||||||
<w>hyperparamter</w>
|
<w>hyperparamter</w>
|
||||||
<w>numlayers</w>
|
<w>numlayers</w>
|
||||||
|
<w>traj</w>
|
||||||
</words>
|
</words>
|
||||||
</dictionary>
|
</dictionary>
|
||||||
</component>
|
</component>
|
2
.idea/hom_traj_gen.iml
generated
2
.idea/hom_traj_gen.iml
generated
@ -2,7 +2,7 @@
|
|||||||
<module type="PYTHON_MODULE" version="4">
|
<module type="PYTHON_MODULE" version="4">
|
||||||
<component name="NewModuleRootManager">
|
<component name="NewModuleRootManager">
|
||||||
<content url="file://$MODULE_DIR$" />
|
<content url="file://$MODULE_DIR$" />
|
||||||
<orderEntry type="jdk" jdkName="traj_gen@AiMachine" jdkType="Python SDK" />
|
<orderEntry type="jdk" jdkName="Remote Python 3.7.6 (sftp://steffen@aimachine:22/home/steffen/envs/traj_gen/bin/python)" jdkType="Python SDK" />
|
||||||
<orderEntry type="sourceFolder" forTests="false" />
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
</component>
|
</component>
|
||||||
</module>
|
</module>
|
2
.idea/misc.xml
generated
2
.idea/misc.xml
generated
@ -3,5 +3,5 @@
|
|||||||
<component name="JavaScriptSettings">
|
<component name="JavaScriptSettings">
|
||||||
<option name="languageLevel" value="ES6" />
|
<option name="languageLevel" value="ES6" />
|
||||||
</component>
|
</component>
|
||||||
<component name="ProjectRootManager" version="2" project-jdk-name="traj_gen@AiMachine" project-jdk-type="Python SDK" />
|
<component name="ProjectRootManager" version="2" project-jdk-name="Remote Python 3.7.6 (sftp://steffen@aimachine:22/home/steffen/envs/traj_gen/bin/python)" project-jdk-type="Python SDK" />
|
||||||
</project>
|
</project>
|
@ -1,28 +1,31 @@
|
|||||||
import shelve
|
import shelve
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from random import choice
|
||||||
from torch.utils.data import ConcatDataset, Dataset
|
from torch.utils.data import ConcatDataset, Dataset
|
||||||
|
|
||||||
from lib.objects.map import Map
|
from lib.objects.map import Map
|
||||||
from lib.preprocessing.generator import Generator
|
from lib.preprocessing.generator import Generator
|
||||||
|
|
||||||
|
|
||||||
class TrajDataset(Dataset):
|
class TrajPairDataset(Dataset):
|
||||||
def __init__(self, data):
|
def __init__(self, data):
|
||||||
super(TrajDataset, self).__init__()
|
super(TrajPairDataset, self).__init__()
|
||||||
self.alternatives = data['alternatives']
|
self.alternatives = data['alternatives']
|
||||||
self.trajectory = data['trajectory']
|
self.trajectory = data['trajectory']
|
||||||
self.labels = data['labels']
|
self.labels = data['labels']
|
||||||
|
self.mapname = data['map']['name'][4:] if data['map']['name'].startswith('map_') else data['map']['name']
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.alternatives)
|
return len(self.alternatives)
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
return self.trajectory.vertices, self.alternatives[item].vertices, self.labels[item]
|
return self.trajectory.vertices, self.alternatives[item].vertices, self.labels[item], self.mapname
|
||||||
|
|
||||||
|
|
||||||
class DataSetMapping(Dataset):
|
class DatasetMapping(Dataset):
|
||||||
def __init__(self, dataset, mapping):
|
def __init__(self, dataset, mapping):
|
||||||
self._dataset = dataset
|
self._dataset = dataset
|
||||||
self._mapping = mapping
|
self._mapping = mapping
|
||||||
@ -34,12 +37,12 @@ class DataSetMapping(Dataset):
|
|||||||
return self._dataset[self._mapping[item]]
|
return self._dataset[self._mapping[item]]
|
||||||
|
|
||||||
|
|
||||||
class TrajData(object):
|
class TrajPairData(object):
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return self.__class__.__name__
|
return self.__class__.__name__
|
||||||
|
|
||||||
def __init__(self, data_root, mapname='tate_sw', trajectories=1000, alternatives=10,
|
def __init__(self, data_root, map_root: Union[Path, str] = '', mapname='tate_sw', trajectories=1000, alternatives=10,
|
||||||
train_val_test_split=(0.6, 0.2, 0.2), rebuild=False, equal_samples=True, **_):
|
train_val_test_split=(0.6, 0.2, 0.2), rebuild=False, equal_samples=True, **_):
|
||||||
|
|
||||||
self.rebuild = rebuild
|
self.rebuild = rebuild
|
||||||
@ -49,13 +52,13 @@ class TrajData(object):
|
|||||||
self.mapname = mapname
|
self.mapname = mapname
|
||||||
self.train_split, self.val_split, self.test_split = train_val_test_split
|
self.train_split, self.val_split, self.test_split = train_val_test_split
|
||||||
self.data_root = Path(data_root)
|
self.data_root = Path(data_root)
|
||||||
|
self.maps_root = Path(data_root) if data_root else Path() / 'res' / 'maps'
|
||||||
self._dataset = None
|
self._dataset = None
|
||||||
self._dataset, self._train_map, self._val_map, self._test_map = self._load_dataset()
|
self._dataset, self._train_map, self._val_map, self._test_map = self._load_dataset()
|
||||||
|
|
||||||
def _build_data_on_demand(self):
|
def _build_data_on_demand(self):
|
||||||
maps_root = Path() / 'res' / 'maps'
|
map_object = Map(self.mapname).from_image(self.maps_root / f'{self.mapname}.bmp')
|
||||||
map_object = Map(self.mapname).from_image(maps_root / f'{self.mapname}.bmp')
|
assert self.maps_root.exists()
|
||||||
assert maps_root.exists()
|
|
||||||
dataset_file = Path(self.data_root) / f'{self.mapname}.pik'
|
dataset_file = Path(self.data_root) / f'{self.mapname}.pik'
|
||||||
if dataset_file.exists() and self.rebuild:
|
if dataset_file.exists() and self.rebuild:
|
||||||
dataset_file.unlink()
|
dataset_file.unlink()
|
||||||
@ -68,7 +71,7 @@ class TrajData(object):
|
|||||||
def _load_dataset(self):
|
def _load_dataset(self):
|
||||||
assert self._build_data_on_demand()
|
assert self._build_data_on_demand()
|
||||||
with shelve.open(str(self.data_root / f'{self.mapname}.pik')) as d:
|
with shelve.open(str(self.data_root / f'{self.mapname}.pik')) as d:
|
||||||
dataset = ConcatDataset([TrajDataset(d[key]) for key in d.keys() if key != 'map'])
|
dataset = ConcatDataset([TrajPairDataset(d[key]) for key in d.keys() if key != 'map'])
|
||||||
indices = torch.randperm(len(dataset))
|
indices = torch.randperm(len(dataset))
|
||||||
|
|
||||||
train_size = int(len(dataset) * self.train_split)
|
train_size = int(len(dataset) * self.train_split)
|
||||||
@ -82,15 +85,50 @@ class TrajData(object):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def train_dataset(self):
|
def train_dataset(self):
|
||||||
return DataSetMapping(self._dataset, self._train_map)
|
return DatasetMapping(self._dataset, self._train_map)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def val_dataset(self):
|
def val_dataset(self):
|
||||||
return DataSetMapping(self._dataset, self._val_map)
|
return DatasetMapping(self._dataset, self._val_map)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def test_dataset(self):
|
def test_dataset(self):
|
||||||
return DataSetMapping(self._dataset, self._test_map)
|
return DatasetMapping(self._dataset, self._test_map)
|
||||||
|
|
||||||
def get_datasets(self):
|
def get_datasets(self):
|
||||||
return self.train_dataset, self.val_dataset, self.test_dataset
|
return self.train_dataset, self.val_dataset, self.test_dataset
|
||||||
|
|
||||||
|
|
||||||
|
class TrajDataset(Dataset):
|
||||||
|
|
||||||
|
def __init__(self, data_root, maps_root: Union[Path, str] = '', mapname='tate_sw', length=100.000, **_):
|
||||||
|
super(TrajDataset, self).__init__()
|
||||||
|
self.mapname = mapname
|
||||||
|
self.maps_root = maps_root
|
||||||
|
self.data_root = data_root
|
||||||
|
self._len = length
|
||||||
|
|
||||||
|
self._map_obj = Map(self.mapname).from_image(self.maps_root / f'{self.mapname}.bmp')
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self._len
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
trajectory = self._map_obj.get_random_trajectory()
|
||||||
|
label = choice([0, 1])
|
||||||
|
return trajectory.vertices, None, label, self.mapname
|
||||||
|
|
||||||
|
@property
|
||||||
|
def train_dataset(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
@property
|
||||||
|
def val_dataset(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
@property
|
||||||
|
def test_dataset(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def get_datasets(self):
|
||||||
|
return self, self, self
|
||||||
|
@ -12,7 +12,8 @@ import pytorch_lightning as pl
|
|||||||
###################
|
###################
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from dataset.dataset import TrajData
|
from dataset.dataset import TrajDataset
|
||||||
|
from lib.objects.map import MapStorage
|
||||||
|
|
||||||
|
|
||||||
class Flatten(nn.Module):
|
class Flatten(nn.Module):
|
||||||
@ -77,7 +78,8 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
|||||||
# Data loading
|
# Data loading
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Dataset
|
# Dataset
|
||||||
self.dataset = TrajData('data')
|
self.dataset = TrajDataset('data')
|
||||||
|
self.map_storage = MapStorage(self.hparams.data_param.map_root)
|
||||||
|
|
||||||
def size(self):
|
def size(self):
|
||||||
return self.shape
|
return self.shape
|
||||||
@ -176,6 +178,17 @@ class MergingLayer(nn.Module):
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class FlipTensor(nn.Module):
|
||||||
|
def __init__(self, dim=-2):
|
||||||
|
super(FlipTensor, self).__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
idx = [i for i in range(x.size(self.dim) - 1, -1, -1)]
|
||||||
|
idx = torch.as_tensor(idx).long()
|
||||||
|
inverted_tensor = x.index_select(self.dim, idx)
|
||||||
|
return inverted_tensor
|
||||||
|
|
||||||
#
|
#
|
||||||
# Sub - Modules
|
# Sub - Modules
|
||||||
###################
|
###################
|
||||||
|
@ -3,9 +3,7 @@ from lib.models.blocks import RecurrentModule, ConvModule, DeConvModule, Generat
|
|||||||
|
|
||||||
class CNNRouteGeneratorModel(LightningBaseModule):
|
class CNNRouteGeneratorModel(LightningBaseModule):
|
||||||
|
|
||||||
@classmethod
|
name = 'CNNRouteGenerator'
|
||||||
def name(cls):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
pass
|
pass
|
||||||
|
@ -0,0 +1,49 @@
|
|||||||
|
from lib.models.blocks import RecurrentModule, ConvModule, DeConvModule, Generator, LightningBaseModule
|
||||||
|
from lib.models.losses import BinaryHomotopicLoss
|
||||||
|
from lib.objects.map import Map
|
||||||
|
from lib.objects.trajectory import Trajectory
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.functional as F
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
nn.MSELoss
|
||||||
|
|
||||||
|
class LinearRouteGeneratorModel(LightningBaseModule):
|
||||||
|
|
||||||
|
name = 'LinearRouteGenerator'
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def validation_step(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def validation_end(self, outputs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_nb, *args, **kwargs):
|
||||||
|
# Type Annotation
|
||||||
|
traj_x: Trajectory
|
||||||
|
traj_o: Trajectory
|
||||||
|
label_x: int
|
||||||
|
map_name: str
|
||||||
|
map_x: Map
|
||||||
|
# Batch unpacking
|
||||||
|
traj_x, traj_o, label_x, map_name = batch
|
||||||
|
map_x = self.map_storage[map_name]
|
||||||
|
pred_y = self(map_x, traj_x, label_x)
|
||||||
|
|
||||||
|
loss = self.loss(traj_x, pred_y)
|
||||||
|
return dict(loss=loss, log=dict(loss=loss))
|
||||||
|
|
||||||
|
def test_step(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __init__(self, *params):
|
||||||
|
super(LinearRouteGeneratorModel, self).__init__(*params)
|
||||||
|
|
||||||
|
self.loss = BinaryHomotopicLoss(self.map_storage)
|
||||||
|
|
||||||
|
def forward(self, map_x, traj_x, label_x):
|
||||||
|
pass
|
21
lib/models/losses.py
Normal file
21
lib/models/losses.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
|
||||||
|
from lib.models.blocks import FlipTensor
|
||||||
|
from lib.objects.map import MapStorage
|
||||||
|
|
||||||
|
|
||||||
|
class BinaryHomotopicLoss(nn.Module):
|
||||||
|
def __init__(self, map_storage: MapStorage):
|
||||||
|
super(BinaryHomotopicLoss, self).__init__()
|
||||||
|
self.map_storage = map_storage
|
||||||
|
self.flipper = FlipTensor()
|
||||||
|
|
||||||
|
def forward(self, x:torch.Tensor, y: torch.Tensor, mapnames: str):
|
||||||
|
y_flipepd = self.flipper(y)
|
||||||
|
circle = torch.cat((x, y_flipepd), dim=-1)
|
||||||
|
masp = self.map_storage[mapname].are
|
||||||
|
|
||||||
|
|
@ -1,4 +1,7 @@
|
|||||||
|
import shelve
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from collections import UserDict
|
||||||
|
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from math import sqrt
|
from math import sqrt
|
||||||
@ -130,3 +133,30 @@ class Map(object):
|
|||||||
# https: // matplotlib.org / api / pyplot_summary.html?highlight = colormaps
|
# https: // matplotlib.org / api / pyplot_summary.html?highlight = colormaps
|
||||||
img = ax.imshow(self.as_array, cmap='Greys_r')
|
img = ax.imshow(self.as_array, cmap='Greys_r')
|
||||||
return dict(img=img, fig=fig, ax=ax)
|
return dict(img=img, fig=fig, ax=ax)
|
||||||
|
|
||||||
|
|
||||||
|
class MapStorage(object):
|
||||||
|
|
||||||
|
def __init__(self, map_root, load_all=False):
|
||||||
|
self.data = dict()
|
||||||
|
self.map_root = Path(map_root)
|
||||||
|
if load_all:
|
||||||
|
for map_file in self.map_root.glob('*.bmp'):
|
||||||
|
_ = self[map_file.name]
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
if item in hasattr(self, item):
|
||||||
|
return self.__getattribute__(item)
|
||||||
|
else:
|
||||||
|
with shelve.open(self.map_root / f'{item}.pik', flag='r') as d:
|
||||||
|
self.__setattr__(item, d['map']['map'])
|
||||||
|
return self[item]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,15 +2,10 @@ import multiprocessing as mp
|
|||||||
import pickle
|
import pickle
|
||||||
import shelve
|
import shelve
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from tqdm import trange
|
|
||||||
|
|
||||||
from lib.objects.map import Map
|
from lib.objects.map import Map
|
||||||
from lib.utils.parallel import run_n_in_parallel
|
|
||||||
|
|
||||||
|
|
||||||
class Generator:
|
class Generator:
|
||||||
@ -109,7 +104,7 @@ class Generator:
|
|||||||
trajectory=trajectory,
|
trajectory=trajectory,
|
||||||
labels=labels)
|
labels=labels)
|
||||||
if 'map' not in f:
|
if 'map' not in f:
|
||||||
f['map'] = dict(map=self.map, name=f'map_{self.map.name}')
|
f['map'] = dict(map=self.map, name=self.map.name)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _remove_unequal(hom_dict):
|
def _remove_unequal(hom_dict):
|
||||||
|
7
main.py
7
main.py
@ -10,7 +10,7 @@ import warnings
|
|||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from dataset.dataset import TrajData
|
from dataset.dataset import TrajPairData
|
||||||
from lib.utils.config import Config
|
from lib.utils.config import Config
|
||||||
from lib.utils.logging import Logger
|
from lib.utils.logging import Logger
|
||||||
|
|
||||||
@ -32,7 +32,8 @@ main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
|
|||||||
# Data Parameters
|
# Data Parameters
|
||||||
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
|
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
|
||||||
main_arg_parser.add_argument("--data_batchsize", type=int, default=100, help="")
|
main_arg_parser.add_argument("--data_batchsize", type=int, default=100, help="")
|
||||||
main_arg_parser.add_argument("--data_root", type=str, default='../data/rpoot', help="")
|
main_arg_parser.add_argument("--data_root", type=str, default='/data/', help="")
|
||||||
|
main_arg_parser.add_argument("--map_root", type=str, default='/res/maps', help="")
|
||||||
|
|
||||||
# Transformations
|
# Transformations
|
||||||
main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="")
|
main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="")
|
||||||
@ -65,7 +66,7 @@ config = Config.read_namespace(args)
|
|||||||
# TESTING ONLY #
|
# TESTING ONLY #
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
hparams = config.model_paramters
|
hparams = config.model_paramters
|
||||||
dataset = TrajData('data', mapname='tate', alternatives=10000, trajectories=2500)
|
dataset = TrajPairData('data', mapname='tate', alternatives=10000, trajectories=2500)
|
||||||
dataloader = DataLoader(dataset=dataset.train_dataset, shuffle=True,
|
dataloader = DataLoader(dataset=dataset.train_dataset, shuffle=True,
|
||||||
batch_size=hparams.data_param.batchsize,
|
batch_size=hparams.data_param.batchsize,
|
||||||
num_workers=hparams.data_param.worker)
|
num_workers=hparams.data_param.worker)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user