Model Loading by string. Within Debugging

This commit is contained in:
Si11ium 2020-08-15 12:42:57 +02:00
parent a4b6c698c3
commit 6bc9447ce1
5 changed files with 108 additions and 58 deletions

View File

@ -91,12 +91,25 @@ class ConvModule(ShapeMixin, nn.Module):
return tensor
# TODO class PreInitializedConvModule(ShapeMixin, nn.Module):
class PreInitializedConvModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, weight_matrix):
super(PreInitializedConvModule, self).__init__()
self.in_shape = in_shape
raise NotImplementedError
# ToDo Get the weight_matrix shape and init a conv_module of similar size,
# override the weights then.
def forward(self, x):
return x
class SobelFilter(ShapeMixin, nn.Module):
def __init__(self, in_shape):
super(SobelFilter, self).__init__()
self.in_shape = in_shape
self.sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).view(1, 1, 3, 3)
self.sobel_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, 2, -1]]).view(1, 1, 3, 3)

View File

@ -39,7 +39,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
# Dataset Loading
################################
# TODO: Find a way to push Class Name, library path and parameters (sometimes thiose are objects) in here
# TODO: Find a way to push Class Name, library path and parameters (sometimes those are objects) in here
def size(self):
return self.shape
@ -108,7 +108,8 @@ class F_x(ShapeMixin, nn.Module):
super(F_x, self).__init__()
self.in_shape = in_shape
def forward(self, x):
@staticmethod
def forward(x):
return x
@ -174,26 +175,22 @@ class WeightInit:
m.bias.data.fill_(0.01)
class FilterLayer(nn.Module):
class Filter(nn.Module):
def __init__(self):
super(FilterLayer, self).__init__()
def __init__(self, in_shape, pos, dim=-1):
super(Filter, self).__init__()
def forward(self, x):
self.in_shape = in_shape
self.pos = pos
self.dim = dim
raise SystemError('Do not use this Module - broken.')
@staticmethod
def forward(x):
tensor = x[:, -1]
return tensor
class MergingLayer(nn.Module):
def __init__(self):
super(MergingLayer, self).__init__()
def forward(self, x):
# ToDo: Which ones to combine?
return
class FlipTensor(nn.Module):
def __init__(self, dim=-2):
super(FlipTensor, self).__init__()
@ -223,43 +220,53 @@ class AutoPadToShape(object):
return f'AutoPadTransform({self.shape})'
class HorizontalSplitter(nn.Module):
def __init__(self, in_shape, n):
super(HorizontalSplitter, self).__init__()
assert len(in_shape) == 3
self.n = n
self.in_shape = in_shape
self.channel, self.height, self.width = self.in_shape
self.new_height = (self.height // self.n) + (1 if self.height % self.n != 0 else 0)
self.shape = (self.channel, self.new_height, self.width)
self.autopad = AutoPadToShape(self.shape)
def forward(self, x):
n_blocks = list()
for block_idx in range(self.n):
start = block_idx * self.new_height
end = (block_idx + 1) * self.new_height
block = self.autopad(x[:, :, start:end, :])
n_blocks.append(block)
return n_blocks
class HorizontalMerger(nn.Module):
class Splitter(nn.Module):
@property
def shape(self):
merged_shape = self.in_shape[0], self.in_shape[1] * self.n, self.in_shape[2]
return merged_shape
return tuple([self._out_shape] * self.n)
@property
def out_shape(self):
return self._out_shape
def __init__(self, in_shape, n, dim=-1):
super(Splitter, self).__init__()
def __init__(self, in_shape, n):
super(HorizontalMerger, self).__init__()
assert len(in_shape) == 3
self.n = n
self.dim = dim
self.in_shape = in_shape
self.new_dim_size = (self.in_shape[self.dim] // self.n) + (1 if self.in_shape[self.dim] % self.n != 0 else 0)
self._out_shape = tuple([x if self.dim != i else self.new_dim_size for i, x in enumerate(self.in_shape)])
self.autopad = AutoPadToShape(self._out_shape)
def forward(self, x: torch.Tensor):
x = x.transpose(0, self.dim)
n_blocks = list()
for block_idx in range(self.n):
start = block_idx * self.new_dim_size
end = (block_idx + 1) * self.new_dim_size
block = self.autopad(x[:, :, start:end, :])
n_blocks.append(block.transpose(0, self.dim))
return n_blocks
class Merger(nn.Module):
@property
def shape(self):
y = self.forward([torch.randn(self.in_shape)])
return y.shape
def __init__(self, in_shape, n, dim=-1):
super(Merger, self).__init__()
self.n = n
self.dim = dim
self.in_shape = in_shape
def forward(self, x):
return torch.cat(x, dim=-2)
return torch.cat(x, dim=self.dim)

View File

@ -7,9 +7,11 @@ from abc import ABC
from argparse import Namespace, ArgumentParser
from collections import defaultdict
from configparser import ConfigParser
from configparser import ConfigParser, DuplicateSectionError
import hashlib
from ml_lib.utils.tools import locate_and_import_class
def is_jsonable(x):
import json
@ -90,11 +92,13 @@ class Config(ConfigParser, ABC):
@property
def model_class(self):
try:
return self._model_map[self.model.type]
except KeyError:
raise KeyError(f'The model alias you provided ("{self.get("model", "type")}")' +
'does not exist! Try one of these: {list(self._model_map.keys())}')
return locate_and_import_class(self.model.type)
except AttributeError as e:
raise AttributeError(f'The model alias you provided ("{self.get("model", "type")}")' +
f'was not found!\n' +
f'{e}')
# --------------------------------------------------
# TODO: Do this programmatically; This did not work:
# Initialize Default Sections as Property
# for section in self.default_sections:
@ -223,3 +227,16 @@ class Config(ConfigParser, ABC):
return
else:
super(Config, self)._write_section(fp, section_name, section_items, delimiter)
def add_section(self, section: str) -> None:
try:
super(Config, self).add_section(section)
except DuplicateSectionError:
pass
class DataClass(Namespace):
@property
def __dict__(self):
return [x for x in dir(self) if not x.startswith('_')]

View File

@ -1,6 +1,3 @@
import argparse
from typing import Union, Dict, Optional, Any
from abc import ABC
from pathlib import Path

View File

@ -1,6 +1,9 @@
import importlib
import pickle
import shelve
from pathlib import Path
from pathlib import Path, PurePath
from pydoc import safeimport
from typing import Union
import numpy as np
import torch
@ -37,3 +40,16 @@ def load_from_shelve(file_path, key):
def check_path(file_path):
assert isinstance(file_path, Path)
assert str(file_path).endswith('.pik')
def locate_and_import_class(class_name, models_location: Union[str, PurePath] = 'models', forceload=False):
"""Locate an object by name or dotted path, importing as necessary."""
models_location = Path(models_location)
module_paths = [x for x in models_location.rglob('*.py') if x.is_file() and '__init__' not in x.name]
for module_path in module_paths:
mod = importlib.import_module('.'.join([x.replace('.py', '') for x in module_path.parts]))
try:
model_class = mod.__getattribute__(class_name)
except AttributeError:
continue
return model_class