Model blocks, Model files, rearrange project structure
This commit is contained in:
8
.idea/dictionaries/steffen.xml
generated
Normal file
8
.idea/dictionaries/steffen.xml
generated
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
<component name="ProjectDictionaryState">
|
||||||
|
<dictionary name="steffen">
|
||||||
|
<words>
|
||||||
|
<w>conv</w>
|
||||||
|
<w>numlayers</w>
|
||||||
|
</words>
|
||||||
|
</dictionary>
|
||||||
|
</component>
|
2
.idea/hom_traj_gen.iml
generated
2
.idea/hom_traj_gen.iml
generated
@@ -2,7 +2,7 @@
|
|||||||
<module type="PYTHON_MODULE" version="4">
|
<module type="PYTHON_MODULE" version="4">
|
||||||
<component name="NewModuleRootManager">
|
<component name="NewModuleRootManager">
|
||||||
<content url="file://$MODULE_DIR$" />
|
<content url="file://$MODULE_DIR$" />
|
||||||
<orderEntry type="inheritedJdk" />
|
<orderEntry type="jdk" jdkName="Python 3.7 (traj_gen)" jdkType="Python SDK" />
|
||||||
<orderEntry type="sourceFolder" forTests="false" />
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
</component>
|
</component>
|
||||||
</module>
|
</module>
|
2
.idea/misc.xml
generated
2
.idea/misc.xml
generated
@@ -3,5 +3,5 @@
|
|||||||
<component name="JavaScriptSettings">
|
<component name="JavaScriptSettings">
|
||||||
<option name="languageLevel" value="ES6" />
|
<option name="languageLevel" value="ES6" />
|
||||||
</component>
|
</component>
|
||||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7" project-jdk-type="Python SDK" />
|
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (traj_gen)" project-jdk-type="Python SDK" />
|
||||||
</project>
|
</project>
|
@@ -1,11 +1,11 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from lib.objects.map import Map
|
from lib.objects.map import Map
|
||||||
from preprocessing.generator import Generator
|
from lib.preprocessing.generator import Generator
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
data_root = Path() / 'data'
|
data_root = Path() / 'data'
|
||||||
maps_root = Path() / 'res' / 'maps'
|
maps_root = Path() / 'res' / 'maps'
|
||||||
map_object = Map('Tate').from_image(maps_root / 'tate_sw.bmp')
|
map_object = Map('Tate').from_image(maps_root / 'tate_sw.bmp')
|
||||||
generator = Generator(data_root, map_object)
|
generator = Generator(data_root, map_object)
|
||||||
generator.generate_n_trajectories_m_alternatives(100, 10, 'test')
|
generator.generate_n_trajectories_m_alternatives(100, 10, 'test')
|
||||||
|
@@ -1,96 +1,96 @@
|
|||||||
import shelve
|
import shelve
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import ConcatDataset, Dataset
|
from torch.utils.data import ConcatDataset, Dataset
|
||||||
|
|
||||||
from lib.objects.map import Map
|
from lib.objects.map import Map
|
||||||
from preprocessing.generator import Generator
|
from lib.preprocessing.generator import Generator
|
||||||
|
|
||||||
|
|
||||||
class TrajDataset(Dataset):
|
class TrajDataset(Dataset):
|
||||||
def __init__(self, data):
|
def __init__(self, data):
|
||||||
super(TrajDataset, self).__init__()
|
super(TrajDataset, self).__init__()
|
||||||
self.alternatives = data['alternatives']
|
self.alternatives = data['alternatives']
|
||||||
self.trajectory = data['trajectory']
|
self.trajectory = data['trajectory']
|
||||||
self.labels = data['labels']
|
self.labels = data['labels']
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.alternatives)
|
return len(self.alternatives)
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
return self.trajectory.vertices, self.alternatives[item].vertices, self.labels[item]
|
return self.trajectory.vertices, self.alternatives[item].vertices, self.labels[item]
|
||||||
|
|
||||||
|
|
||||||
class DataSetMapping(Dataset):
|
class DataSetMapping(Dataset):
|
||||||
def __init__(self, dataset, mapping):
|
def __init__(self, dataset, mapping):
|
||||||
self._dataset = dataset
|
self._dataset = dataset
|
||||||
self._mapping = mapping
|
self._mapping = mapping
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self._mapping.shape[0]
|
return self._mapping.shape[0]
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
return self._dataset[self._mapping[item]]
|
return self._dataset[self._mapping[item]]
|
||||||
|
|
||||||
|
|
||||||
class TrajData(object):
|
class TrajData(object):
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return self.__class__.__name__
|
return self.__class__.__name__
|
||||||
|
|
||||||
def __init__(self, data_root, mapname='tate_sw', trajectories=1000, alternatives=10,
|
def __init__(self, data_root, mapname='tate_sw', trajectories=1000, alternatives=10,
|
||||||
train_val_test_split=(0.6, 0.2, 0.2), rebuild=False, equal_samples=True, **_):
|
train_val_test_split=(0.6, 0.2, 0.2), rebuild=False, equal_samples=True, **_):
|
||||||
|
|
||||||
self.rebuild = rebuild
|
self.rebuild = rebuild
|
||||||
self.equal_samples = equal_samples
|
self.equal_samples = equal_samples
|
||||||
self._alternatives = alternatives
|
self._alternatives = alternatives
|
||||||
self._trajectories = trajectories
|
self._trajectories = trajectories
|
||||||
self.mapname = mapname
|
self.mapname = mapname
|
||||||
self.train_split, self.val_split, self.test_split = train_val_test_split
|
self.train_split, self.val_split, self.test_split = train_val_test_split
|
||||||
self.data_root = Path(data_root)
|
self.data_root = Path(data_root)
|
||||||
self._dataset = None
|
self._dataset = None
|
||||||
self._dataset, self._train_map, self._val_map, self._test_map = self._load_dataset()
|
self._dataset, self._train_map, self._val_map, self._test_map = self._load_dataset()
|
||||||
|
|
||||||
def _build_data_on_demand(self):
|
def _build_data_on_demand(self):
|
||||||
maps_root = Path() / 'res' / 'maps'
|
maps_root = Path() / 'res' / 'maps'
|
||||||
map_object = Map(self.mapname).from_image(maps_root / f'{self.mapname}.bmp')
|
map_object = Map(self.mapname).from_image(maps_root / f'{self.mapname}.bmp')
|
||||||
assert maps_root.exists()
|
assert maps_root.exists()
|
||||||
dataset_file = Path(self.data_root) / f'{self.mapname}.pik'
|
dataset_file = Path(self.data_root) / f'{self.mapname}.pik'
|
||||||
if dataset_file.exists() and self.rebuild:
|
if dataset_file.exists() and self.rebuild:
|
||||||
dataset_file.unlink()
|
dataset_file.unlink()
|
||||||
if not dataset_file.exists():
|
if not dataset_file.exists():
|
||||||
generator = Generator(self.data_root, map_object)
|
generator = Generator(self.data_root, map_object)
|
||||||
generator.generate_n_trajectories_m_alternatives(self._trajectories, self._alternatives,
|
generator.generate_n_trajectories_m_alternatives(self._trajectories, self._alternatives,
|
||||||
self.mapname, equal_samples=self.equal_samples)
|
self.mapname, equal_samples=self.equal_samples)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _load_dataset(self):
|
def _load_dataset(self):
|
||||||
assert self._build_data_on_demand()
|
assert self._build_data_on_demand()
|
||||||
with shelve.open(str(self.data_root / f'{self.mapname}.pik')) as d:
|
with shelve.open(str(self.data_root / f'{self.mapname}.pik')) as d:
|
||||||
dataset = ConcatDataset([TrajDataset(d[key]) for key in d.keys() if key != 'map'])
|
dataset = ConcatDataset([TrajDataset(d[key]) for key in d.keys() if key != 'map'])
|
||||||
indices = torch.randperm(len(dataset))
|
indices = torch.randperm(len(dataset))
|
||||||
|
|
||||||
train_size = int(len(dataset) * self.train_split)
|
train_size = int(len(dataset) * self.train_split)
|
||||||
val_size = int(len(dataset) * self.val_split)
|
val_size = int(len(dataset) * self.val_split)
|
||||||
test_size = int(len(dataset) * self.test_split)
|
test_size = int(len(dataset) * self.test_split)
|
||||||
|
|
||||||
train_map = indices[:train_size]
|
train_map = indices[:train_size]
|
||||||
val_map = indices[train_size:val_size]
|
val_map = indices[train_size:val_size]
|
||||||
test_map = indices[test_size:]
|
test_map = indices[test_size:]
|
||||||
return dataset, train_map, val_map, test_map
|
return dataset, train_map, val_map, test_map
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def train_dataset(self):
|
def train_dataset(self):
|
||||||
return DataSetMapping(self._dataset, self._train_map)
|
return DataSetMapping(self._dataset, self._train_map)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def val_dataset(self):
|
def val_dataset(self):
|
||||||
return DataSetMapping(self._dataset, self._val_map)
|
return DataSetMapping(self._dataset, self._val_map)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def test_dataset(self):
|
def test_dataset(self):
|
||||||
return DataSetMapping(self._dataset, self._test_map)
|
return DataSetMapping(self._dataset, self._test_map)
|
||||||
|
|
||||||
def get_datasets(self):
|
def get_datasets(self):
|
||||||
return self.train_dataset, self.val_dataset, self.test_dataset
|
return self.train_dataset, self.val_dataset, self.test_dataset
|
||||||
|
0
lib/evaluation/homotopic.py
Normal file
0
lib/evaluation/homotopic.py
Normal file
@@ -156,6 +156,26 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
|||||||
self.apply(_weight_init)
|
self.apply(_weight_init)
|
||||||
|
|
||||||
|
|
||||||
|
class FilterLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(FilterLayer, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
tensor = x[:, -1]
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class MergingLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(MergingLayer, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# ToDo: Which ones to combine?
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Sub - Modules
|
# Sub - Modules
|
||||||
###################
|
###################
|
||||||
@@ -241,6 +261,32 @@ class DeConvModule(nn.Module):
|
|||||||
return self.shape
|
return self.shape
|
||||||
|
|
||||||
|
|
||||||
|
class RecurrentModule(nn.Module):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||||
|
output = self(x)
|
||||||
|
return output.shape[1:]
|
||||||
|
|
||||||
|
def __init__(self, in_shape, hidden_size, num_layers=1, cell_type=nn.GRU, use_bias=True, dropout=0):
|
||||||
|
super(RecurrentModule, self).__init__()
|
||||||
|
self.use_bias = use_bias
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.in_shape = in_shape
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.dropout = dropout
|
||||||
|
self.rnn = cell_type(self.in_shape[-1] * self.in_shape[-2], hidden_size,
|
||||||
|
num_layers=num_layers,
|
||||||
|
bias=self.use_bias,
|
||||||
|
batch_first=True,
|
||||||
|
dropout=self.dropout)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
tensor = self.rnn(x)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Full Model Parts
|
# Full Model Parts
|
||||||
###################
|
###################
|
||||||
|
29
lib/models/cnn.py
Normal file
29
lib/models/cnn.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from lib.models.blocks import RecurrentModule, ConvModule, DeConvModule, Generator, LightningBaseModule
|
||||||
|
|
||||||
|
|
||||||
|
class CNNRouteGeneratorModel(LightningBaseModule):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def name(cls):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def validation_step(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def validation_end(self, outputs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_step(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __init__(self, *params):
|
||||||
|
super(CNNRouteGeneratorModel, self).__init__(*params)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
pass
|
0
lib/models/full.py
Normal file
0
lib/models/full.py
Normal file
0
lib/models/recurrent.py
Normal file
0
lib/models/recurrent.py
Normal file
0
lib/preprocessing/__init__.py
Normal file
0
lib/preprocessing/__init__.py
Normal file
0
lib/visualization/__init__.py
Normal file
0
lib/visualization/__init__.py
Normal file
0
lib/visualization/bars.py
Normal file
0
lib/visualization/bars.py
Normal file
Reference in New Issue
Block a user