Model Loading by string. Within Debugging
This commit is contained in:
parent
a4b6c698c3
commit
6bc9447ce1
@ -91,12 +91,25 @@ class ConvModule(ShapeMixin, nn.Module):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
# TODO class PreInitializedConvModule(ShapeMixin, nn.Module):
|
class PreInitializedConvModule(ShapeMixin, nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_shape, weight_matrix):
|
||||||
|
super(PreInitializedConvModule, self).__init__()
|
||||||
|
self.in_shape = in_shape
|
||||||
|
raise NotImplementedError
|
||||||
|
# ToDo Get the weight_matrix shape and init a conv_module of similar size,
|
||||||
|
# override the weights then.
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class SobelFilter(ShapeMixin, nn.Module):
|
class SobelFilter(ShapeMixin, nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_shape):
|
def __init__(self, in_shape):
|
||||||
super(SobelFilter, self).__init__()
|
super(SobelFilter, self).__init__()
|
||||||
|
self.in_shape = in_shape
|
||||||
|
|
||||||
self.sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).view(1, 1, 3, 3)
|
self.sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).view(1, 1, 3, 3)
|
||||||
self.sobel_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, 2, -1]]).view(1, 1, 3, 3)
|
self.sobel_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, 2, -1]]).view(1, 1, 3, 3)
|
||||||
|
103
modules/util.py
103
modules/util.py
@ -39,7 +39,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
|||||||
|
|
||||||
# Dataset Loading
|
# Dataset Loading
|
||||||
################################
|
################################
|
||||||
# TODO: Find a way to push Class Name, library path and parameters (sometimes thiose are objects) in here
|
# TODO: Find a way to push Class Name, library path and parameters (sometimes those are objects) in here
|
||||||
|
|
||||||
def size(self):
|
def size(self):
|
||||||
return self.shape
|
return self.shape
|
||||||
@ -108,7 +108,8 @@ class F_x(ShapeMixin, nn.Module):
|
|||||||
super(F_x, self).__init__()
|
super(F_x, self).__init__()
|
||||||
self.in_shape = in_shape
|
self.in_shape = in_shape
|
||||||
|
|
||||||
def forward(self, x):
|
@staticmethod
|
||||||
|
def forward(x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -174,26 +175,22 @@ class WeightInit:
|
|||||||
m.bias.data.fill_(0.01)
|
m.bias.data.fill_(0.01)
|
||||||
|
|
||||||
|
|
||||||
class FilterLayer(nn.Module):
|
class Filter(nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, in_shape, pos, dim=-1):
|
||||||
super(FilterLayer, self).__init__()
|
super(Filter, self).__init__()
|
||||||
|
|
||||||
def forward(self, x):
|
self.in_shape = in_shape
|
||||||
|
self.pos = pos
|
||||||
|
self.dim = dim
|
||||||
|
raise SystemError('Do not use this Module - broken.')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(x):
|
||||||
tensor = x[:, -1]
|
tensor = x[:, -1]
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
class MergingLayer(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super(MergingLayer, self).__init__()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# ToDo: Which ones to combine?
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
class FlipTensor(nn.Module):
|
class FlipTensor(nn.Module):
|
||||||
def __init__(self, dim=-2):
|
def __init__(self, dim=-2):
|
||||||
super(FlipTensor, self).__init__()
|
super(FlipTensor, self).__init__()
|
||||||
@ -223,43 +220,53 @@ class AutoPadToShape(object):
|
|||||||
return f'AutoPadTransform({self.shape})'
|
return f'AutoPadTransform({self.shape})'
|
||||||
|
|
||||||
|
|
||||||
class HorizontalSplitter(nn.Module):
|
class Splitter(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_shape, n):
|
|
||||||
super(HorizontalSplitter, self).__init__()
|
|
||||||
assert len(in_shape) == 3
|
|
||||||
self.n = n
|
|
||||||
self.in_shape = in_shape
|
|
||||||
|
|
||||||
self.channel, self.height, self.width = self.in_shape
|
|
||||||
self.new_height = (self.height // self.n) + (1 if self.height % self.n != 0 else 0)
|
|
||||||
|
|
||||||
self.shape = (self.channel, self.new_height, self.width)
|
|
||||||
self.autopad = AutoPadToShape(self.shape)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
n_blocks = list()
|
|
||||||
for block_idx in range(self.n):
|
|
||||||
start = block_idx * self.new_height
|
|
||||||
end = (block_idx + 1) * self.new_height
|
|
||||||
block = self.autopad(x[:, :, start:end, :])
|
|
||||||
n_blocks.append(block)
|
|
||||||
|
|
||||||
return n_blocks
|
|
||||||
|
|
||||||
|
|
||||||
class HorizontalMerger(nn.Module):
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self):
|
def shape(self):
|
||||||
merged_shape = self.in_shape[0], self.in_shape[1] * self.n, self.in_shape[2]
|
return tuple([self._out_shape] * self.n)
|
||||||
return merged_shape
|
|
||||||
|
@property
|
||||||
|
def out_shape(self):
|
||||||
|
return self._out_shape
|
||||||
|
|
||||||
|
def __init__(self, in_shape, n, dim=-1):
|
||||||
|
super(Splitter, self).__init__()
|
||||||
|
|
||||||
def __init__(self, in_shape, n):
|
|
||||||
super(HorizontalMerger, self).__init__()
|
|
||||||
assert len(in_shape) == 3
|
|
||||||
self.n = n
|
self.n = n
|
||||||
|
self.dim = dim
|
||||||
|
self.in_shape = in_shape
|
||||||
|
|
||||||
|
self.new_dim_size = (self.in_shape[self.dim] // self.n) + (1 if self.in_shape[self.dim] % self.n != 0 else 0)
|
||||||
|
self._out_shape = tuple([x if self.dim != i else self.new_dim_size for i, x in enumerate(self.in_shape)])
|
||||||
|
|
||||||
|
self.autopad = AutoPadToShape(self._out_shape)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
x = x.transpose(0, self.dim)
|
||||||
|
n_blocks = list()
|
||||||
|
for block_idx in range(self.n):
|
||||||
|
start = block_idx * self.new_dim_size
|
||||||
|
end = (block_idx + 1) * self.new_dim_size
|
||||||
|
block = self.autopad(x[:, :, start:end, :])
|
||||||
|
|
||||||
|
n_blocks.append(block.transpose(0, self.dim))
|
||||||
|
return n_blocks
|
||||||
|
|
||||||
|
|
||||||
|
class Merger(nn.Module):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
y = self.forward([torch.randn(self.in_shape)])
|
||||||
|
return y.shape
|
||||||
|
|
||||||
|
def __init__(self, in_shape, n, dim=-1):
|
||||||
|
super(Merger, self).__init__()
|
||||||
|
|
||||||
|
self.n = n
|
||||||
|
self.dim = dim
|
||||||
self.in_shape = in_shape
|
self.in_shape = in_shape
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.cat(x, dim=-2)
|
return torch.cat(x, dim=self.dim)
|
||||||
|
@ -7,9 +7,11 @@ from abc import ABC
|
|||||||
|
|
||||||
from argparse import Namespace, ArgumentParser
|
from argparse import Namespace, ArgumentParser
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from configparser import ConfigParser
|
from configparser import ConfigParser, DuplicateSectionError
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
|
from ml_lib.utils.tools import locate_and_import_class
|
||||||
|
|
||||||
|
|
||||||
def is_jsonable(x):
|
def is_jsonable(x):
|
||||||
import json
|
import json
|
||||||
@ -90,11 +92,13 @@ class Config(ConfigParser, ABC):
|
|||||||
@property
|
@property
|
||||||
def model_class(self):
|
def model_class(self):
|
||||||
try:
|
try:
|
||||||
return self._model_map[self.model.type]
|
return locate_and_import_class(self.model.type)
|
||||||
except KeyError:
|
except AttributeError as e:
|
||||||
raise KeyError(f'The model alias you provided ("{self.get("model", "type")}")' +
|
raise AttributeError(f'The model alias you provided ("{self.get("model", "type")}")' +
|
||||||
'does not exist! Try one of these: {list(self._model_map.keys())}')
|
f'was not found!\n' +
|
||||||
|
f'{e}')
|
||||||
|
|
||||||
|
# --------------------------------------------------
|
||||||
# TODO: Do this programmatically; This did not work:
|
# TODO: Do this programmatically; This did not work:
|
||||||
# Initialize Default Sections as Property
|
# Initialize Default Sections as Property
|
||||||
# for section in self.default_sections:
|
# for section in self.default_sections:
|
||||||
@ -223,3 +227,16 @@ class Config(ConfigParser, ABC):
|
|||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
super(Config, self)._write_section(fp, section_name, section_items, delimiter)
|
super(Config, self)._write_section(fp, section_name, section_items, delimiter)
|
||||||
|
|
||||||
|
def add_section(self, section: str) -> None:
|
||||||
|
try:
|
||||||
|
super(Config, self).add_section(section)
|
||||||
|
except DuplicateSectionError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DataClass(Namespace):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def __dict__(self):
|
||||||
|
return [x for x in dir(self) if not x.startswith('_')]
|
||||||
|
@ -1,6 +1,3 @@
|
|||||||
import argparse
|
|
||||||
from typing import Union, Dict, Optional, Any
|
|
||||||
|
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
|
import importlib
|
||||||
import pickle
|
import pickle
|
||||||
import shelve
|
import shelve
|
||||||
from pathlib import Path
|
from pathlib import Path, PurePath
|
||||||
|
from pydoc import safeimport
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -37,3 +40,16 @@ def load_from_shelve(file_path, key):
|
|||||||
def check_path(file_path):
|
def check_path(file_path):
|
||||||
assert isinstance(file_path, Path)
|
assert isinstance(file_path, Path)
|
||||||
assert str(file_path).endswith('.pik')
|
assert str(file_path).endswith('.pik')
|
||||||
|
|
||||||
|
|
||||||
|
def locate_and_import_class(class_name, models_location: Union[str, PurePath] = 'models', forceload=False):
|
||||||
|
"""Locate an object by name or dotted path, importing as necessary."""
|
||||||
|
models_location = Path(models_location)
|
||||||
|
module_paths = [x for x in models_location.rglob('*.py') if x.is_file() and '__init__' not in x.name]
|
||||||
|
for module_path in module_paths:
|
||||||
|
mod = importlib.import_module('.'.join([x.replace('.py', '') for x in module_path.parts]))
|
||||||
|
try:
|
||||||
|
model_class = mod.__getattribute__(class_name)
|
||||||
|
except AttributeError:
|
||||||
|
continue
|
||||||
|
return model_class
|
||||||
|
Loading…
x
Reference in New Issue
Block a user