from state dict to direct parameter update

This commit is contained in:
Steffen Illium 2022-02-22 09:54:54 +01:00
parent 2a710b40d7
commit 78e9c4d520
4 changed files with 34 additions and 15 deletions

View File

@ -45,7 +45,7 @@ from functionalities_test import test_for_fixpoints
WORKER = 10 if not debug else 2
debug = False
BATCHSIZE = 500 if not debug else 50
EPOCH = 200
EPOCH = 100
VALIDATION_FRQ = 3 if not debug else 1
SELF_TRAIN_FRQ = 1 if not debug else 1
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@ -288,7 +288,7 @@ if __name__ == '__main__':
batch_train_beta = 1
weight_hidden_size = 3
residual_skip = True
n_seeds = 2
n_seeds = 5
data_path = Path('data')
data_path.mkdir(exist_ok=True, parents=True)

View File

@ -3,7 +3,6 @@ import copy
import random
from typing import Union
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
@ -15,6 +14,7 @@ from tqdm import tqdm
def prng():
return random.random()
class FixTypes:
divergent = 'divergent'
@ -61,14 +61,12 @@ class Net(nn.Module):
def apply_weights(self, new_weights: Tensor):
""" Changing the weights of a network to new given values. """
keys = self.state_dict().keys()
shapes = [x.shape for x in self.state_dict().values()]
numels = np.cumsum([0, *[x.numel() for x in self.state_dict().values()]])
new_state_dict = {key: new_weights[start: end].view(
shape) for key, shape, start, end in zip(keys, shapes, numels, numels[1:])
}
# noinspection PyTypeChecker
self.load_state_dict(new_state_dict)
with torch.no_grad():
i = 0
for parameters in self.parameters():
size = parameters.numel()
parameters[:] = new_weights[i:i+size].view(parameters.shape)[:]
i += size
return self
def __init__(self, i_size: int, h_size: int, o_size: int, name=None, start_time=1) -> None:
@ -164,7 +162,6 @@ class Net(nn.Module):
weight_matrix = torch.cat([x.view(-1, 1) for x in self.parameters()])
return weight_matrix
def self_train(self,
training_steps: int,
log_step_size: int = 0,
@ -478,7 +475,7 @@ class MetaNet(nn.Module):
for cell in layer.meta_cell_list:
# Individual replacement on cell lvl
for weight in cell.meta_weight_list:
weight.apply_weights(next(particle_weights_list))
weight.apply_weights(next(particle_weights_list).detach())
return self

View File

@ -0,0 +1,22 @@
import torch
from network import MetaNet
from sparse_net import SparseNetwork
if __name__ == '__main__':
dense_metanet = MetaNet(30, depth=5, width=6, out=10, residual_skip=True,
weight_hidden_size=3, )
sparse_metanet = SparseNetwork(30, depth=5, width=6, out=10, residual_skip=True,
weight_hidden_size=3,)
particles = [torch.cat([x.view(-1) for x in x.parameters()]) for x in dense_metanet.particles]
# Transfer weights
sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles)
# Transfer weights
dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights)
new_particles = [torch.cat([x.view(-1) for x in x.parameters()]) for x in dense_metanet.particles]
print(f' Particles are same: {all([(x==y).all() for x,y in zip(particles, new_particles) ])}')

View File

@ -55,8 +55,8 @@ if __name__ == '__main__':
mnist_test = MNIST(str(data_path), transform=MNIST_TRANSFORM, download=True, train=False)
d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER)
loss_fn = nn.CrossEntropyLoss()
model = torch.load("mn_st_40_6_res_Tsk_0.85", map_location=DEVICE).eval()
model_path = (Path() / r'experiments\output\mn_st_40_6_res_Tsk_0.85\trained_model_ckpt_e40.tp')
model = torch.load(model_path, map_location=DEVICE).eval()
weights = extract_weights_from_model(model)
test_weights_as_model(weights, d_test)