2020-03-03 15:10:17 +01:00

85 lines
3.7 KiB
Python

from functools import reduce
from operator import mul
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam
from datasets.trajectory_dataset import TrajData
from lib.modules.utils import LightningBaseModule, Flatten
from lib.modules.blocks import ConvModule, ResidualModule
class ConvHomDetector(LightningBaseModule):
name = 'CNNHomotopyClassifier'
def configure_optimizers(self):
return Adam(self.parameters(), lr=self.hparams.lr)
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, batch_y = batch_xy
pred_y = self(batch_x)
loss = F.binary_cross_entropy(pred_y, batch_y.float())
return {'loss': loss, 'log': dict(loss=loss)}
def __init__(self, *params):
super(ConvHomDetector, self).__init__(*params)
# Dataset
self.dataset = TrajData(self.hparams.data_param.root)
# Additional Attributes
self.map_shape = self.dataset.map_shapes_max
# Model Paramters
self.in_shape = self.dataset.map_shapes_max
assert len(self.in_shape) == 3, f'Image or map shape has to have 3 dims, but had: {len(self.in_shape)}'
# NN Nodes
# ============================
# Convolutional Map Processing
#
self.map_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1,
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 3,
**dict(conv_kernel=3, conv_stride=1,
conv_padding=1, conv_filters=self.hparams.model_param.filters[0]))
self.map_conv_1 = ConvModule(self.map_res_1.shape, conv_kernel=5, conv_stride=1,
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
self.map_res_2 = ResidualModule(self.map_conv_1.shape, ConvModule, 3,
**dict(conv_kernel=3, conv_stride=1,
conv_padding=1, conv_filters=self.hparams.model_param.filters[0]))
self.map_conv_2 = ConvModule(self.map_res_2.shape, conv_kernel=5, conv_stride=1,
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
self.map_res_3 = ResidualModule(self.map_conv_2.shape, ConvModule, 3,
**dict(conv_kernel=3, conv_stride=1,
conv_padding=1, conv_filters=self.hparams.model_param.filters[0]))
self.map_conv_3 = ConvModule(self.map_res_3.shape, conv_kernel=5, conv_stride=1,
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
self.flatten = Flatten(self.map_conv_3.shape)
# ============================
# Classifier
#
self.linear = nn.Linear(reduce(mul, self.flatten.shape), self.hparams.model_param.classes * 10)
# Comments on Multi Class labels
self.classifier = nn.Linear(self.hparams.model_param.classes * 10, 1) # self.hparams.model_param.classes)
self.out_activation = nn.Sigmoid() # nn.Softmax
def forward(self, x):
tensor = self.map_conv_0(x)
tensor = self.map_res_1(tensor)
tensor = self.map_conv_1(tensor)
tensor = self.map_res_2(tensor)
tensor = self.map_conv_2(tensor)
tensor = self.map_conv_3(tensor)
tensor = self.flatten(tensor)
tensor = self.linear(tensor)
tensor = self.classifier(tensor)
tensor = self.out_activation(tensor)
return tensor