From 6bc9447ce1df4f41d90615d1975ea5e03ab08c6f Mon Sep 17 00:00:00 2001 From: Si11ium Date: Sat, 15 Aug 2020 12:42:57 +0200 Subject: [PATCH] Model Loading by string. Within Debugging --- modules/blocks.py | 15 ++++++- modules/util.py | 103 +++++++++++++++++++++++++--------------------- utils/config.py | 27 +++++++++--- utils/logging.py | 3 -- utils/tools.py | 18 +++++++- 5 files changed, 108 insertions(+), 58 deletions(-) diff --git a/modules/blocks.py b/modules/blocks.py index 791da14..fe19b8b 100644 --- a/modules/blocks.py +++ b/modules/blocks.py @@ -91,12 +91,25 @@ class ConvModule(ShapeMixin, nn.Module): return tensor -# TODO class PreInitializedConvModule(ShapeMixin, nn.Module): +class PreInitializedConvModule(ShapeMixin, nn.Module): + + def __init__(self, in_shape, weight_matrix): + super(PreInitializedConvModule, self).__init__() + self.in_shape = in_shape + raise NotImplementedError + # ToDo Get the weight_matrix shape and init a conv_module of similar size, + # override the weights then. + + def forward(self, x): + + return x + class SobelFilter(ShapeMixin, nn.Module): def __init__(self, in_shape): super(SobelFilter, self).__init__() + self.in_shape = in_shape self.sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).view(1, 1, 3, 3) self.sobel_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, 2, -1]]).view(1, 1, 3, 3) diff --git a/modules/util.py b/modules/util.py index 1329ab2..1b29853 100644 --- a/modules/util.py +++ b/modules/util.py @@ -39,7 +39,7 @@ class LightningBaseModule(pl.LightningModule, ABC): # Dataset Loading ################################ - # TODO: Find a way to push Class Name, library path and parameters (sometimes thiose are objects) in here + # TODO: Find a way to push Class Name, library path and parameters (sometimes those are objects) in here def size(self): return self.shape @@ -108,7 +108,8 @@ class F_x(ShapeMixin, nn.Module): super(F_x, self).__init__() self.in_shape = in_shape - def forward(self, x): + @staticmethod + def forward(x): return x @@ -174,26 +175,22 @@ class WeightInit: m.bias.data.fill_(0.01) -class FilterLayer(nn.Module): +class Filter(nn.Module): - def __init__(self): - super(FilterLayer, self).__init__() + def __init__(self, in_shape, pos, dim=-1): + super(Filter, self).__init__() - def forward(self, x): + self.in_shape = in_shape + self.pos = pos + self.dim = dim + raise SystemError('Do not use this Module - broken.') + + @staticmethod + def forward(x): tensor = x[:, -1] return tensor -class MergingLayer(nn.Module): - - def __init__(self): - super(MergingLayer, self).__init__() - - def forward(self, x): - # ToDo: Which ones to combine? - return - - class FlipTensor(nn.Module): def __init__(self, dim=-2): super(FlipTensor, self).__init__() @@ -223,43 +220,53 @@ class AutoPadToShape(object): return f'AutoPadTransform({self.shape})' -class HorizontalSplitter(nn.Module): - - def __init__(self, in_shape, n): - super(HorizontalSplitter, self).__init__() - assert len(in_shape) == 3 - self.n = n - self.in_shape = in_shape - - self.channel, self.height, self.width = self.in_shape - self.new_height = (self.height // self.n) + (1 if self.height % self.n != 0 else 0) - - self.shape = (self.channel, self.new_height, self.width) - self.autopad = AutoPadToShape(self.shape) - - def forward(self, x): - n_blocks = list() - for block_idx in range(self.n): - start = block_idx * self.new_height - end = (block_idx + 1) * self.new_height - block = self.autopad(x[:, :, start:end, :]) - n_blocks.append(block) - - return n_blocks - - -class HorizontalMerger(nn.Module): +class Splitter(nn.Module): @property def shape(self): - merged_shape = self.in_shape[0], self.in_shape[1] * self.n, self.in_shape[2] - return merged_shape + return tuple([self._out_shape] * self.n) + + @property + def out_shape(self): + return self._out_shape + + def __init__(self, in_shape, n, dim=-1): + super(Splitter, self).__init__() - def __init__(self, in_shape, n): - super(HorizontalMerger, self).__init__() - assert len(in_shape) == 3 self.n = n + self.dim = dim + self.in_shape = in_shape + + self.new_dim_size = (self.in_shape[self.dim] // self.n) + (1 if self.in_shape[self.dim] % self.n != 0 else 0) + self._out_shape = tuple([x if self.dim != i else self.new_dim_size for i, x in enumerate(self.in_shape)]) + + self.autopad = AutoPadToShape(self._out_shape) + + def forward(self, x: torch.Tensor): + x = x.transpose(0, self.dim) + n_blocks = list() + for block_idx in range(self.n): + start = block_idx * self.new_dim_size + end = (block_idx + 1) * self.new_dim_size + block = self.autopad(x[:, :, start:end, :]) + + n_blocks.append(block.transpose(0, self.dim)) + return n_blocks + + +class Merger(nn.Module): + + @property + def shape(self): + y = self.forward([torch.randn(self.in_shape)]) + return y.shape + + def __init__(self, in_shape, n, dim=-1): + super(Merger, self).__init__() + + self.n = n + self.dim = dim self.in_shape = in_shape def forward(self, x): - return torch.cat(x, dim=-2) + return torch.cat(x, dim=self.dim) diff --git a/utils/config.py b/utils/config.py index b7040f9..b8bb63c 100644 --- a/utils/config.py +++ b/utils/config.py @@ -7,9 +7,11 @@ from abc import ABC from argparse import Namespace, ArgumentParser from collections import defaultdict -from configparser import ConfigParser +from configparser import ConfigParser, DuplicateSectionError import hashlib +from ml_lib.utils.tools import locate_and_import_class + def is_jsonable(x): import json @@ -90,11 +92,13 @@ class Config(ConfigParser, ABC): @property def model_class(self): try: - return self._model_map[self.model.type] - except KeyError: - raise KeyError(f'The model alias you provided ("{self.get("model", "type")}")' + - 'does not exist! Try one of these: {list(self._model_map.keys())}') + return locate_and_import_class(self.model.type) + except AttributeError as e: + raise AttributeError(f'The model alias you provided ("{self.get("model", "type")}")' + + f'was not found!\n' + + f'{e}') + # -------------------------------------------------- # TODO: Do this programmatically; This did not work: # Initialize Default Sections as Property # for section in self.default_sections: @@ -223,3 +227,16 @@ class Config(ConfigParser, ABC): return else: super(Config, self)._write_section(fp, section_name, section_items, delimiter) + + def add_section(self, section: str) -> None: + try: + super(Config, self).add_section(section) + except DuplicateSectionError: + pass + + +class DataClass(Namespace): + + @property + def __dict__(self): + return [x for x in dir(self) if not x.startswith('_')] diff --git a/utils/logging.py b/utils/logging.py index 876fd4a..3c27474 100644 --- a/utils/logging.py +++ b/utils/logging.py @@ -1,6 +1,3 @@ -import argparse -from typing import Union, Dict, Optional, Any - from abc import ABC from pathlib import Path diff --git a/utils/tools.py b/utils/tools.py index 6e291f8..6d2395f 100644 --- a/utils/tools.py +++ b/utils/tools.py @@ -1,6 +1,9 @@ +import importlib import pickle import shelve -from pathlib import Path +from pathlib import Path, PurePath +from pydoc import safeimport +from typing import Union import numpy as np import torch @@ -37,3 +40,16 @@ def load_from_shelve(file_path, key): def check_path(file_path): assert isinstance(file_path, Path) assert str(file_path).endswith('.pik') + + +def locate_and_import_class(class_name, models_location: Union[str, PurePath] = 'models', forceload=False): + """Locate an object by name or dotted path, importing as necessary.""" + models_location = Path(models_location) + module_paths = [x for x in models_location.rglob('*.py') if x.is_file() and '__init__' not in x.name] + for module_path in module_paths: + mod = importlib.import_module('.'.join([x.replace('.py', '') for x in module_path.parts])) + try: + model_class = mod.__getattribute__(class_name) + except AttributeError: + continue + return model_class