New Self Replication

This commit is contained in:
Steffen Illium
2022-02-15 14:25:05 +01:00
parent 62e640e1f0
commit a4d1ee86dd

View File

@ -387,30 +387,30 @@ class MetaNet(nn.Module):
self.weight_hidden_size = weight_hidden_size
self.weight_output_size = weight_output_size
self._meta_layer_list = nn.ModuleList()
self._meta_layer_list.append(MetaLayer(name=f'L{0}',
interface=self.interface,
width=self.width, residual_skip=residual_skip,
weight_interface=weight_interface,
weight_hidden_size=weight_hidden_size,
weight_output_size=weight_output_size)
)
self._meta_layer_list.extend([MetaLayer(name=f'L{layer_idx + 1}',
interface=self.width, width=self.width, residual_skip=residual_skip,
weight_interface=weight_interface,
weight_hidden_size=weight_hidden_size,
weight_output_size=weight_output_size,
) for layer_idx in range(self.depth - 2)]
)
self._meta_layer_list.append(MetaLayer(name=f'L{len(self._meta_layer_list)}',
interface=self.width, width=self.out, residual_skip=residual_skip,
weight_interface=weight_interface,
weight_hidden_size=weight_hidden_size,
weight_output_size=weight_output_size,
)
)
self._meta_layer_first = MetaLayer(name=f'L{0}',
interface=self.interface,
width=self.width, residual_skip=residual_skip,
weight_interface=weight_interface,
weight_hidden_size=weight_hidden_size,
weight_output_size=weight_output_size)
self._meta_layer_list = nn.ModuleList([MetaLayer(name=f'L{layer_idx + 1}',
interface=self.width, width=self.width, residual_skip=residual_skip,
weight_interface=weight_interface,
weight_hidden_size=weight_hidden_size,
weight_output_size=weight_output_size,
) for layer_idx in range(self.depth - 2)]
)
self._meta_layer_last = MetaLayer(name=f'L{len(self._meta_layer_list)}',
interface=self.width, width=self.out, residual_skip=residual_skip,
weight_interface=weight_interface,
weight_hidden_size=weight_hidden_size,
weight_output_size=weight_output_size,
)
self.dropout_layer = nn.Dropout(p=self.dropout)
self._all_layers_with_particles = [self._meta_layer_first, *self._meta_layer_list, self._meta_layer_last]
def replace_with_zero(self, ident_key):
replaced_particles = 0
for particle in self.particles:
@ -423,16 +423,25 @@ class MetaNet(nn.Module):
return self
def forward(self, x):
tensor = x
for meta_layer in self._meta_layer_list:
if self.dropout_layer:
x = self.dropout_layer(x)
tensor = self._meta_layer_first(x)
for idx, meta_layer in enumerate(self._meta_layer_list, start=1):
if self.dropout:
tensor = self.dropout_layer(tensor)
if idx % 2 == 1:
x = tensor.clone()
tensor = meta_layer(tensor)
if idx % 2 == 0:
tensor = tensor + x
if self.dropout_layer:
x = self.dropout_layer(x)
tensor = self._meta_layer_last(x)
return tensor
@property
def particles(self):
return (cell for metalayer in self._meta_layer_list for cell in metalayer.particles)
return (cell for metalayer in self._all_layers_with_particles for cell in metalayer.particles)
def combined_self_train(self):
losses = []
@ -473,9 +482,9 @@ class MetaNetCompareBaseline(nn.Module):
if __name__ == '__main__':
metanet = MetaNet(interface=3, depth=5, width=3, out=1)
metanet = MetaNet(interface=3, depth=5, width=3, out=1, dropout=0.1, residual_skip=True)
next(metanet.particles).input_weight_matrix()
metanet(torch.hstack([torch.full((2, 1), x) for x in range(metanet.interface)]))
metanet(torch.hstack([torch.full((2, 1), 1.0) for _ in range(metanet.interface)]))
a = metanet.particles
print('Test')
print('Test')