MetaNetworks Debugged

This commit is contained in:
Steffen Illium
2022-01-31 10:35:11 +01:00
parent 49c0d8a621
commit 246d825bb4
8 changed files with 169 additions and 109 deletions

@ -34,17 +34,21 @@ class Net(nn.Module):
def are_weights_diverged(network_weights):
""" Testing if the weights are eiter converging to infinity or -infinity. """
for layer_id, layer in enumerate(network_weights):
for cell_id, cell in enumerate(layer):
for weight_id, weight in enumerate(cell):
if torch.isnan(weight):
return True
if torch.isinf(weight):
return True
return False
# Slow and shitty:
# for layer_id, layer in enumerate(network_weights):
# for cell_id, cell in enumerate(layer):
# for weight_id, weight in enumerate(cell):
# if torch.isnan(weight):
# return True
# if torch.isinf(weight):
# return True
# return False
# Fast and modern:
return any(x.isnan.any() or x.isinf().any() for x in network_weights.parameters)
def apply_weights(self, new_weights: Tensor):
""" Changing the weights of a network to new given values. """
# TODO: Change this to 'parameters' version
i = 0
for layer_id, layer_name in enumerate(self.state_dict()):
for line_id, line_values in enumerate(self.state_dict()[layer_name]):
@ -101,15 +105,17 @@ class Net(nn.Module):
# Cell Enumeration
torch.arange(layer.out_features, device=d).repeat_interleave(layer.in_features).view(-1, 1),
# Weight Enumeration within the Cells
torch.arange(layer.in_features, device=d).view(-1, 1).repeat(layer.out_features, 1)
torch.arange(layer.in_features, device=d).view(-1, 1).repeat(layer.out_features, 1),
*(torch.full((x.numel(), 1), 0, device=d) for _ in range(self.input_size-4))
), dim=1)
)
# Finalize
weight_matrix = torch.cat(weight_matrix).float()
# Normalize all along the 1 dimensions
norm2 = weight_matrix[:, 1:].pow(2).sum(keepdim=True, dim=0).sqrt()
weight_matrix[:, 1:] = weight_matrix[:, 1:] / norm2
# Normalize 1,2,3 column of dim 1
last_pos_idx = self.input_size - 4
norm2 = weight_matrix[:, 1:-last_pos_idx].pow(2).sum(keepdim=True, dim=0).sqrt()
weight_matrix[:, 1:-last_pos_idx] = (weight_matrix[:, 1:-last_pos_idx] / norm2) + 1e-8
# computations
# create a mask where pos is 0 if it is to be replaced
@ -117,7 +123,7 @@ class Net(nn.Module):
mask[:, 0] = 0
self._weight_pos_enc_and_mask = weight_matrix, mask
return self._weight_pos_enc_and_mask
return tuple(x.clone() for x in self._weight_pos_enc_and_mask)
def forward(self, x):
for layer in self.layers:
@ -125,6 +131,7 @@ class Net(nn.Module):
return x
def normalize(self, value, norm):
raise NotImplementedError
# FIXME, This is bullshit, the code does not do what the docstring explains
# Obsolete now
""" Normalizing the values >= 1 and adding pow(10, -8) to the values equal to 0 """
@ -138,7 +145,7 @@ class Net(nn.Module):
""" Calculating the input tensor formed from the weights of the net """
weight_matrix = torch.cat([x.view(-1, 1) for x in self.parameters()])
pos_enc, mask = self._weight_pos_enc
weight_matrix = pos_enc * mask + weight_matrix.expand(-1, 4) * (1 - mask)
weight_matrix = pos_enc * mask + weight_matrix.expand(-1, pos_enc.shape[-1]) * (1 - mask)
return weight_matrix
def self_train(self,
@ -283,33 +290,50 @@ class SecondaryNet(Net):
class MetaCell(nn.Module):
def __init__(self, name, interface, residual_skip=True):
def __init__(self, name, interface):
super().__init__()
self.residual_skip = residual_skip
self.name = name
self.interface = interface
self.weight_interface = 4
self.weight_interface = 5
self.net_hidden_size = 4
self.net_ouput_size = 1
self.meta_weight_list = nn.ModuleList()
self.meta_weight_list.extend(
[Net(self.weight_interface, self.net_hidden_size,
self.net_ouput_size, name=f'{self.name}_{weight_idx}'
self.net_ouput_size, name=f'{self.name}_W{weight_idx}'
) for weight_idx in range(self.interface)]
)
self.__bed_mask = None
@property
def _bed_mask(self):
if self.__bed_mask is None:
d = next(self.parameters()).device
embedding = torch.zeros(1, self.weight_interface, device=d)
# computations
# create a mask where pos is 0 if it is to be replaced
mask = torch.ones_like(embedding)
mask[:, -1] = 0
self.__bed_mask = embedding, mask
return tuple(x.clone() for x in self.__bed_mask)
def forward(self, x):
xs = [torch.hstack(
(torch.zeros((x.shape[0], self.weight_interface - 1), device=x.device), x[:, idx].unsqueeze(-1))
)
for idx in range(len(self.meta_weight_list))]
tensor = torch.hstack([meta_weight(xs[idx]) for idx, meta_weight in enumerate(self.meta_weight_list)])
embedding, mask = self._bed_mask
expanded_mask = mask.expand(*x.shape, embedding.shape[-1])
embedding = embedding.repeat(*x.shape, 1)
if self.residual_skip:
tensor += x
# Row-wise
# xs = x.unsqueeze(-1).expand(-1, -1, embedding.shape[-1]).swapdims(0, 1)
# Column-wise
xs = x.unsqueeze(-1).expand(-1, -1, embedding.shape[-1])
xs = embedding * expanded_mask + xs * (1 - expanded_mask)
# ToDo Speed this up!
tensor = torch.hstack([meta_weight(xs[:, idx, :]) for idx, meta_weight in enumerate(self.meta_weight_list)])
result = torch.sum(tensor, dim=-1, keepdim=True)
return result
tensor = torch.sum(tensor, dim=-1, keepdim=True)
return tensor
@property
def particles(self):
@ -317,21 +341,27 @@ class MetaCell(nn.Module):
class MetaLayer(nn.Module):
def __init__(self, name, interface=4, width=4):
def __init__(self, name, interface=4, width=4, residual_skip=True):
super().__init__()
self.residual_skip = residual_skip
self.name = name
self.interface = interface
self.width = width
self.meta_cell_list = nn.ModuleList()
self.meta_cell_list.extend([MetaCell(name=f'{self.name}_{cell_idx}',
self.meta_cell_list.extend([MetaCell(name=f'{self.name}_C{cell_idx}',
interface=interface
) for cell_idx in range(self.width)]
)
def forward(self, x):
result = torch.hstack([metacell(x) for metacell in self.meta_cell_list])
return result
cell_results = []
for metacell in self.meta_cell_list:
cell_results.append(metacell(x))
tensor = torch.hstack(cell_results)
if self.residual_skip and x.shape == tensor.shape:
tensor += x
return tensor
@property
def particles(self):
@ -349,15 +379,15 @@ class MetaNet(nn.Module):
self.depth = depth
self._meta_layer_list = nn.ModuleList()
self._meta_layer_list.append(MetaLayer(name=f'Weight_{0}',
self._meta_layer_list.append(MetaLayer(name=f'L{0}',
interface=self.interface,
width=self.width)
)
self._meta_layer_list.extend([MetaLayer(name=f'Weight_{layer_idx + 1}',
self._meta_layer_list.extend([MetaLayer(name=f'L{layer_idx + 1}',
interface=self.width, width=self.width
) for layer_idx in range(self.depth - 2)]
)
self._meta_layer_list.append(MetaLayer(name=f'Weight_{len(self._meta_layer_list)}',
self._meta_layer_list.append(MetaLayer(name=f'L{len(self._meta_layer_list)}',
interface=self.width, width=self.out)
)
@ -383,9 +413,9 @@ class MetaNet(nn.Module):
if __name__ == '__main__':
metanet = MetaNet(interface=2, depth=3, width=2, out=1)
metanet = MetaNet(interface=3, depth=5, width=3, out=1)
next(metanet.particles).input_weight_matrix()
metanet(torch.ones((5, 2)))
metanet(torch.hstack([torch.full((2, 1), x) for x in range(metanet.interface)]))
a = metanet.particles
print('Test')
print('Test')