from abc import ABC from pathlib import Path import torch from torch import nn from torch import functional as F import pytorch_lightning as pl # Utility - Modules ################### from ..utils.model_io import ModelParameters class ShapeMixin: @property def shape(self): assert isinstance(self, (LightningBaseModule, nn.Module)) if self.in_shape is not None: x = torch.randn(self.in_shape) # This is needed for BatchNorm shape checking x = torch.stack((x, x)) output = self(x) return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1] else: return -1 class F_x(ShapeMixin, nn.Module): def __init__(self, in_shape): super(F_x, self).__init__() self.in_shape = in_shape def forward(self, x): return x # Utility - Modules ################### class Flatten(ShapeMixin, nn.Module): def __init__(self, in_shape, to=-1): assert isinstance(to, int) or isinstance(to, tuple) super(Flatten, self).__init__() self.in_shape = in_shape self.to = (to,) if isinstance(to, int) else to def forward(self, x): return x.view(x.size(0), *self.to) class Interpolate(nn.Module): def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None): super(Interpolate, self).__init__() self.interp = nn.functional.interpolate self.size = size self.scale_factor = scale_factor self.align_corners = align_corners self.mode = mode def forward(self, x): x = self.interp(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) return x class AutoPad(nn.Module): def __init__(self, interpolations=3, base=2): super(AutoPad, self).__init__() self.fct = base ** interpolations def forward(self, x): # noinspection PyUnresolvedReferences x = F.pad(x, [0, (x.shape[-1] // self.fct + 1) * self.fct - x.shape[-1] if x.shape[-1] % self.fct != 0 else 0, (x.shape[-2] // self.fct + 1) * self.fct - x.shape[-2] if x.shape[-2] % self.fct != 0 else 0, 0]) return x class WeightInit: def __init__(self, in_place_init_function): self.in_place_init_function = in_place_init_function def __call__(self, m): if hasattr(m, 'weight'): if isinstance(m.weight, torch.Tensor): if m.weight.ndim < 2: m.weight.data.fill_(0.01) else: self.in_place_init_function(m.weight) if hasattr(m, 'bias'): if isinstance(m.bias, torch.Tensor): m.bias.data.fill_(0.01) class LightningBaseModule(pl.LightningModule, ABC): @classmethod def name(cls): return cls.__name__ @property def shape(self): try: x = torch.randn(self.in_shape).unsqueeze(0) output = self(x) return output.shape[1:] except Exception as e: print(e) return -1 def __init__(self, hparams): super(LightningBaseModule, self).__init__() # Set Parameters ################################ self.hparams = hparams self.params = ModelParameters(hparams) # Dataset Loading ################################ # TODO: Find a way to push Class Name, library path and parameters (sometimes thiose are objects) in here def size(self): return self.shape def save_to_disk(self, model_path): Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True) if not (model_path / 'model_class.obj').exists(): with (model_path / 'model_class.obj').open('wb') as f: torch.save(self.__class__, f) return True @property def data_len(self): return len(self.dataset.train_dataset) @property def n_train_batches(self): return len(self.train_dataloader()) def configure_optimizers(self): raise NotImplementedError def forward(self, *args, **kwargs): raise NotImplementedError def training_step(self, batch_xy, batch_nb, *args, **kwargs): raise NotImplementedError def test_step(self, *args, **kwargs): raise NotImplementedError def test_epoch_end(self, outputs): raise NotImplementedError def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_): weight_initializer = WeightInit(in_place_init_function=in_place_init_func_) self.apply(weight_initializer) class FilterLayer(nn.Module): def __init__(self): super(FilterLayer, self).__init__() def forward(self, x): tensor = x[:, -1] return tensor class MergingLayer(nn.Module): def __init__(self): super(MergingLayer, self).__init__() def forward(self, x): # ToDo: Which ones to combine? return class FlipTensor(nn.Module): def __init__(self, dim=-2): super(FlipTensor, self).__init__() self.dim = dim def forward(self, x): idx = [i for i in range(x.size(self.dim) - 1, -1, -1)] idx = torch.as_tensor(idx).long() inverted_tensor = x.index_select(self.dim, idx) return inverted_tensor class AutoPadToShape(object): def __init__(self, shape): self.shape = shape def __call__(self, x): if not torch.is_tensor(x): x = torch.as_tensor(x) if x.shape[1:] == self.shape: return x embedding = torch.zeros((x.shape[0], *self.shape)) embedding[:, :x.shape[1], :x.shape[2], :x.shape[3]] = x return embedding def __repr__(self): return f'AutoPadTransform({self.shape})' class HorizontalSplitter(nn.Module): def __init__(self, in_shape, n): super(HorizontalSplitter, self).__init__() assert len(in_shape) == 3 self.n = n self.in_shape = in_shape self.channel, self.height, self.width = self.in_shape self.new_height = (self.height // self.n) + (1 if self.height % self.n != 0 else 0) self.shape = (self.channel, self.new_height, self.width) self.autopad = AutoPadToShape(self.shape) def forward(self, x): n_blocks = list() for block_idx in range(self.n): start = block_idx * self.new_height end = (block_idx + 1) * self.new_height block = self.autopad(x[:, :, start:end, :]) n_blocks.append(block) return n_blocks class HorizontalMerger(nn.Module): @property def shape(self): merged_shape = self.in_shape[0], self.in_shape[1] * self.n, self.in_shape[2] return merged_shape def __init__(self, in_shape, n): super(HorizontalMerger, self).__init__() assert len(in_shape) == 3 self.n = n self.in_shape = in_shape def forward(self, x): return torch.cat(x, dim=-2)