Model Loading by string. Within Debugging

This commit is contained in:
Si11ium
2020-08-15 12:42:57 +02:00
parent a4b6c698c3
commit 6bc9447ce1
5 changed files with 108 additions and 58 deletions

View File

@ -39,7 +39,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
# Dataset Loading
################################
# TODO: Find a way to push Class Name, library path and parameters (sometimes thiose are objects) in here
# TODO: Find a way to push Class Name, library path and parameters (sometimes those are objects) in here
def size(self):
return self.shape
@ -108,7 +108,8 @@ class F_x(ShapeMixin, nn.Module):
super(F_x, self).__init__()
self.in_shape = in_shape
def forward(self, x):
@staticmethod
def forward(x):
return x
@ -174,26 +175,22 @@ class WeightInit:
m.bias.data.fill_(0.01)
class FilterLayer(nn.Module):
class Filter(nn.Module):
def __init__(self):
super(FilterLayer, self).__init__()
def __init__(self, in_shape, pos, dim=-1):
super(Filter, self).__init__()
def forward(self, x):
self.in_shape = in_shape
self.pos = pos
self.dim = dim
raise SystemError('Do not use this Module - broken.')
@staticmethod
def forward(x):
tensor = x[:, -1]
return tensor
class MergingLayer(nn.Module):
def __init__(self):
super(MergingLayer, self).__init__()
def forward(self, x):
# ToDo: Which ones to combine?
return
class FlipTensor(nn.Module):
def __init__(self, dim=-2):
super(FlipTensor, self).__init__()
@ -223,43 +220,53 @@ class AutoPadToShape(object):
return f'AutoPadTransform({self.shape})'
class HorizontalSplitter(nn.Module):
def __init__(self, in_shape, n):
super(HorizontalSplitter, self).__init__()
assert len(in_shape) == 3
self.n = n
self.in_shape = in_shape
self.channel, self.height, self.width = self.in_shape
self.new_height = (self.height // self.n) + (1 if self.height % self.n != 0 else 0)
self.shape = (self.channel, self.new_height, self.width)
self.autopad = AutoPadToShape(self.shape)
def forward(self, x):
n_blocks = list()
for block_idx in range(self.n):
start = block_idx * self.new_height
end = (block_idx + 1) * self.new_height
block = self.autopad(x[:, :, start:end, :])
n_blocks.append(block)
return n_blocks
class HorizontalMerger(nn.Module):
class Splitter(nn.Module):
@property
def shape(self):
merged_shape = self.in_shape[0], self.in_shape[1] * self.n, self.in_shape[2]
return merged_shape
return tuple([self._out_shape] * self.n)
@property
def out_shape(self):
return self._out_shape
def __init__(self, in_shape, n, dim=-1):
super(Splitter, self).__init__()
def __init__(self, in_shape, n):
super(HorizontalMerger, self).__init__()
assert len(in_shape) == 3
self.n = n
self.dim = dim
self.in_shape = in_shape
self.new_dim_size = (self.in_shape[self.dim] // self.n) + (1 if self.in_shape[self.dim] % self.n != 0 else 0)
self._out_shape = tuple([x if self.dim != i else self.new_dim_size for i, x in enumerate(self.in_shape)])
self.autopad = AutoPadToShape(self._out_shape)
def forward(self, x: torch.Tensor):
x = x.transpose(0, self.dim)
n_blocks = list()
for block_idx in range(self.n):
start = block_idx * self.new_dim_size
end = (block_idx + 1) * self.new_dim_size
block = self.autopad(x[:, :, start:end, :])
n_blocks.append(block.transpose(0, self.dim))
return n_blocks
class Merger(nn.Module):
@property
def shape(self):
y = self.forward([torch.randn(self.in_shape)])
return y.shape
def __init__(self, in_shape, n, dim=-1):
super(Merger, self).__init__()
self.n = n
self.dim = dim
self.in_shape = in_shape
def forward(self, x):
return torch.cat(x, dim=-2)
return torch.cat(x, dim=self.dim)