from functools import reduce from matplotlib import pyplot as plt from abc import ABC from pathlib import Path import torch from operator import mul from pytorch_lightning.utilities import argparse_utils from torch import nn from torch.nn import functional as F, Unfold from sklearn.metrics import ConfusionMatrixDisplay # Utility - Modules ################### from ..metrics.binary_class_classifictaion import BinaryScores from ..metrics.multi_class_classification import MultiClassScores from ..utils.model_io import ModelParameters from ..utils.tools import add_argparse_args try: import pytorch_lightning as pl class PLMetrics(pl.metrics.Metric): def __init__(self, n_classes, tag=''): super(PLMetrics, self).__init__() self.n_classes = n_classes self.tag = tag self.accuracy_score = pl.metrics.Accuracy(compute_on_step=False,) self.precision = pl.metrics.Precision(num_classes=self.n_classes, average='macro', compute_on_step=False, is_multiclass=True) self.recall = pl.metrics.Recall(num_classes=self.n_classes, average='macro', compute_on_step=False, is_multiclass=True) self.confusion_matrix = pl.metrics.ConfusionMatrix(self.n_classes, normalize='true', compute_on_step=False) # self.precision_recall_curve = pl.metrics.PrecisionRecallCurve(self.n_classes, compute_on_step=False) # self.average_prec = pl.metrics.AveragePrecision(self.n_classes, compute_on_step=True) # self.roc = pl.metrics.ROC(self.n_classes, compute_on_step=False) if self.n_classes > 2: self.fbeta = pl.metrics.FBeta(self.n_classes, average='macro', compute_on_step=False) self.f1 = pl.metrics.F1(self.n_classes, average='macro', compute_on_step=False) def __iter__(self): return iter(((name, metric) for name, metric in self._modules.items())) def update(self, preds, target) -> None: for _, metric in self: try: if self.n_classes <= 2: metric.update(preds, target) else: metric.update(preds, target) except ValueError: print(f'error was: {ValueError}') print(f'Metric is: {metric}') print(f'Shape is: preds - {preds.squeeze().shape}, target - {target.shape}') metric.update(preds.squeeze(), target) except AssertionError: print(f'error was: {AssertionError}') print(f'Metric is: {metric}') print(f'Shape is: preds - {preds.shape}, target - {target.unsqueeze(-1).shape}') metric.update(preds, target.unsqueeze(-1)) def reset(self) -> None: for _, metric in self: metric.reset() def compute(self) -> dict: tag = f'{self.tag}_' if self.tag else '' return {f'{tag}{metric_name}_score': metric.compute() for metric_name, metric in self} def compute_and_prepare(self): pl_metrics = self.compute() images_from_metrics = dict() for metric_name in list(pl_metrics.keys()): if 'curve' in metric_name: continue roc_curve = pl_metrics.pop(metric_name) print('debug_point') elif 'matrix' in metric_name: matrix = pl_metrics.pop(metric_name) fig1, ax1 = plt.subplots(dpi=96) disp = ConfusionMatrixDisplay(confusion_matrix=matrix.cpu().numpy(), display_labels=[i for i in range(self.n_classes)] ) disp.plot(include_values=True, ax=ax1) images_from_metrics[metric_name] = fig1 elif 'ROC' in metric_name: continue roc = pl_metrics.pop(metric_name) print('debug_point') else: pl_metrics[metric_name] = pl_metrics[metric_name].cpu().item() return pl_metrics, images_from_metrics class LightningBaseModule(pl.LightningModule, ABC): @classmethod def name(cls): return cls.__name__ @property def shape(self): try: x = torch.randn(self.in_shape).unsqueeze(0) output = self(x) return output.shape[1:] except Exception as e: print(e) return -1 @classmethod def from_argparse_args(cls, args, **kwargs): return argparse_utils.from_argparse_args(cls, args, **kwargs) @classmethod def add_argparse_args(cls, parent_parser): return add_argparse_args(cls, parent_parser) def __init__(self, model_parameters, weight_init='xavier_normal_'): super(LightningBaseModule, self).__init__() self._weight_init = weight_init self.params = ModelParameters(model_parameters) if hasattr(self.params, 'n_classes'): self.metrics = PLMetrics(self.params.n_classes, tag='PL') else: pass def size(self): return self.shape def save_to_disk(self, model_path): Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True) if not (model_path / 'model_class.obj').exists(): with (model_path / 'model_class.obj').open('wb') as f: torch.save(self.__class__, f) return True @property def data_len(self): return len(self.dataset.train_dataset) @property def n_train_batches(self): return len(self.train_dataloader()) def configure_optimizers(self): raise NotImplementedError def forward(self, *args, **kwargs): raise NotImplementedError def training_step(self, batch_xy, batch_nb, *args, **kwargs): raise NotImplementedError def test_step(self, *args, **kwargs): raise NotImplementedError def test_epoch_end(self, outputs): raise NotImplementedError def init_weights(self): if isinstance(self._weight_init, str): mod = __import__('torch.nn.init', fromlist=[self._weight_init]) self._weight_init = getattr(mod, self._weight_init) assert callable(self._weight_init) weight_initializer = WeightInit(in_place_init_function=self._weight_init) self.apply(weight_initializer) def additional_scores(self, outputs): if self.params.n_classes > 2: return MultiClassScores(self)(outputs) else: return BinaryScores(self)(outputs) module_types = (LightningBaseModule, nn.Module,) except ImportError: module_types = (nn.Module,) pl = None pass # Maybe post a hint to install pytorch-lightning. class ShapeMixin: @property def shape(self): assert isinstance(self, module_types) def get_out_shape(output): return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1] in_shape = self.in_shape if hasattr(self, 'in_shape') else None if in_shape is not None: try: device = self.device except AttributeError: try: device = next(self.parameters()).device except StopIteration: device = 'cuda' if torch.cuda.is_available() else 'cpu' x = torch.randn(in_shape, device=device) # This is needed for BatchNorm shape checking x = torch.stack((x, x)) # noinspection PyCallingNonCallable y = self(x) if isinstance(y, tuple): shape = tuple([get_out_shape(y[i]) for i in range(len(y))]) else: shape = get_out_shape(y) return shape else: return -1 @property def flat_shape(self): shape = self.shape try: return reduce(mul, shape) except TypeError: return shape class F_x(ShapeMixin, nn.Identity): def __init__(self, in_shape): super(F_x, self).__init__() self.in_shape = in_shape class SlidingWindow(ShapeMixin, nn.Module): def __init__(self, in_shape, kernel, stride=1, padding=0, keepdim=False): super(SlidingWindow, self).__init__() self.in_shape = in_shape self.kernel = kernel if not isinstance(kernel, int) else (kernel, kernel) self.padding = padding self.stride = stride self.keepdim = keepdim self._unfolder = Unfold(self.kernel, dilation=1, padding=self.padding, stride=self.stride) def forward(self, x): tensor = self._unfolder(x) tensor = tensor.transpose(-1, -2) if self.keepdim: shape = *x.shape[:2], -1, *self.kernel tensor = tensor.reshape(shape) return tensor # Utility - Modules ################### class Flatten(ShapeMixin, nn.Module): def __init__(self, in_shape, to=-1): assert isinstance(to, int) or isinstance(to, tuple) super(Flatten, self).__init__() self.in_shape = in_shape self.to = (to,) if isinstance(to, int) else to def forward(self, x): return x.view(x.size(0), *self.to) class Interpolate(nn.Module): def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None): super(Interpolate, self).__init__() self.interp = nn.functional.interpolate self.size = size self.scale_factor = scale_factor self.align_corners = align_corners self.mode = mode def forward(self, x): x = self.interp(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) return x class AutoPad(nn.Module): def __init__(self, interpolations=3, base=2): super(AutoPad, self).__init__() self.fct = base ** interpolations def forward(self, x): # noinspection PyUnresolvedReferences x = F.pad(x, [0, (x.shape[-1] // self.fct + 1) * self.fct - x.shape[-1] if x.shape[-1] % self.fct != 0 else 0, (x.shape[-2] // self.fct + 1) * self.fct - x.shape[-2] if x.shape[-2] % self.fct != 0 else 0, 0]) return x class WeightInit: def __init__(self, in_place_init_function): self.in_place_init_function = in_place_init_function def __call__(self, m): if hasattr(m, 'weight'): if isinstance(m.weight, torch.Tensor): if m.weight.ndim < 2: m.weight.data.fill_(0.01) else: self.in_place_init_function(m.weight) if hasattr(m, 'bias'): if isinstance(m.bias, torch.Tensor): m.bias.data.fill_(0.01) class Filter(nn.Module, ShapeMixin): def __init__(self, in_shape, pos, dim=-1): super(Filter, self).__init__() self.in_shape = in_shape self.pos = pos self.dim = dim raise SystemError('Do not use this Module - broken.') @staticmethod def forward(x): tensor = x[:, -1] return tensor class FlipTensor(nn.Module): def __init__(self, dim=-2): super(FlipTensor, self).__init__() self.dim = dim def forward(self, x): idx = [i for i in range(x.size(self.dim) - 1, -1, -1)] idx = torch.as_tensor(idx).long() inverted_tensor = x.index_select(self.dim, idx) return inverted_tensor class AutoPadToShape(nn.Module): def __init__(self, target_shape): super(AutoPadToShape, self).__init__() self.target_shape = target_shape def forward(self, x): if not torch.is_tensor(x): x = torch.as_tensor(x) if x.shape[-len(self.target_shape):] == self.target_shape or x.shape == self.target_shape: return x idx = [0] * (len(self.target_shape) * 2) for i, j in zip(range(-1, -(len(self.target_shape)+1), -1), range(0, len(idx), 2)): idx[j] = self.target_shape[i] - x.shape[i] x = torch.nn.functional.pad(x, idx) return x def __repr__(self): return f'AutoPadTransform({self.target_shape})' class Splitter(nn.Module): @property def shape(self): return tuple([self._out_shape] * self.n) @property def out_shape(self): return self._out_shape def __init__(self, in_shape, n, dim=-1): super(Splitter, self).__init__() self.in_shape = (in_shape, ) if isinstance(in_shape, int) else in_shape self.n = n self.dim = dim if dim > 0 else len(self.in_shape) - abs(dim) self.new_dim_size = (self.in_shape[self.dim] // self.n) + (1 if self.in_shape[self.dim] % self.n != 0 else 0) self._out_shape = tuple([x if self.dim != i else self.new_dim_size for i, x in enumerate(self.in_shape)]) self.autopad = AutoPadToShape(self._out_shape) def forward(self, x: torch.Tensor): dim = self.dim + 1 if len(self.in_shape) == (x.ndim - 1) else self.dim x = x.transpose(0, dim) n_blocks = list() for block_idx in range(self.n): start = block_idx * self.new_dim_size end = (block_idx + 1) * self.new_dim_size block = x[start:end].transpose(0, dim) block = self.autopad(block) n_blocks.append(block) return n_blocks class Merger(nn.Module, ShapeMixin): @property def shape(self): y = self.forward([torch.randn(self.in_shape) for _ in range(self.n)]) return y.shape def __init__(self, in_shape, n, dim=-1): super(Merger, self).__init__() self.n = n self.dim = dim self.in_shape = in_shape def forward(self, x): return torch.cat(x, dim=self.dim)