from state dict to direct parameter update
This commit is contained in:
parent
2a710b40d7
commit
78e9c4d520
@ -45,7 +45,7 @@ from functionalities_test import test_for_fixpoints
|
||||
WORKER = 10 if not debug else 2
|
||||
debug = False
|
||||
BATCHSIZE = 500 if not debug else 50
|
||||
EPOCH = 200
|
||||
EPOCH = 100
|
||||
VALIDATION_FRQ = 3 if not debug else 1
|
||||
SELF_TRAIN_FRQ = 1 if not debug else 1
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
@ -288,7 +288,7 @@ if __name__ == '__main__':
|
||||
batch_train_beta = 1
|
||||
weight_hidden_size = 3
|
||||
residual_skip = True
|
||||
n_seeds = 2
|
||||
n_seeds = 5
|
||||
|
||||
data_path = Path('data')
|
||||
data_path.mkdir(exist_ok=True, parents=True)
|
||||
|
19
network.py
19
network.py
@ -3,7 +3,6 @@ import copy
|
||||
import random
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -15,6 +14,7 @@ from tqdm import tqdm
|
||||
def prng():
|
||||
return random.random()
|
||||
|
||||
|
||||
class FixTypes:
|
||||
|
||||
divergent = 'divergent'
|
||||
@ -61,14 +61,12 @@ class Net(nn.Module):
|
||||
|
||||
def apply_weights(self, new_weights: Tensor):
|
||||
""" Changing the weights of a network to new given values. """
|
||||
keys = self.state_dict().keys()
|
||||
shapes = [x.shape for x in self.state_dict().values()]
|
||||
numels = np.cumsum([0, *[x.numel() for x in self.state_dict().values()]])
|
||||
new_state_dict = {key: new_weights[start: end].view(
|
||||
shape) for key, shape, start, end in zip(keys, shapes, numels, numels[1:])
|
||||
}
|
||||
# noinspection PyTypeChecker
|
||||
self.load_state_dict(new_state_dict)
|
||||
with torch.no_grad():
|
||||
i = 0
|
||||
for parameters in self.parameters():
|
||||
size = parameters.numel()
|
||||
parameters[:] = new_weights[i:i+size].view(parameters.shape)[:]
|
||||
i += size
|
||||
return self
|
||||
|
||||
def __init__(self, i_size: int, h_size: int, o_size: int, name=None, start_time=1) -> None:
|
||||
@ -164,7 +162,6 @@ class Net(nn.Module):
|
||||
weight_matrix = torch.cat([x.view(-1, 1) for x in self.parameters()])
|
||||
return weight_matrix
|
||||
|
||||
|
||||
def self_train(self,
|
||||
training_steps: int,
|
||||
log_step_size: int = 0,
|
||||
@ -478,7 +475,7 @@ class MetaNet(nn.Module):
|
||||
for cell in layer.meta_cell_list:
|
||||
# Individual replacement on cell lvl
|
||||
for weight in cell.meta_weight_list:
|
||||
weight.apply_weights(next(particle_weights_list))
|
||||
weight.apply_weights(next(particle_weights_list).detach())
|
||||
return self
|
||||
|
||||
|
||||
|
22
sanity_check_particle_weight_swap.py
Normal file
22
sanity_check_particle_weight_swap.py
Normal file
@ -0,0 +1,22 @@
|
||||
import torch
|
||||
|
||||
from network import MetaNet
|
||||
from sparse_net import SparseNetwork
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dense_metanet = MetaNet(30, depth=5, width=6, out=10, residual_skip=True,
|
||||
weight_hidden_size=3, )
|
||||
sparse_metanet = SparseNetwork(30, depth=5, width=6, out=10, residual_skip=True,
|
||||
weight_hidden_size=3,)
|
||||
|
||||
particles = [torch.cat([x.view(-1) for x in x.parameters()]) for x in dense_metanet.particles]
|
||||
|
||||
# Transfer weights
|
||||
sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles)
|
||||
|
||||
# Transfer weights
|
||||
dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights)
|
||||
new_particles = [torch.cat([x.view(-1) for x in x.parameters()]) for x in dense_metanet.particles]
|
||||
|
||||
print(f' Particles are same: {all([(x==y).all() for x,y in zip(particles, new_particles) ])}')
|
@ -55,8 +55,8 @@ if __name__ == '__main__':
|
||||
mnist_test = MNIST(str(data_path), transform=MNIST_TRANSFORM, download=True, train=False)
|
||||
d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER)
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
model = torch.load("mn_st_40_6_res_Tsk_0.85", map_location=DEVICE).eval()
|
||||
model_path = (Path() / r'experiments\output\mn_st_40_6_res_Tsk_0.85\trained_model_ckpt_e40.tp')
|
||||
model = torch.load(model_path, map_location=DEVICE).eval()
|
||||
weights = extract_weights_from_model(model)
|
||||
test_weights_as_model(weights, d_test)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user