From 78e9c4d5205c973e2ebd5f28c40ed4a3884b0404 Mon Sep 17 00:00:00 2001 From: Steffen Illium Date: Tue, 22 Feb 2022 09:54:54 +0100 Subject: [PATCH] from state dict to direct parameter update --- experiments/meta_task_exp.py | 4 ++-- network.py | 19 ++++++++----------- sanity_check_particle_weight_swap.py | 22 ++++++++++++++++++++++ sanity_check_weights.py | 4 ++-- 4 files changed, 34 insertions(+), 15 deletions(-) create mode 100644 sanity_check_particle_weight_swap.py diff --git a/experiments/meta_task_exp.py b/experiments/meta_task_exp.py index 25607d4..60852d1 100644 --- a/experiments/meta_task_exp.py +++ b/experiments/meta_task_exp.py @@ -45,7 +45,7 @@ from functionalities_test import test_for_fixpoints WORKER = 10 if not debug else 2 debug = False BATCHSIZE = 500 if not debug else 50 -EPOCH = 200 +EPOCH = 100 VALIDATION_FRQ = 3 if not debug else 1 SELF_TRAIN_FRQ = 1 if not debug else 1 DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -288,7 +288,7 @@ if __name__ == '__main__': batch_train_beta = 1 weight_hidden_size = 3 residual_skip = True - n_seeds = 2 + n_seeds = 5 data_path = Path('data') data_path.mkdir(exist_ok=True, parents=True) diff --git a/network.py b/network.py index 3c8ecd8..053194c 100644 --- a/network.py +++ b/network.py @@ -3,7 +3,6 @@ import copy import random from typing import Union -import numpy as np import pandas as pd import torch import torch.nn as nn @@ -15,6 +14,7 @@ from tqdm import tqdm def prng(): return random.random() + class FixTypes: divergent = 'divergent' @@ -61,14 +61,12 @@ class Net(nn.Module): def apply_weights(self, new_weights: Tensor): """ Changing the weights of a network to new given values. """ - keys = self.state_dict().keys() - shapes = [x.shape for x in self.state_dict().values()] - numels = np.cumsum([0, *[x.numel() for x in self.state_dict().values()]]) - new_state_dict = {key: new_weights[start: end].view( - shape) for key, shape, start, end in zip(keys, shapes, numels, numels[1:]) - } - # noinspection PyTypeChecker - self.load_state_dict(new_state_dict) + with torch.no_grad(): + i = 0 + for parameters in self.parameters(): + size = parameters.numel() + parameters[:] = new_weights[i:i+size].view(parameters.shape)[:] + i += size return self def __init__(self, i_size: int, h_size: int, o_size: int, name=None, start_time=1) -> None: @@ -164,7 +162,6 @@ class Net(nn.Module): weight_matrix = torch.cat([x.view(-1, 1) for x in self.parameters()]) return weight_matrix - def self_train(self, training_steps: int, log_step_size: int = 0, @@ -478,7 +475,7 @@ class MetaNet(nn.Module): for cell in layer.meta_cell_list: # Individual replacement on cell lvl for weight in cell.meta_weight_list: - weight.apply_weights(next(particle_weights_list)) + weight.apply_weights(next(particle_weights_list).detach()) return self diff --git a/sanity_check_particle_weight_swap.py b/sanity_check_particle_weight_swap.py new file mode 100644 index 0000000..49a432c --- /dev/null +++ b/sanity_check_particle_weight_swap.py @@ -0,0 +1,22 @@ +import torch + +from network import MetaNet +from sparse_net import SparseNetwork + + +if __name__ == '__main__': + dense_metanet = MetaNet(30, depth=5, width=6, out=10, residual_skip=True, + weight_hidden_size=3, ) + sparse_metanet = SparseNetwork(30, depth=5, width=6, out=10, residual_skip=True, + weight_hidden_size=3,) + + particles = [torch.cat([x.view(-1) for x in x.parameters()]) for x in dense_metanet.particles] + + # Transfer weights + sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles) + + # Transfer weights + dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights) + new_particles = [torch.cat([x.view(-1) for x in x.parameters()]) for x in dense_metanet.particles] + + print(f' Particles are same: {all([(x==y).all() for x,y in zip(particles, new_particles) ])}') diff --git a/sanity_check_weights.py b/sanity_check_weights.py index 9610f1b..32f789b 100644 --- a/sanity_check_weights.py +++ b/sanity_check_weights.py @@ -55,8 +55,8 @@ if __name__ == '__main__': mnist_test = MNIST(str(data_path), transform=MNIST_TRANSFORM, download=True, train=False) d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER) loss_fn = nn.CrossEntropyLoss() - - model = torch.load("mn_st_40_6_res_Tsk_0.85", map_location=DEVICE).eval() + model_path = (Path() / r'experiments\output\mn_st_40_6_res_Tsk_0.85\trained_model_ckpt_e40.tp') + model = torch.load(model_path, map_location=DEVICE).eval() weights = extract_weights_from_model(model) test_weights_as_model(weights, d_test)