from state dict to direct parameter update

This commit is contained in:
Steffen Illium
2022-02-22 09:54:54 +01:00
parent 2a710b40d7
commit 78e9c4d520
4 changed files with 34 additions and 15 deletions

View File

@@ -45,7 +45,7 @@ from functionalities_test import test_for_fixpoints
WORKER = 10 if not debug else 2 WORKER = 10 if not debug else 2
debug = False debug = False
BATCHSIZE = 500 if not debug else 50 BATCHSIZE = 500 if not debug else 50
EPOCH = 200 EPOCH = 100
VALIDATION_FRQ = 3 if not debug else 1 VALIDATION_FRQ = 3 if not debug else 1
SELF_TRAIN_FRQ = 1 if not debug else 1 SELF_TRAIN_FRQ = 1 if not debug else 1
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -288,7 +288,7 @@ if __name__ == '__main__':
batch_train_beta = 1 batch_train_beta = 1
weight_hidden_size = 3 weight_hidden_size = 3
residual_skip = True residual_skip = True
n_seeds = 2 n_seeds = 5
data_path = Path('data') data_path = Path('data')
data_path.mkdir(exist_ok=True, parents=True) data_path.mkdir(exist_ok=True, parents=True)

View File

@@ -3,7 +3,6 @@ import copy
import random import random
from typing import Union from typing import Union
import numpy as np
import pandas as pd import pandas as pd
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -15,6 +14,7 @@ from tqdm import tqdm
def prng(): def prng():
return random.random() return random.random()
class FixTypes: class FixTypes:
divergent = 'divergent' divergent = 'divergent'
@@ -61,14 +61,12 @@ class Net(nn.Module):
def apply_weights(self, new_weights: Tensor): def apply_weights(self, new_weights: Tensor):
""" Changing the weights of a network to new given values. """ """ Changing the weights of a network to new given values. """
keys = self.state_dict().keys() with torch.no_grad():
shapes = [x.shape for x in self.state_dict().values()] i = 0
numels = np.cumsum([0, *[x.numel() for x in self.state_dict().values()]]) for parameters in self.parameters():
new_state_dict = {key: new_weights[start: end].view( size = parameters.numel()
shape) for key, shape, start, end in zip(keys, shapes, numels, numels[1:]) parameters[:] = new_weights[i:i+size].view(parameters.shape)[:]
} i += size
# noinspection PyTypeChecker
self.load_state_dict(new_state_dict)
return self return self
def __init__(self, i_size: int, h_size: int, o_size: int, name=None, start_time=1) -> None: def __init__(self, i_size: int, h_size: int, o_size: int, name=None, start_time=1) -> None:
@@ -164,7 +162,6 @@ class Net(nn.Module):
weight_matrix = torch.cat([x.view(-1, 1) for x in self.parameters()]) weight_matrix = torch.cat([x.view(-1, 1) for x in self.parameters()])
return weight_matrix return weight_matrix
def self_train(self, def self_train(self,
training_steps: int, training_steps: int,
log_step_size: int = 0, log_step_size: int = 0,
@@ -478,7 +475,7 @@ class MetaNet(nn.Module):
for cell in layer.meta_cell_list: for cell in layer.meta_cell_list:
# Individual replacement on cell lvl # Individual replacement on cell lvl
for weight in cell.meta_weight_list: for weight in cell.meta_weight_list:
weight.apply_weights(next(particle_weights_list)) weight.apply_weights(next(particle_weights_list).detach())
return self return self

View File

@@ -0,0 +1,22 @@
import torch
from network import MetaNet
from sparse_net import SparseNetwork
if __name__ == '__main__':
dense_metanet = MetaNet(30, depth=5, width=6, out=10, residual_skip=True,
weight_hidden_size=3, )
sparse_metanet = SparseNetwork(30, depth=5, width=6, out=10, residual_skip=True,
weight_hidden_size=3,)
particles = [torch.cat([x.view(-1) for x in x.parameters()]) for x in dense_metanet.particles]
# Transfer weights
sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles)
# Transfer weights
dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights)
new_particles = [torch.cat([x.view(-1) for x in x.parameters()]) for x in dense_metanet.particles]
print(f' Particles are same: {all([(x==y).all() for x,y in zip(particles, new_particles) ])}')

View File

@@ -55,8 +55,8 @@ if __name__ == '__main__':
mnist_test = MNIST(str(data_path), transform=MNIST_TRANSFORM, download=True, train=False) mnist_test = MNIST(str(data_path), transform=MNIST_TRANSFORM, download=True, train=False)
d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER) d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER)
loss_fn = nn.CrossEntropyLoss() loss_fn = nn.CrossEntropyLoss()
model_path = (Path() / r'experiments\output\mn_st_40_6_res_Tsk_0.85\trained_model_ckpt_e40.tp')
model = torch.load("mn_st_40_6_res_Tsk_0.85", map_location=DEVICE).eval() model = torch.load(model_path, map_location=DEVICE).eval()
weights = extract_weights_from_model(model) weights = extract_weights_from_model(model)
test_weights_as_model(weights, d_test) test_weights_as_model(weights, d_test)