Model blocks, Model files, rearrange project structure

This commit is contained in:
Steffen Illium
2020-02-14 10:48:59 +01:00
parent 91ecf157d6
commit 1ce8d5993b
17 changed files with 192 additions and 109 deletions

8
.idea/dictionaries/steffen.xml generated Normal file
View File

@@ -0,0 +1,8 @@
<component name="ProjectDictionaryState">
<dictionary name="steffen">
<words>
<w>conv</w>
<w>numlayers</w>
</words>
</dictionary>
</component>

View File

@@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4"> <module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager"> <component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" /> <content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" /> <orderEntry type="jdk" jdkName="Python 3.7 (traj_gen)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" /> <orderEntry type="sourceFolder" forTests="false" />
</component> </component>
</module> </module>

2
.idea/misc.xml generated
View File

@@ -3,5 +3,5 @@
<component name="JavaScriptSettings"> <component name="JavaScriptSettings">
<option name="languageLevel" value="ES6" /> <option name="languageLevel" value="ES6" />
</component> </component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7" project-jdk-type="Python SDK" /> <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (traj_gen)" project-jdk-type="Python SDK" />
</project> </project>

View File

@@ -1,11 +1,11 @@
from pathlib import Path from pathlib import Path
from lib.objects.map import Map from lib.objects.map import Map
from preprocessing.generator import Generator from lib.preprocessing.generator import Generator
if __name__ == '__main__': if __name__ == '__main__':
data_root = Path() / 'data' data_root = Path() / 'data'
maps_root = Path() / 'res' / 'maps' maps_root = Path() / 'res' / 'maps'
map_object = Map('Tate').from_image(maps_root / 'tate_sw.bmp') map_object = Map('Tate').from_image(maps_root / 'tate_sw.bmp')
generator = Generator(data_root, map_object) generator = Generator(data_root, map_object)
generator.generate_n_trajectories_m_alternatives(100, 10, 'test') generator.generate_n_trajectories_m_alternatives(100, 10, 'test')

View File

@@ -1,96 +1,96 @@
import shelve import shelve
from pathlib import Path from pathlib import Path
import torch import torch
from torch.utils.data import ConcatDataset, Dataset from torch.utils.data import ConcatDataset, Dataset
from lib.objects.map import Map from lib.objects.map import Map
from preprocessing.generator import Generator from lib.preprocessing.generator import Generator
class TrajDataset(Dataset): class TrajDataset(Dataset):
def __init__(self, data): def __init__(self, data):
super(TrajDataset, self).__init__() super(TrajDataset, self).__init__()
self.alternatives = data['alternatives'] self.alternatives = data['alternatives']
self.trajectory = data['trajectory'] self.trajectory = data['trajectory']
self.labels = data['labels'] self.labels = data['labels']
def __len__(self): def __len__(self):
return len(self.alternatives) return len(self.alternatives)
def __getitem__(self, item): def __getitem__(self, item):
return self.trajectory.vertices, self.alternatives[item].vertices, self.labels[item] return self.trajectory.vertices, self.alternatives[item].vertices, self.labels[item]
class DataSetMapping(Dataset): class DataSetMapping(Dataset):
def __init__(self, dataset, mapping): def __init__(self, dataset, mapping):
self._dataset = dataset self._dataset = dataset
self._mapping = mapping self._mapping = mapping
def __len__(self): def __len__(self):
return self._mapping.shape[0] return self._mapping.shape[0]
def __getitem__(self, item): def __getitem__(self, item):
return self._dataset[self._mapping[item]] return self._dataset[self._mapping[item]]
class TrajData(object): class TrajData(object):
@property @property
def name(self): def name(self):
return self.__class__.__name__ return self.__class__.__name__
def __init__(self, data_root, mapname='tate_sw', trajectories=1000, alternatives=10, def __init__(self, data_root, mapname='tate_sw', trajectories=1000, alternatives=10,
train_val_test_split=(0.6, 0.2, 0.2), rebuild=False, equal_samples=True, **_): train_val_test_split=(0.6, 0.2, 0.2), rebuild=False, equal_samples=True, **_):
self.rebuild = rebuild self.rebuild = rebuild
self.equal_samples = equal_samples self.equal_samples = equal_samples
self._alternatives = alternatives self._alternatives = alternatives
self._trajectories = trajectories self._trajectories = trajectories
self.mapname = mapname self.mapname = mapname
self.train_split, self.val_split, self.test_split = train_val_test_split self.train_split, self.val_split, self.test_split = train_val_test_split
self.data_root = Path(data_root) self.data_root = Path(data_root)
self._dataset = None self._dataset = None
self._dataset, self._train_map, self._val_map, self._test_map = self._load_dataset() self._dataset, self._train_map, self._val_map, self._test_map = self._load_dataset()
def _build_data_on_demand(self): def _build_data_on_demand(self):
maps_root = Path() / 'res' / 'maps' maps_root = Path() / 'res' / 'maps'
map_object = Map(self.mapname).from_image(maps_root / f'{self.mapname}.bmp') map_object = Map(self.mapname).from_image(maps_root / f'{self.mapname}.bmp')
assert maps_root.exists() assert maps_root.exists()
dataset_file = Path(self.data_root) / f'{self.mapname}.pik' dataset_file = Path(self.data_root) / f'{self.mapname}.pik'
if dataset_file.exists() and self.rebuild: if dataset_file.exists() and self.rebuild:
dataset_file.unlink() dataset_file.unlink()
if not dataset_file.exists(): if not dataset_file.exists():
generator = Generator(self.data_root, map_object) generator = Generator(self.data_root, map_object)
generator.generate_n_trajectories_m_alternatives(self._trajectories, self._alternatives, generator.generate_n_trajectories_m_alternatives(self._trajectories, self._alternatives,
self.mapname, equal_samples=self.equal_samples) self.mapname, equal_samples=self.equal_samples)
return True return True
def _load_dataset(self): def _load_dataset(self):
assert self._build_data_on_demand() assert self._build_data_on_demand()
with shelve.open(str(self.data_root / f'{self.mapname}.pik')) as d: with shelve.open(str(self.data_root / f'{self.mapname}.pik')) as d:
dataset = ConcatDataset([TrajDataset(d[key]) for key in d.keys() if key != 'map']) dataset = ConcatDataset([TrajDataset(d[key]) for key in d.keys() if key != 'map'])
indices = torch.randperm(len(dataset)) indices = torch.randperm(len(dataset))
train_size = int(len(dataset) * self.train_split) train_size = int(len(dataset) * self.train_split)
val_size = int(len(dataset) * self.val_split) val_size = int(len(dataset) * self.val_split)
test_size = int(len(dataset) * self.test_split) test_size = int(len(dataset) * self.test_split)
train_map = indices[:train_size] train_map = indices[:train_size]
val_map = indices[train_size:val_size] val_map = indices[train_size:val_size]
test_map = indices[test_size:] test_map = indices[test_size:]
return dataset, train_map, val_map, test_map return dataset, train_map, val_map, test_map
@property @property
def train_dataset(self): def train_dataset(self):
return DataSetMapping(self._dataset, self._train_map) return DataSetMapping(self._dataset, self._train_map)
@property @property
def val_dataset(self): def val_dataset(self):
return DataSetMapping(self._dataset, self._val_map) return DataSetMapping(self._dataset, self._val_map)
@property @property
def test_dataset(self): def test_dataset(self):
return DataSetMapping(self._dataset, self._test_map) return DataSetMapping(self._dataset, self._test_map)
def get_datasets(self): def get_datasets(self):
return self.train_dataset, self.val_dataset, self.test_dataset return self.train_dataset, self.val_dataset, self.test_dataset

View File

View File

@@ -156,6 +156,26 @@ class LightningBaseModule(pl.LightningModule, ABC):
self.apply(_weight_init) self.apply(_weight_init)
class FilterLayer(nn.Module):
def __init__(self):
super(FilterLayer, self).__init__()
def forward(self, x):
tensor = x[:, -1]
return tensor
class MergingLayer(nn.Module):
def __init__(self):
super(MergingLayer, self).__init__()
def forward(self, x):
# ToDo: Which ones to combine?
return
# #
# Sub - Modules # Sub - Modules
################### ###################
@@ -241,6 +261,32 @@ class DeConvModule(nn.Module):
return self.shape return self.shape
class RecurrentModule(nn.Module):
@property
def shape(self):
x = torch.randn(self.in_shape).unsqueeze(0)
output = self(x)
return output.shape[1:]
def __init__(self, in_shape, hidden_size, num_layers=1, cell_type=nn.GRU, use_bias=True, dropout=0):
super(RecurrentModule, self).__init__()
self.use_bias = use_bias
self.num_layers = num_layers
self.in_shape = in_shape
self.hidden_size = hidden_size
self.dropout = dropout
self.rnn = cell_type(self.in_shape[-1] * self.in_shape[-2], hidden_size,
num_layers=num_layers,
bias=self.use_bias,
batch_first=True,
dropout=self.dropout)
def forward(self, x):
tensor = self.rnn(x)
return tensor
# #
# Full Model Parts # Full Model Parts
################### ###################

29
lib/models/cnn.py Normal file
View File

@@ -0,0 +1,29 @@
from lib.models.blocks import RecurrentModule, ConvModule, DeConvModule, Generator, LightningBaseModule
class CNNRouteGeneratorModel(LightningBaseModule):
@classmethod
def name(cls):
pass
def configure_optimizers(self):
pass
def validation_step(self, *args, **kwargs):
pass
def validation_end(self, outputs):
pass
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
pass
def test_step(self, *args, **kwargs):
pass
def __init__(self, *params):
super(CNNRouteGeneratorModel, self).__init__(*params)
def forward(self, x):
pass

0
lib/models/full.py Normal file
View File

0
lib/models/recurrent.py Normal file
View File

View File

View File

View File