New Self Replication
This commit is contained in:
63
network.py
63
network.py
@@ -387,30 +387,30 @@ class MetaNet(nn.Module):
|
|||||||
self.weight_hidden_size = weight_hidden_size
|
self.weight_hidden_size = weight_hidden_size
|
||||||
self.weight_output_size = weight_output_size
|
self.weight_output_size = weight_output_size
|
||||||
|
|
||||||
self._meta_layer_list = nn.ModuleList()
|
self._meta_layer_first = MetaLayer(name=f'L{0}',
|
||||||
self._meta_layer_list.append(MetaLayer(name=f'L{0}',
|
interface=self.interface,
|
||||||
interface=self.interface,
|
width=self.width, residual_skip=residual_skip,
|
||||||
width=self.width, residual_skip=residual_skip,
|
weight_interface=weight_interface,
|
||||||
weight_interface=weight_interface,
|
weight_hidden_size=weight_hidden_size,
|
||||||
weight_hidden_size=weight_hidden_size,
|
weight_output_size=weight_output_size)
|
||||||
weight_output_size=weight_output_size)
|
|
||||||
)
|
self._meta_layer_list = nn.ModuleList([MetaLayer(name=f'L{layer_idx + 1}',
|
||||||
self._meta_layer_list.extend([MetaLayer(name=f'L{layer_idx + 1}',
|
interface=self.width, width=self.width, residual_skip=residual_skip,
|
||||||
interface=self.width, width=self.width, residual_skip=residual_skip,
|
weight_interface=weight_interface,
|
||||||
weight_interface=weight_interface,
|
weight_hidden_size=weight_hidden_size,
|
||||||
weight_hidden_size=weight_hidden_size,
|
weight_output_size=weight_output_size,
|
||||||
weight_output_size=weight_output_size,
|
) for layer_idx in range(self.depth - 2)]
|
||||||
) for layer_idx in range(self.depth - 2)]
|
)
|
||||||
)
|
self._meta_layer_last = MetaLayer(name=f'L{len(self._meta_layer_list)}',
|
||||||
self._meta_layer_list.append(MetaLayer(name=f'L{len(self._meta_layer_list)}',
|
interface=self.width, width=self.out, residual_skip=residual_skip,
|
||||||
interface=self.width, width=self.out, residual_skip=residual_skip,
|
weight_interface=weight_interface,
|
||||||
weight_interface=weight_interface,
|
weight_hidden_size=weight_hidden_size,
|
||||||
weight_hidden_size=weight_hidden_size,
|
weight_output_size=weight_output_size,
|
||||||
weight_output_size=weight_output_size,
|
)
|
||||||
)
|
|
||||||
)
|
|
||||||
self.dropout_layer = nn.Dropout(p=self.dropout)
|
self.dropout_layer = nn.Dropout(p=self.dropout)
|
||||||
|
|
||||||
|
self._all_layers_with_particles = [self._meta_layer_first, *self._meta_layer_list, self._meta_layer_last]
|
||||||
|
|
||||||
def replace_with_zero(self, ident_key):
|
def replace_with_zero(self, ident_key):
|
||||||
replaced_particles = 0
|
replaced_particles = 0
|
||||||
for particle in self.particles:
|
for particle in self.particles:
|
||||||
@@ -423,16 +423,25 @@ class MetaNet(nn.Module):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
tensor = x
|
if self.dropout_layer:
|
||||||
for meta_layer in self._meta_layer_list:
|
x = self.dropout_layer(x)
|
||||||
|
tensor = self._meta_layer_first(x)
|
||||||
|
for idx, meta_layer in enumerate(self._meta_layer_list, start=1):
|
||||||
if self.dropout:
|
if self.dropout:
|
||||||
tensor = self.dropout_layer(tensor)
|
tensor = self.dropout_layer(tensor)
|
||||||
|
if idx % 2 == 1:
|
||||||
|
x = tensor.clone()
|
||||||
tensor = meta_layer(tensor)
|
tensor = meta_layer(tensor)
|
||||||
|
if idx % 2 == 0:
|
||||||
|
tensor = tensor + x
|
||||||
|
if self.dropout_layer:
|
||||||
|
x = self.dropout_layer(x)
|
||||||
|
tensor = self._meta_layer_last(x)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def particles(self):
|
def particles(self):
|
||||||
return (cell for metalayer in self._meta_layer_list for cell in metalayer.particles)
|
return (cell for metalayer in self._all_layers_with_particles for cell in metalayer.particles)
|
||||||
|
|
||||||
def combined_self_train(self):
|
def combined_self_train(self):
|
||||||
losses = []
|
losses = []
|
||||||
@@ -473,9 +482,9 @@ class MetaNetCompareBaseline(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
metanet = MetaNet(interface=3, depth=5, width=3, out=1)
|
metanet = MetaNet(interface=3, depth=5, width=3, out=1, dropout=0.1, residual_skip=True)
|
||||||
next(metanet.particles).input_weight_matrix()
|
next(metanet.particles).input_weight_matrix()
|
||||||
metanet(torch.hstack([torch.full((2, 1), x) for x in range(metanet.interface)]))
|
metanet(torch.hstack([torch.full((2, 1), 1.0) for _ in range(metanet.interface)]))
|
||||||
a = metanet.particles
|
a = metanet.particles
|
||||||
print('Test')
|
print('Test')
|
||||||
print('Test')
|
print('Test')
|
||||||
|
|||||||
Reference in New Issue
Block a user