Model Loading by string. Within Debugging

This commit is contained in:
Si11ium 2020-08-15 12:42:57 +02:00
parent a4b6c698c3
commit 6bc9447ce1
5 changed files with 108 additions and 58 deletions

View File

@ -91,12 +91,25 @@ class ConvModule(ShapeMixin, nn.Module):
return tensor return tensor
# TODO class PreInitializedConvModule(ShapeMixin, nn.Module): class PreInitializedConvModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, weight_matrix):
super(PreInitializedConvModule, self).__init__()
self.in_shape = in_shape
raise NotImplementedError
# ToDo Get the weight_matrix shape and init a conv_module of similar size,
# override the weights then.
def forward(self, x):
return x
class SobelFilter(ShapeMixin, nn.Module): class SobelFilter(ShapeMixin, nn.Module):
def __init__(self, in_shape): def __init__(self, in_shape):
super(SobelFilter, self).__init__() super(SobelFilter, self).__init__()
self.in_shape = in_shape
self.sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).view(1, 1, 3, 3) self.sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).view(1, 1, 3, 3)
self.sobel_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, 2, -1]]).view(1, 1, 3, 3) self.sobel_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, 2, -1]]).view(1, 1, 3, 3)

View File

@ -39,7 +39,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
# Dataset Loading # Dataset Loading
################################ ################################
# TODO: Find a way to push Class Name, library path and parameters (sometimes thiose are objects) in here # TODO: Find a way to push Class Name, library path and parameters (sometimes those are objects) in here
def size(self): def size(self):
return self.shape return self.shape
@ -108,7 +108,8 @@ class F_x(ShapeMixin, nn.Module):
super(F_x, self).__init__() super(F_x, self).__init__()
self.in_shape = in_shape self.in_shape = in_shape
def forward(self, x): @staticmethod
def forward(x):
return x return x
@ -174,26 +175,22 @@ class WeightInit:
m.bias.data.fill_(0.01) m.bias.data.fill_(0.01)
class FilterLayer(nn.Module): class Filter(nn.Module):
def __init__(self): def __init__(self, in_shape, pos, dim=-1):
super(FilterLayer, self).__init__() super(Filter, self).__init__()
def forward(self, x): self.in_shape = in_shape
self.pos = pos
self.dim = dim
raise SystemError('Do not use this Module - broken.')
@staticmethod
def forward(x):
tensor = x[:, -1] tensor = x[:, -1]
return tensor return tensor
class MergingLayer(nn.Module):
def __init__(self):
super(MergingLayer, self).__init__()
def forward(self, x):
# ToDo: Which ones to combine?
return
class FlipTensor(nn.Module): class FlipTensor(nn.Module):
def __init__(self, dim=-2): def __init__(self, dim=-2):
super(FlipTensor, self).__init__() super(FlipTensor, self).__init__()
@ -223,43 +220,53 @@ class AutoPadToShape(object):
return f'AutoPadTransform({self.shape})' return f'AutoPadTransform({self.shape})'
class HorizontalSplitter(nn.Module): class Splitter(nn.Module):
def __init__(self, in_shape, n):
super(HorizontalSplitter, self).__init__()
assert len(in_shape) == 3
self.n = n
self.in_shape = in_shape
self.channel, self.height, self.width = self.in_shape
self.new_height = (self.height // self.n) + (1 if self.height % self.n != 0 else 0)
self.shape = (self.channel, self.new_height, self.width)
self.autopad = AutoPadToShape(self.shape)
def forward(self, x):
n_blocks = list()
for block_idx in range(self.n):
start = block_idx * self.new_height
end = (block_idx + 1) * self.new_height
block = self.autopad(x[:, :, start:end, :])
n_blocks.append(block)
return n_blocks
class HorizontalMerger(nn.Module):
@property @property
def shape(self): def shape(self):
merged_shape = self.in_shape[0], self.in_shape[1] * self.n, self.in_shape[2] return tuple([self._out_shape] * self.n)
return merged_shape
@property
def out_shape(self):
return self._out_shape
def __init__(self, in_shape, n, dim=-1):
super(Splitter, self).__init__()
def __init__(self, in_shape, n):
super(HorizontalMerger, self).__init__()
assert len(in_shape) == 3
self.n = n self.n = n
self.dim = dim
self.in_shape = in_shape
self.new_dim_size = (self.in_shape[self.dim] // self.n) + (1 if self.in_shape[self.dim] % self.n != 0 else 0)
self._out_shape = tuple([x if self.dim != i else self.new_dim_size for i, x in enumerate(self.in_shape)])
self.autopad = AutoPadToShape(self._out_shape)
def forward(self, x: torch.Tensor):
x = x.transpose(0, self.dim)
n_blocks = list()
for block_idx in range(self.n):
start = block_idx * self.new_dim_size
end = (block_idx + 1) * self.new_dim_size
block = self.autopad(x[:, :, start:end, :])
n_blocks.append(block.transpose(0, self.dim))
return n_blocks
class Merger(nn.Module):
@property
def shape(self):
y = self.forward([torch.randn(self.in_shape)])
return y.shape
def __init__(self, in_shape, n, dim=-1):
super(Merger, self).__init__()
self.n = n
self.dim = dim
self.in_shape = in_shape self.in_shape = in_shape
def forward(self, x): def forward(self, x):
return torch.cat(x, dim=-2) return torch.cat(x, dim=self.dim)

View File

@ -7,9 +7,11 @@ from abc import ABC
from argparse import Namespace, ArgumentParser from argparse import Namespace, ArgumentParser
from collections import defaultdict from collections import defaultdict
from configparser import ConfigParser from configparser import ConfigParser, DuplicateSectionError
import hashlib import hashlib
from ml_lib.utils.tools import locate_and_import_class
def is_jsonable(x): def is_jsonable(x):
import json import json
@ -90,11 +92,13 @@ class Config(ConfigParser, ABC):
@property @property
def model_class(self): def model_class(self):
try: try:
return self._model_map[self.model.type] return locate_and_import_class(self.model.type)
except KeyError: except AttributeError as e:
raise KeyError(f'The model alias you provided ("{self.get("model", "type")}")' + raise AttributeError(f'The model alias you provided ("{self.get("model", "type")}")' +
'does not exist! Try one of these: {list(self._model_map.keys())}') f'was not found!\n' +
f'{e}')
# --------------------------------------------------
# TODO: Do this programmatically; This did not work: # TODO: Do this programmatically; This did not work:
# Initialize Default Sections as Property # Initialize Default Sections as Property
# for section in self.default_sections: # for section in self.default_sections:
@ -223,3 +227,16 @@ class Config(ConfigParser, ABC):
return return
else: else:
super(Config, self)._write_section(fp, section_name, section_items, delimiter) super(Config, self)._write_section(fp, section_name, section_items, delimiter)
def add_section(self, section: str) -> None:
try:
super(Config, self).add_section(section)
except DuplicateSectionError:
pass
class DataClass(Namespace):
@property
def __dict__(self):
return [x for x in dir(self) if not x.startswith('_')]

View File

@ -1,6 +1,3 @@
import argparse
from typing import Union, Dict, Optional, Any
from abc import ABC from abc import ABC
from pathlib import Path from pathlib import Path

View File

@ -1,6 +1,9 @@
import importlib
import pickle import pickle
import shelve import shelve
from pathlib import Path from pathlib import Path, PurePath
from pydoc import safeimport
from typing import Union
import numpy as np import numpy as np
import torch import torch
@ -37,3 +40,16 @@ def load_from_shelve(file_path, key):
def check_path(file_path): def check_path(file_path):
assert isinstance(file_path, Path) assert isinstance(file_path, Path)
assert str(file_path).endswith('.pik') assert str(file_path).endswith('.pik')
def locate_and_import_class(class_name, models_location: Union[str, PurePath] = 'models', forceload=False):
"""Locate an object by name or dotted path, importing as necessary."""
models_location = Path(models_location)
module_paths = [x for x in models_location.rglob('*.py') if x.is_file() and '__init__' not in x.name]
for module_path in module_paths:
mod = importlib.import_module('.'.join([x.replace('.py', '') for x in module_path.parts]))
try:
model_class = mod.__getattribute__(class_name)
except AttributeError:
continue
return model_class