MetaNetworks Debugged
This commit is contained in:
102
network.py
102
network.py
@ -34,17 +34,21 @@ class Net(nn.Module):
|
||||
def are_weights_diverged(network_weights):
|
||||
""" Testing if the weights are eiter converging to infinity or -infinity. """
|
||||
|
||||
for layer_id, layer in enumerate(network_weights):
|
||||
for cell_id, cell in enumerate(layer):
|
||||
for weight_id, weight in enumerate(cell):
|
||||
if torch.isnan(weight):
|
||||
return True
|
||||
if torch.isinf(weight):
|
||||
return True
|
||||
return False
|
||||
# Slow and shitty:
|
||||
# for layer_id, layer in enumerate(network_weights):
|
||||
# for cell_id, cell in enumerate(layer):
|
||||
# for weight_id, weight in enumerate(cell):
|
||||
# if torch.isnan(weight):
|
||||
# return True
|
||||
# if torch.isinf(weight):
|
||||
# return True
|
||||
# return False
|
||||
# Fast and modern:
|
||||
return any(x.isnan.any() or x.isinf().any() for x in network_weights.parameters)
|
||||
|
||||
def apply_weights(self, new_weights: Tensor):
|
||||
""" Changing the weights of a network to new given values. """
|
||||
# TODO: Change this to 'parameters' version
|
||||
i = 0
|
||||
for layer_id, layer_name in enumerate(self.state_dict()):
|
||||
for line_id, line_values in enumerate(self.state_dict()[layer_name]):
|
||||
@ -101,15 +105,17 @@ class Net(nn.Module):
|
||||
# Cell Enumeration
|
||||
torch.arange(layer.out_features, device=d).repeat_interleave(layer.in_features).view(-1, 1),
|
||||
# Weight Enumeration within the Cells
|
||||
torch.arange(layer.in_features, device=d).view(-1, 1).repeat(layer.out_features, 1)
|
||||
torch.arange(layer.in_features, device=d).view(-1, 1).repeat(layer.out_features, 1),
|
||||
*(torch.full((x.numel(), 1), 0, device=d) for _ in range(self.input_size-4))
|
||||
), dim=1)
|
||||
)
|
||||
# Finalize
|
||||
weight_matrix = torch.cat(weight_matrix).float()
|
||||
|
||||
# Normalize all along the 1 dimensions
|
||||
norm2 = weight_matrix[:, 1:].pow(2).sum(keepdim=True, dim=0).sqrt()
|
||||
weight_matrix[:, 1:] = weight_matrix[:, 1:] / norm2
|
||||
# Normalize 1,2,3 column of dim 1
|
||||
last_pos_idx = self.input_size - 4
|
||||
norm2 = weight_matrix[:, 1:-last_pos_idx].pow(2).sum(keepdim=True, dim=0).sqrt()
|
||||
weight_matrix[:, 1:-last_pos_idx] = (weight_matrix[:, 1:-last_pos_idx] / norm2) + 1e-8
|
||||
|
||||
# computations
|
||||
# create a mask where pos is 0 if it is to be replaced
|
||||
@ -117,7 +123,7 @@ class Net(nn.Module):
|
||||
mask[:, 0] = 0
|
||||
|
||||
self._weight_pos_enc_and_mask = weight_matrix, mask
|
||||
return self._weight_pos_enc_and_mask
|
||||
return tuple(x.clone() for x in self._weight_pos_enc_and_mask)
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
@ -125,6 +131,7 @@ class Net(nn.Module):
|
||||
return x
|
||||
|
||||
def normalize(self, value, norm):
|
||||
raise NotImplementedError
|
||||
# FIXME, This is bullshit, the code does not do what the docstring explains
|
||||
# Obsolete now
|
||||
""" Normalizing the values >= 1 and adding pow(10, -8) to the values equal to 0 """
|
||||
@ -138,7 +145,7 @@ class Net(nn.Module):
|
||||
""" Calculating the input tensor formed from the weights of the net """
|
||||
weight_matrix = torch.cat([x.view(-1, 1) for x in self.parameters()])
|
||||
pos_enc, mask = self._weight_pos_enc
|
||||
weight_matrix = pos_enc * mask + weight_matrix.expand(-1, 4) * (1 - mask)
|
||||
weight_matrix = pos_enc * mask + weight_matrix.expand(-1, pos_enc.shape[-1]) * (1 - mask)
|
||||
return weight_matrix
|
||||
|
||||
def self_train(self,
|
||||
@ -283,33 +290,50 @@ class SecondaryNet(Net):
|
||||
|
||||
|
||||
class MetaCell(nn.Module):
|
||||
def __init__(self, name, interface, residual_skip=True):
|
||||
def __init__(self, name, interface):
|
||||
super().__init__()
|
||||
self.residual_skip = residual_skip
|
||||
self.name = name
|
||||
self.interface = interface
|
||||
self.weight_interface = 4
|
||||
self.weight_interface = 5
|
||||
self.net_hidden_size = 4
|
||||
self.net_ouput_size = 1
|
||||
self.meta_weight_list = nn.ModuleList()
|
||||
self.meta_weight_list.extend(
|
||||
[Net(self.weight_interface, self.net_hidden_size,
|
||||
self.net_ouput_size, name=f'{self.name}_{weight_idx}'
|
||||
self.net_ouput_size, name=f'{self.name}_W{weight_idx}'
|
||||
) for weight_idx in range(self.interface)]
|
||||
)
|
||||
self.__bed_mask = None
|
||||
|
||||
@property
|
||||
def _bed_mask(self):
|
||||
if self.__bed_mask is None:
|
||||
d = next(self.parameters()).device
|
||||
embedding = torch.zeros(1, self.weight_interface, device=d)
|
||||
|
||||
# computations
|
||||
# create a mask where pos is 0 if it is to be replaced
|
||||
mask = torch.ones_like(embedding)
|
||||
mask[:, -1] = 0
|
||||
|
||||
self.__bed_mask = embedding, mask
|
||||
return tuple(x.clone() for x in self.__bed_mask)
|
||||
|
||||
def forward(self, x):
|
||||
xs = [torch.hstack(
|
||||
(torch.zeros((x.shape[0], self.weight_interface - 1), device=x.device), x[:, idx].unsqueeze(-1))
|
||||
)
|
||||
for idx in range(len(self.meta_weight_list))]
|
||||
tensor = torch.hstack([meta_weight(xs[idx]) for idx, meta_weight in enumerate(self.meta_weight_list)])
|
||||
embedding, mask = self._bed_mask
|
||||
expanded_mask = mask.expand(*x.shape, embedding.shape[-1])
|
||||
embedding = embedding.repeat(*x.shape, 1)
|
||||
|
||||
if self.residual_skip:
|
||||
tensor += x
|
||||
# Row-wise
|
||||
# xs = x.unsqueeze(-1).expand(-1, -1, embedding.shape[-1]).swapdims(0, 1)
|
||||
# Column-wise
|
||||
xs = x.unsqueeze(-1).expand(-1, -1, embedding.shape[-1])
|
||||
xs = embedding * expanded_mask + xs * (1 - expanded_mask)
|
||||
# ToDo Speed this up!
|
||||
tensor = torch.hstack([meta_weight(xs[:, idx, :]) for idx, meta_weight in enumerate(self.meta_weight_list)])
|
||||
|
||||
result = torch.sum(tensor, dim=-1, keepdim=True)
|
||||
return result
|
||||
tensor = torch.sum(tensor, dim=-1, keepdim=True)
|
||||
return tensor
|
||||
|
||||
@property
|
||||
def particles(self):
|
||||
@ -317,21 +341,27 @@ class MetaCell(nn.Module):
|
||||
|
||||
|
||||
class MetaLayer(nn.Module):
|
||||
def __init__(self, name, interface=4, width=4):
|
||||
def __init__(self, name, interface=4, width=4, residual_skip=True):
|
||||
super().__init__()
|
||||
self.residual_skip = residual_skip
|
||||
self.name = name
|
||||
self.interface = interface
|
||||
self.width = width
|
||||
|
||||
self.meta_cell_list = nn.ModuleList()
|
||||
self.meta_cell_list.extend([MetaCell(name=f'{self.name}_{cell_idx}',
|
||||
self.meta_cell_list.extend([MetaCell(name=f'{self.name}_C{cell_idx}',
|
||||
interface=interface
|
||||
) for cell_idx in range(self.width)]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
result = torch.hstack([metacell(x) for metacell in self.meta_cell_list])
|
||||
return result
|
||||
cell_results = []
|
||||
for metacell in self.meta_cell_list:
|
||||
cell_results.append(metacell(x))
|
||||
tensor = torch.hstack(cell_results)
|
||||
if self.residual_skip and x.shape == tensor.shape:
|
||||
tensor += x
|
||||
return tensor
|
||||
|
||||
@property
|
||||
def particles(self):
|
||||
@ -349,15 +379,15 @@ class MetaNet(nn.Module):
|
||||
self.depth = depth
|
||||
|
||||
self._meta_layer_list = nn.ModuleList()
|
||||
self._meta_layer_list.append(MetaLayer(name=f'Weight_{0}',
|
||||
self._meta_layer_list.append(MetaLayer(name=f'L{0}',
|
||||
interface=self.interface,
|
||||
width=self.width)
|
||||
)
|
||||
self._meta_layer_list.extend([MetaLayer(name=f'Weight_{layer_idx + 1}',
|
||||
self._meta_layer_list.extend([MetaLayer(name=f'L{layer_idx + 1}',
|
||||
interface=self.width, width=self.width
|
||||
) for layer_idx in range(self.depth - 2)]
|
||||
)
|
||||
self._meta_layer_list.append(MetaLayer(name=f'Weight_{len(self._meta_layer_list)}',
|
||||
self._meta_layer_list.append(MetaLayer(name=f'L{len(self._meta_layer_list)}',
|
||||
interface=self.width, width=self.out)
|
||||
)
|
||||
|
||||
@ -383,9 +413,9 @@ class MetaNet(nn.Module):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
metanet = MetaNet(interface=2, depth=3, width=2, out=1)
|
||||
metanet = MetaNet(interface=3, depth=5, width=3, out=1)
|
||||
next(metanet.particles).input_weight_matrix()
|
||||
metanet(torch.ones((5, 2)))
|
||||
metanet(torch.hstack([torch.full((2, 1), x) for x in range(metanet.interface)]))
|
||||
a = metanet.particles
|
||||
print('Test')
|
||||
print('Test')
|
||||
|
Reference in New Issue
Block a user