from functools import reduce from operator import mul import torch from torch import nn import torch.nn.functional as F from torch.optim import Adam from torch.utils.data import DataLoader from datasets.trajectory_dataset import TrajData from lib.evaluation.classification import ROCEvaluation from lib.modules.utils import LightningBaseModule, Flatten from lib.modules.blocks import ConvModule, ResidualModule import matplotlib.pyplot as plt class ConvHomDetector(LightningBaseModule): name = 'CNNHomotopyClassifier' def configure_optimizers(self): return Adam(self.parameters(), lr=self.hparams.lr) def training_step(self, batch_xy, batch_nb, *args, **kwargs): batch_x, batch_y = batch_xy pred_y = self(batch_x) loss = self.criterion(pred_y, batch_y.unsqueeze(-1).float()) return {'loss': loss, 'log': dict(loss=loss)} def test_step(self, batch_xy, batch_nb, **kwargs): batch_x, batch_y = batch_xy pred_y = self(batch_x) return dict(prediction=pred_y, label=batch_y, batch_nb=batch_nb) def test_epoch_end(self, outputs): evaluation = ROCEvaluation(plot_roc=True) predictions = torch.cat([x['prediction'] for x in outputs]) labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1) # Sci-py call ROC eval call is eval(true_label, prediction) roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), predictions.cpu().numpy(), ) score_dict = dict(roc_auc=roc_auc) # self.logger.log_metrics(score_dict) self.logger.log_image(f'{self.name}', plt.gcf()) return dict(log=score_dict) def __init__(self, hparams): super(ConvHomDetector, self).__init__(hparams) # Dataset self.dataset = TrajData(self.hparams.data_param.map_root, mode='all_in_map') # Additional Attributes self.map_shape = self.dataset.map_shapes_max # Model Parameters self.in_shape = self.dataset.map_shapes_max assert len(self.in_shape) == 3, f'Image or map shape has to have 3 dims, but had: {len(self.in_shape)}' self.criterion = nn.BCELoss() self.sigmoid = nn.Sigmoid() # NN Nodes # ============================ # Convolutional Map Processing self.map_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[0]) self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 3, **dict(conv_kernel=3, conv_stride=1, conv_padding=1, conv_filters=self.hparams.model_param.filters[0])) self.map_conv_1 = ConvModule(self.map_res_1.shape, conv_kernel=5, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[0]) self.map_res_2 = ResidualModule(self.map_conv_1.shape, ConvModule, 3, **dict(conv_kernel=3, conv_stride=1, conv_padding=1, conv_filters=self.hparams.model_param.filters[0])) self.map_conv_2 = ConvModule(self.map_res_2.shape, conv_kernel=5, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[0]) self.map_res_3 = ResidualModule(self.map_conv_2.shape, ConvModule, 3, **dict(conv_kernel=3, conv_stride=1, conv_padding=1, conv_filters=self.hparams.model_param.filters[0])) self.map_conv_3 = ConvModule(self.map_res_3.shape, conv_kernel=5, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[0]) self.flatten = Flatten(self.map_conv_3.shape) # ============================ # Classifier # self.linear = nn.Linear(reduce(mul, self.flatten.shape), self.hparams.model_param.classes * 10) # Comments on Multi Class labels self.classifier = nn.Linear(self.hparams.model_param.classes * 10, 1) # self.hparams.model_param.classes) def forward(self, x): tensor = self.map_conv_0(x) tensor = self.map_res_1(tensor) tensor = self.map_conv_1(tensor) tensor = self.map_res_2(tensor) tensor = self.map_conv_2(tensor) tensor = self.map_conv_3(tensor) tensor = self.flatten(tensor) tensor = self.linear(tensor) tensor = self.classifier(tensor) tensor = self.sigmoid(tensor) return tensor