119 lines
5.1 KiB
Python
119 lines
5.1 KiB
Python
from functools import reduce
|
|
from operator import mul
|
|
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
from torch.optim import Adam
|
|
from torch.utils.data import DataLoader
|
|
|
|
from datasets.trajectory_dataset import TrajData
|
|
from ml_lib.evaluation.classification import ROCEvaluation
|
|
from ml_lib.modules.utils import LightningBaseModule, Flatten
|
|
from ml_lib.modules.blocks import ConvModule, ResidualModule
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
class ConvHomDetector(LightningBaseModule):
|
|
|
|
name = 'CNNHomotopyClassifier'
|
|
|
|
def configure_optimizers(self):
|
|
return Adam(self.parameters(), lr=self.hparams.lr)
|
|
|
|
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
|
batch_x, batch_y = batch_xy
|
|
pred_y = self(batch_x)
|
|
loss = self.criterion(pred_y, batch_y.unsqueeze(-1).float())
|
|
return {'loss': loss, 'log': dict(loss=loss)}
|
|
|
|
def test_step(self, batch_xy, batch_nb, **kwargs):
|
|
batch_x, batch_y = batch_xy
|
|
pred_y = self(batch_x)
|
|
return dict(prediction=pred_y, label=batch_y, batch_nb=batch_nb)
|
|
|
|
def validation_step(self, batch_xy, batch_nb, **kwargs):
|
|
batch_x, batch_y = batch_xy
|
|
pred_y = self(batch_x)
|
|
return dict(prediction=pred_y, label=batch_y, batch_nb=batch_nb)
|
|
|
|
def test_epoch_end(self, outputs):
|
|
return self._val_test_end(outputs)
|
|
|
|
def validation_epoch_end(self, outputs: list):
|
|
return self._val_test_end(outputs)
|
|
|
|
def _val_test_end(self, outputs, test=True):
|
|
evaluation = ROCEvaluation(plot_roc=True if test else False)
|
|
predictions = torch.cat([x['prediction'] for x in outputs])
|
|
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
|
|
|
|
# Sci-py call ROC eval call is eval(true_label, prediction)
|
|
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), predictions.cpu().numpy())
|
|
# self.logger.log_metrics(score_dict)
|
|
if test:
|
|
self.logger.log_image(f'{self.name}', plt.gcf())
|
|
|
|
return dict(score=roc_auc, log=dict(roc_auc=roc_auc))
|
|
|
|
def __init__(self, hparams):
|
|
super(ConvHomDetector, self).__init__(hparams)
|
|
|
|
# Dataset
|
|
self.dataset = TrajData(self.hparams.data_param.map_root, mode='classifier_all_in_map', )
|
|
|
|
# Additional Attributes
|
|
self.map_shape = self.dataset.map_shapes_max
|
|
|
|
# Model Parameters
|
|
self.in_shape = self.dataset.map_shapes_max
|
|
assert len(self.in_shape) == 3, f'Image or map shape has to have 3 dims, but had: {len(self.in_shape)}'
|
|
self.criterion = nn.BCELoss()
|
|
self.sigmoid = nn.Sigmoid()
|
|
self.relu = nn.ReLU()
|
|
|
|
# NN Nodes
|
|
# ============================
|
|
# Convolutional Map Processing
|
|
self.map_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1,
|
|
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
|
|
self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 3,
|
|
**dict(conv_kernel=3, conv_stride=1,
|
|
conv_padding=1, conv_filters=self.hparams.model_param.filters[0]))
|
|
self.map_conv_1 = ConvModule(self.map_res_1.shape, conv_kernel=5, conv_stride=1,
|
|
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
|
|
self.map_res_2 = ResidualModule(self.map_conv_1.shape, ConvModule, 3,
|
|
**dict(conv_kernel=3, conv_stride=1,
|
|
conv_padding=1, conv_filters=self.hparams.model_param.filters[0]))
|
|
self.map_conv_2 = ConvModule(self.map_res_2.shape, conv_kernel=5, conv_stride=1,
|
|
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
|
|
self.map_res_3 = ResidualModule(self.map_conv_2.shape, ConvModule, 3,
|
|
**dict(conv_kernel=3, conv_stride=1,
|
|
conv_padding=1, conv_filters=self.hparams.model_param.filters[0]))
|
|
self.map_conv_3 = ConvModule(self.map_res_3.shape, conv_kernel=5, conv_stride=1,
|
|
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
|
|
|
|
self.flatten = Flatten(self.map_conv_3.shape)
|
|
|
|
# ============================
|
|
# Classifier
|
|
#
|
|
|
|
self.linear = nn.Linear(reduce(mul, self.flatten.shape), self.hparams.model_param.classes * 10)
|
|
# Comments on Multi Class labels
|
|
self.classifier = nn.Linear(self.hparams.model_param.classes * 10, 1) # self.hparams.model_param.classes)
|
|
|
|
def forward(self, x):
|
|
tensor = self.map_conv_0(x)
|
|
tensor = self.map_res_1(tensor)
|
|
tensor = self.map_conv_1(tensor)
|
|
tensor = self.map_res_2(tensor)
|
|
tensor = self.map_conv_2(tensor)
|
|
tensor = self.map_conv_3(tensor)
|
|
tensor = self.flatten(tensor)
|
|
tensor = self.linear(tensor)
|
|
tensor = self.relu(tensor)
|
|
tensor = self.classifier(tensor)
|
|
tensor = self.sigmoid(tensor)
|
|
return tensor
|