Model Loading by string. Within Debugging
This commit is contained in:
103
modules/util.py
103
modules/util.py
@ -39,7 +39,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
|
||||
# Dataset Loading
|
||||
################################
|
||||
# TODO: Find a way to push Class Name, library path and parameters (sometimes thiose are objects) in here
|
||||
# TODO: Find a way to push Class Name, library path and parameters (sometimes those are objects) in here
|
||||
|
||||
def size(self):
|
||||
return self.shape
|
||||
@ -108,7 +108,8 @@ class F_x(ShapeMixin, nn.Module):
|
||||
super(F_x, self).__init__()
|
||||
self.in_shape = in_shape
|
||||
|
||||
def forward(self, x):
|
||||
@staticmethod
|
||||
def forward(x):
|
||||
return x
|
||||
|
||||
|
||||
@ -174,26 +175,22 @@ class WeightInit:
|
||||
m.bias.data.fill_(0.01)
|
||||
|
||||
|
||||
class FilterLayer(nn.Module):
|
||||
class Filter(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(FilterLayer, self).__init__()
|
||||
def __init__(self, in_shape, pos, dim=-1):
|
||||
super(Filter, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
self.in_shape = in_shape
|
||||
self.pos = pos
|
||||
self.dim = dim
|
||||
raise SystemError('Do not use this Module - broken.')
|
||||
|
||||
@staticmethod
|
||||
def forward(x):
|
||||
tensor = x[:, -1]
|
||||
return tensor
|
||||
|
||||
|
||||
class MergingLayer(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(MergingLayer, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
# ToDo: Which ones to combine?
|
||||
return
|
||||
|
||||
|
||||
class FlipTensor(nn.Module):
|
||||
def __init__(self, dim=-2):
|
||||
super(FlipTensor, self).__init__()
|
||||
@ -223,43 +220,53 @@ class AutoPadToShape(object):
|
||||
return f'AutoPadTransform({self.shape})'
|
||||
|
||||
|
||||
class HorizontalSplitter(nn.Module):
|
||||
|
||||
def __init__(self, in_shape, n):
|
||||
super(HorizontalSplitter, self).__init__()
|
||||
assert len(in_shape) == 3
|
||||
self.n = n
|
||||
self.in_shape = in_shape
|
||||
|
||||
self.channel, self.height, self.width = self.in_shape
|
||||
self.new_height = (self.height // self.n) + (1 if self.height % self.n != 0 else 0)
|
||||
|
||||
self.shape = (self.channel, self.new_height, self.width)
|
||||
self.autopad = AutoPadToShape(self.shape)
|
||||
|
||||
def forward(self, x):
|
||||
n_blocks = list()
|
||||
for block_idx in range(self.n):
|
||||
start = block_idx * self.new_height
|
||||
end = (block_idx + 1) * self.new_height
|
||||
block = self.autopad(x[:, :, start:end, :])
|
||||
n_blocks.append(block)
|
||||
|
||||
return n_blocks
|
||||
|
||||
|
||||
class HorizontalMerger(nn.Module):
|
||||
class Splitter(nn.Module):
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
merged_shape = self.in_shape[0], self.in_shape[1] * self.n, self.in_shape[2]
|
||||
return merged_shape
|
||||
return tuple([self._out_shape] * self.n)
|
||||
|
||||
@property
|
||||
def out_shape(self):
|
||||
return self._out_shape
|
||||
|
||||
def __init__(self, in_shape, n, dim=-1):
|
||||
super(Splitter, self).__init__()
|
||||
|
||||
def __init__(self, in_shape, n):
|
||||
super(HorizontalMerger, self).__init__()
|
||||
assert len(in_shape) == 3
|
||||
self.n = n
|
||||
self.dim = dim
|
||||
self.in_shape = in_shape
|
||||
|
||||
self.new_dim_size = (self.in_shape[self.dim] // self.n) + (1 if self.in_shape[self.dim] % self.n != 0 else 0)
|
||||
self._out_shape = tuple([x if self.dim != i else self.new_dim_size for i, x in enumerate(self.in_shape)])
|
||||
|
||||
self.autopad = AutoPadToShape(self._out_shape)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = x.transpose(0, self.dim)
|
||||
n_blocks = list()
|
||||
for block_idx in range(self.n):
|
||||
start = block_idx * self.new_dim_size
|
||||
end = (block_idx + 1) * self.new_dim_size
|
||||
block = self.autopad(x[:, :, start:end, :])
|
||||
|
||||
n_blocks.append(block.transpose(0, self.dim))
|
||||
return n_blocks
|
||||
|
||||
|
||||
class Merger(nn.Module):
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
y = self.forward([torch.randn(self.in_shape)])
|
||||
return y.shape
|
||||
|
||||
def __init__(self, in_shape, n, dim=-1):
|
||||
super(Merger, self).__init__()
|
||||
|
||||
self.n = n
|
||||
self.dim = dim
|
||||
self.in_shape = in_shape
|
||||
|
||||
def forward(self, x):
|
||||
return torch.cat(x, dim=-2)
|
||||
return torch.cat(x, dim=self.dim)
|
||||
|
Reference in New Issue
Block a user