Model blocks, Model files, rearrange project structure
This commit is contained in:
8
.idea/dictionaries/steffen.xml
generated
Normal file
8
.idea/dictionaries/steffen.xml
generated
Normal file
@@ -0,0 +1,8 @@
|
||||
<component name="ProjectDictionaryState">
|
||||
<dictionary name="steffen">
|
||||
<words>
|
||||
<w>conv</w>
|
||||
<w>numlayers</w>
|
||||
</words>
|
||||
</dictionary>
|
||||
</component>
|
2
.idea/hom_traj_gen.iml
generated
2
.idea/hom_traj_gen.iml
generated
@@ -2,7 +2,7 @@
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="jdk" jdkName="Python 3.7 (traj_gen)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
2
.idea/misc.xml
generated
2
.idea/misc.xml
generated
@@ -3,5 +3,5 @@
|
||||
<component name="JavaScriptSettings">
|
||||
<option name="languageLevel" value="ES6" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7" project-jdk-type="Python SDK" />
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (traj_gen)" project-jdk-type="Python SDK" />
|
||||
</project>
|
@@ -1,7 +1,7 @@
|
||||
from pathlib import Path
|
||||
|
||||
from lib.objects.map import Map
|
||||
from preprocessing.generator import Generator
|
||||
from lib.preprocessing.generator import Generator
|
||||
|
||||
if __name__ == '__main__':
|
||||
data_root = Path() / 'data'
|
||||
|
@@ -5,7 +5,7 @@ import torch
|
||||
from torch.utils.data import ConcatDataset, Dataset
|
||||
|
||||
from lib.objects.map import Map
|
||||
from preprocessing.generator import Generator
|
||||
from lib.preprocessing.generator import Generator
|
||||
|
||||
|
||||
class TrajDataset(Dataset):
|
||||
|
0
lib/evaluation/homotopic.py
Normal file
0
lib/evaluation/homotopic.py
Normal file
@@ -156,6 +156,26 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
self.apply(_weight_init)
|
||||
|
||||
|
||||
class FilterLayer(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(FilterLayer, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
tensor = x[:, -1]
|
||||
return tensor
|
||||
|
||||
|
||||
class MergingLayer(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(MergingLayer, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
# ToDo: Which ones to combine?
|
||||
return
|
||||
|
||||
|
||||
#
|
||||
# Sub - Modules
|
||||
###################
|
||||
@@ -241,6 +261,32 @@ class DeConvModule(nn.Module):
|
||||
return self.shape
|
||||
|
||||
|
||||
class RecurrentModule(nn.Module):
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||
output = self(x)
|
||||
return output.shape[1:]
|
||||
|
||||
def __init__(self, in_shape, hidden_size, num_layers=1, cell_type=nn.GRU, use_bias=True, dropout=0):
|
||||
super(RecurrentModule, self).__init__()
|
||||
self.use_bias = use_bias
|
||||
self.num_layers = num_layers
|
||||
self.in_shape = in_shape
|
||||
self.hidden_size = hidden_size
|
||||
self.dropout = dropout
|
||||
self.rnn = cell_type(self.in_shape[-1] * self.in_shape[-2], hidden_size,
|
||||
num_layers=num_layers,
|
||||
bias=self.use_bias,
|
||||
batch_first=True,
|
||||
dropout=self.dropout)
|
||||
|
||||
def forward(self, x):
|
||||
tensor = self.rnn(x)
|
||||
return tensor
|
||||
|
||||
|
||||
#
|
||||
# Full Model Parts
|
||||
###################
|
||||
|
29
lib/models/cnn.py
Normal file
29
lib/models/cnn.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from lib.models.blocks import RecurrentModule, ConvModule, DeConvModule, Generator, LightningBaseModule
|
||||
|
||||
|
||||
class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
|
||||
@classmethod
|
||||
def name(cls):
|
||||
pass
|
||||
|
||||
def configure_optimizers(self):
|
||||
pass
|
||||
|
||||
def validation_step(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def validation_end(self, outputs):
|
||||
pass
|
||||
|
||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def test_step(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __init__(self, *params):
|
||||
super(CNNRouteGeneratorModel, self).__init__(*params)
|
||||
|
||||
def forward(self, x):
|
||||
pass
|
0
lib/models/full.py
Normal file
0
lib/models/full.py
Normal file
0
lib/models/recurrent.py
Normal file
0
lib/models/recurrent.py
Normal file
0
lib/preprocessing/__init__.py
Normal file
0
lib/preprocessing/__init__.py
Normal file
0
lib/visualization/__init__.py
Normal file
0
lib/visualization/__init__.py
Normal file
0
lib/visualization/bars.py
Normal file
0
lib/visualization/bars.py
Normal file
Reference in New Issue
Block a user