self-replicating-neural-net.../sanity_check_particle_weight_swap.py
2022-02-22 09:54:54 +01:00

23 lines
907 B
Python

import torch
from network import MetaNet
from sparse_net import SparseNetwork
if __name__ == '__main__':
dense_metanet = MetaNet(30, depth=5, width=6, out=10, residual_skip=True,
weight_hidden_size=3, )
sparse_metanet = SparseNetwork(30, depth=5, width=6, out=10, residual_skip=True,
weight_hidden_size=3,)
particles = [torch.cat([x.view(-1) for x in x.parameters()]) for x in dense_metanet.particles]
# Transfer weights
sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles)
# Transfer weights
dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights)
new_particles = [torch.cat([x.view(-1) for x in x.parameters()]) for x in dense_metanet.particles]
print(f' Particles are same: {all([(x==y).all() for x,y in zip(particles, new_particles) ])}')