Model Loading by string. Within Debugging
This commit is contained in:
parent
a4b6c698c3
commit
6bc9447ce1
@ -91,12 +91,25 @@ class ConvModule(ShapeMixin, nn.Module):
|
||||
return tensor
|
||||
|
||||
|
||||
# TODO class PreInitializedConvModule(ShapeMixin, nn.Module):
|
||||
class PreInitializedConvModule(ShapeMixin, nn.Module):
|
||||
|
||||
def __init__(self, in_shape, weight_matrix):
|
||||
super(PreInitializedConvModule, self).__init__()
|
||||
self.in_shape = in_shape
|
||||
raise NotImplementedError
|
||||
# ToDo Get the weight_matrix shape and init a conv_module of similar size,
|
||||
# override the weights then.
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SobelFilter(ShapeMixin, nn.Module):
|
||||
|
||||
def __init__(self, in_shape):
|
||||
super(SobelFilter, self).__init__()
|
||||
self.in_shape = in_shape
|
||||
|
||||
self.sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).view(1, 1, 3, 3)
|
||||
self.sobel_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, 2, -1]]).view(1, 1, 3, 3)
|
||||
|
103
modules/util.py
103
modules/util.py
@ -39,7 +39,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
|
||||
# Dataset Loading
|
||||
################################
|
||||
# TODO: Find a way to push Class Name, library path and parameters (sometimes thiose are objects) in here
|
||||
# TODO: Find a way to push Class Name, library path and parameters (sometimes those are objects) in here
|
||||
|
||||
def size(self):
|
||||
return self.shape
|
||||
@ -108,7 +108,8 @@ class F_x(ShapeMixin, nn.Module):
|
||||
super(F_x, self).__init__()
|
||||
self.in_shape = in_shape
|
||||
|
||||
def forward(self, x):
|
||||
@staticmethod
|
||||
def forward(x):
|
||||
return x
|
||||
|
||||
|
||||
@ -174,26 +175,22 @@ class WeightInit:
|
||||
m.bias.data.fill_(0.01)
|
||||
|
||||
|
||||
class FilterLayer(nn.Module):
|
||||
class Filter(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(FilterLayer, self).__init__()
|
||||
def __init__(self, in_shape, pos, dim=-1):
|
||||
super(Filter, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
self.in_shape = in_shape
|
||||
self.pos = pos
|
||||
self.dim = dim
|
||||
raise SystemError('Do not use this Module - broken.')
|
||||
|
||||
@staticmethod
|
||||
def forward(x):
|
||||
tensor = x[:, -1]
|
||||
return tensor
|
||||
|
||||
|
||||
class MergingLayer(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(MergingLayer, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
# ToDo: Which ones to combine?
|
||||
return
|
||||
|
||||
|
||||
class FlipTensor(nn.Module):
|
||||
def __init__(self, dim=-2):
|
||||
super(FlipTensor, self).__init__()
|
||||
@ -223,43 +220,53 @@ class AutoPadToShape(object):
|
||||
return f'AutoPadTransform({self.shape})'
|
||||
|
||||
|
||||
class HorizontalSplitter(nn.Module):
|
||||
|
||||
def __init__(self, in_shape, n):
|
||||
super(HorizontalSplitter, self).__init__()
|
||||
assert len(in_shape) == 3
|
||||
self.n = n
|
||||
self.in_shape = in_shape
|
||||
|
||||
self.channel, self.height, self.width = self.in_shape
|
||||
self.new_height = (self.height // self.n) + (1 if self.height % self.n != 0 else 0)
|
||||
|
||||
self.shape = (self.channel, self.new_height, self.width)
|
||||
self.autopad = AutoPadToShape(self.shape)
|
||||
|
||||
def forward(self, x):
|
||||
n_blocks = list()
|
||||
for block_idx in range(self.n):
|
||||
start = block_idx * self.new_height
|
||||
end = (block_idx + 1) * self.new_height
|
||||
block = self.autopad(x[:, :, start:end, :])
|
||||
n_blocks.append(block)
|
||||
|
||||
return n_blocks
|
||||
|
||||
|
||||
class HorizontalMerger(nn.Module):
|
||||
class Splitter(nn.Module):
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
merged_shape = self.in_shape[0], self.in_shape[1] * self.n, self.in_shape[2]
|
||||
return merged_shape
|
||||
return tuple([self._out_shape] * self.n)
|
||||
|
||||
@property
|
||||
def out_shape(self):
|
||||
return self._out_shape
|
||||
|
||||
def __init__(self, in_shape, n, dim=-1):
|
||||
super(Splitter, self).__init__()
|
||||
|
||||
def __init__(self, in_shape, n):
|
||||
super(HorizontalMerger, self).__init__()
|
||||
assert len(in_shape) == 3
|
||||
self.n = n
|
||||
self.dim = dim
|
||||
self.in_shape = in_shape
|
||||
|
||||
self.new_dim_size = (self.in_shape[self.dim] // self.n) + (1 if self.in_shape[self.dim] % self.n != 0 else 0)
|
||||
self._out_shape = tuple([x if self.dim != i else self.new_dim_size for i, x in enumerate(self.in_shape)])
|
||||
|
||||
self.autopad = AutoPadToShape(self._out_shape)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = x.transpose(0, self.dim)
|
||||
n_blocks = list()
|
||||
for block_idx in range(self.n):
|
||||
start = block_idx * self.new_dim_size
|
||||
end = (block_idx + 1) * self.new_dim_size
|
||||
block = self.autopad(x[:, :, start:end, :])
|
||||
|
||||
n_blocks.append(block.transpose(0, self.dim))
|
||||
return n_blocks
|
||||
|
||||
|
||||
class Merger(nn.Module):
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
y = self.forward([torch.randn(self.in_shape)])
|
||||
return y.shape
|
||||
|
||||
def __init__(self, in_shape, n, dim=-1):
|
||||
super(Merger, self).__init__()
|
||||
|
||||
self.n = n
|
||||
self.dim = dim
|
||||
self.in_shape = in_shape
|
||||
|
||||
def forward(self, x):
|
||||
return torch.cat(x, dim=-2)
|
||||
return torch.cat(x, dim=self.dim)
|
||||
|
@ -7,9 +7,11 @@ from abc import ABC
|
||||
|
||||
from argparse import Namespace, ArgumentParser
|
||||
from collections import defaultdict
|
||||
from configparser import ConfigParser
|
||||
from configparser import ConfigParser, DuplicateSectionError
|
||||
import hashlib
|
||||
|
||||
from ml_lib.utils.tools import locate_and_import_class
|
||||
|
||||
|
||||
def is_jsonable(x):
|
||||
import json
|
||||
@ -90,11 +92,13 @@ class Config(ConfigParser, ABC):
|
||||
@property
|
||||
def model_class(self):
|
||||
try:
|
||||
return self._model_map[self.model.type]
|
||||
except KeyError:
|
||||
raise KeyError(f'The model alias you provided ("{self.get("model", "type")}")' +
|
||||
'does not exist! Try one of these: {list(self._model_map.keys())}')
|
||||
return locate_and_import_class(self.model.type)
|
||||
except AttributeError as e:
|
||||
raise AttributeError(f'The model alias you provided ("{self.get("model", "type")}")' +
|
||||
f'was not found!\n' +
|
||||
f'{e}')
|
||||
|
||||
# --------------------------------------------------
|
||||
# TODO: Do this programmatically; This did not work:
|
||||
# Initialize Default Sections as Property
|
||||
# for section in self.default_sections:
|
||||
@ -223,3 +227,16 @@ class Config(ConfigParser, ABC):
|
||||
return
|
||||
else:
|
||||
super(Config, self)._write_section(fp, section_name, section_items, delimiter)
|
||||
|
||||
def add_section(self, section: str) -> None:
|
||||
try:
|
||||
super(Config, self).add_section(section)
|
||||
except DuplicateSectionError:
|
||||
pass
|
||||
|
||||
|
||||
class DataClass(Namespace):
|
||||
|
||||
@property
|
||||
def __dict__(self):
|
||||
return [x for x in dir(self) if not x.startswith('_')]
|
||||
|
@ -1,6 +1,3 @@
|
||||
import argparse
|
||||
from typing import Union, Dict, Optional, Any
|
||||
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -1,6 +1,9 @@
|
||||
import importlib
|
||||
import pickle
|
||||
import shelve
|
||||
from pathlib import Path
|
||||
from pathlib import Path, PurePath
|
||||
from pydoc import safeimport
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -37,3 +40,16 @@ def load_from_shelve(file_path, key):
|
||||
def check_path(file_path):
|
||||
assert isinstance(file_path, Path)
|
||||
assert str(file_path).endswith('.pik')
|
||||
|
||||
|
||||
def locate_and_import_class(class_name, models_location: Union[str, PurePath] = 'models', forceload=False):
|
||||
"""Locate an object by name or dotted path, importing as necessary."""
|
||||
models_location = Path(models_location)
|
||||
module_paths = [x for x in models_location.rglob('*.py') if x.is_file() and '__init__' not in x.name]
|
||||
for module_path in module_paths:
|
||||
mod = importlib.import_module('.'.join([x.replace('.py', '') for x in module_path.parts]))
|
||||
try:
|
||||
model_class = mod.__getattribute__(class_name)
|
||||
except AttributeError:
|
||||
continue
|
||||
return model_class
|
||||
|
Loading…
x
Reference in New Issue
Block a user