Model blocks, Model files, rearrange project structure

This commit is contained in:
Steffen Illium
2020-02-14 10:48:59 +01:00
parent 91ecf157d6
commit 1ce8d5993b
17 changed files with 192 additions and 109 deletions

@ -156,6 +156,26 @@ class LightningBaseModule(pl.LightningModule, ABC):
self.apply(_weight_init)
class FilterLayer(nn.Module):
def __init__(self):
super(FilterLayer, self).__init__()
def forward(self, x):
tensor = x[:, -1]
return tensor
class MergingLayer(nn.Module):
def __init__(self):
super(MergingLayer, self).__init__()
def forward(self, x):
# ToDo: Which ones to combine?
return
#
# Sub - Modules
###################
@ -241,6 +261,32 @@ class DeConvModule(nn.Module):
return self.shape
class RecurrentModule(nn.Module):
@property
def shape(self):
x = torch.randn(self.in_shape).unsqueeze(0)
output = self(x)
return output.shape[1:]
def __init__(self, in_shape, hidden_size, num_layers=1, cell_type=nn.GRU, use_bias=True, dropout=0):
super(RecurrentModule, self).__init__()
self.use_bias = use_bias
self.num_layers = num_layers
self.in_shape = in_shape
self.hidden_size = hidden_size
self.dropout = dropout
self.rnn = cell_type(self.in_shape[-1] * self.in_shape[-2], hidden_size,
num_layers=num_layers,
bias=self.use_bias,
batch_first=True,
dropout=self.dropout)
def forward(self, x):
tensor = self.rnn(x)
return tensor
#
# Full Model Parts
###################

29
lib/models/cnn.py Normal file

@ -0,0 +1,29 @@
from lib.models.blocks import RecurrentModule, ConvModule, DeConvModule, Generator, LightningBaseModule
class CNNRouteGeneratorModel(LightningBaseModule):
@classmethod
def name(cls):
pass
def configure_optimizers(self):
pass
def validation_step(self, *args, **kwargs):
pass
def validation_end(self, outputs):
pass
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
pass
def test_step(self, *args, **kwargs):
pass
def __init__(self, *params):
super(CNNRouteGeneratorModel, self).__init__(*params)
def forward(self, x):
pass

0
lib/models/full.py Normal file

0
lib/models/recurrent.py Normal file