from abc import ABC from pathlib import Path import torch from torch import nn from torch import functional as F from torch.utils.data import DataLoader from lib.objects.map import MapStorage import pytorch_lightning as pl # Utility - Modules ################### class Flatten(nn.Module): @property def shape(self): try: x = torch.randn(self.in_shape).unsqueeze(0) output = self(x) return output.shape[1:] except Exception as e: print(e) return -1 def __init__(self, in_shape, to=(-1, )): super(Flatten, self).__init__() self.in_shape = in_shape self.to = to def forward(self, x): return x.view(x.size(0), *self.to) class Interpolate(nn.Module): def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None): super(Interpolate, self).__init__() self.interp = nn.functional.interpolate self.size = size self.scale_factor = scale_factor self.align_corners = align_corners self.mode = mode def forward(self, x): x = self.interp(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) return x class AutoPad(nn.Module): def __init__(self, interpolations=3, base=2): super(AutoPad, self).__init__() self.fct = base ** interpolations def forward(self, x): # noinspection PyUnresolvedReferences x = F.pad(x, [0, (x.shape[-1] // self.fct + 1) * self.fct - x.shape[-1] if x.shape[-1] % self.fct != 0 else 0, (x.shape[-2] // self.fct + 1) * self.fct - x.shape[-2] if x.shape[-2] % self.fct != 0 else 0, 0]) return x class LightningBaseModule(pl.LightningModule, ABC): @classmethod def name(cls): raise NotImplementedError('Give your model a name!') @property def shape(self): try: x = torch.randn(self.in_shape).unsqueeze(0) output = self(x) return output.shape[1:] except Exception as e: print(e) return -1 def __init__(self, params): super(LightningBaseModule, self).__init__() self.hparams = params # Data loading # ============================================================================= # Map Object self.map_storage = MapStorage(self.hparams.data_param.map_root) def size(self): return self.shape def _move_to_model_device(self, x): return x.cuda() if next(self.parameters()).is_cuda else x.cpu() def save_to_disk(self, model_path): Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True) if not (model_path / 'model_class.obj').exists(): with (model_path / 'model_class.obj').open('wb') as f: torch.save(self.__class__, f) return True @pl.data_loader def train_dataloader(self): return DataLoader(dataset=self.dataset.train_dataset, shuffle=True, batch_size=self.hparams.data_param.batchsize, num_workers=self.hparams.data_param.worker) @pl.data_loader def test_dataloader(self): return DataLoader(dataset=self.dataset.test_dataset, shuffle=True, batch_size=self.hparams.data_param.batchsize, num_workers=self.hparams.data_param.worker) @pl.data_loader def val_dataloader(self): return DataLoader(dataset=self.dataset.val_dataset, shuffle=True, batch_size=self.hparams.data_param.batchsize, num_workers=self.hparams.data_param.worker) def configure_optimizers(self): raise NotImplementedError def forward(self, *args, **kwargs): raise NotImplementedError def validation_step(self, *args, **kwargs): raise NotImplementedError def validation_end(self, outputs): raise NotImplementedError def training_step(self, batch_xy, batch_nb, *args, **kwargs): raise NotImplementedError def test_step(self, *args, **kwargs): raise NotImplementedError def test_end(self, outputs): from sklearn.metrics import roc_auc_score y_scores, y_true = [], [] for output in outputs: y_scores.append(output['y_pred']) y_true.append(output['y_true']) y_true = torch.cat(y_true, dim=0) # FIXME: What did this do do i need it? # y_true = (y_true != V.HOMOTOPIC).long() y_scores = torch.cat(y_scores, dim=0) roc_auc_scores = roc_auc_score(y_true.cpu().numpy(), y_scores.cpu().numpy()) print(f'AUC Score: {roc_auc_scores}') return {'roc_auc_scores': roc_auc_scores} def init_weights(self): def _weight_init(m): if hasattr(m, 'weight'): if isinstance(m.weight, torch.Tensor): torch.nn.init.xavier_uniform_(m.weight) if hasattr(m, 'bias'): if isinstance(m.bias, torch.Tensor): m.bias.data.fill_(0.01) self.apply(_weight_init) class FilterLayer(nn.Module): def __init__(self): super(FilterLayer, self).__init__() def forward(self, x): tensor = x[:, -1] return tensor class MergingLayer(nn.Module): def __init__(self): super(MergingLayer, self).__init__() def forward(self, x): # ToDo: Which ones to combine? return class FlipTensor(nn.Module): def __init__(self, dim=-2): super(FlipTensor, self).__init__() self.dim = dim def forward(self, x): idx = [i for i in range(x.size(self.dim) - 1, -1, -1)] idx = torch.as_tensor(idx).long() inverted_tensor = x.index_select(self.dim, idx) return inverted_tensor