from state dict to direct parameter update
This commit is contained in:
@@ -45,7 +45,7 @@ from functionalities_test import test_for_fixpoints
|
|||||||
WORKER = 10 if not debug else 2
|
WORKER = 10 if not debug else 2
|
||||||
debug = False
|
debug = False
|
||||||
BATCHSIZE = 500 if not debug else 50
|
BATCHSIZE = 500 if not debug else 50
|
||||||
EPOCH = 200
|
EPOCH = 100
|
||||||
VALIDATION_FRQ = 3 if not debug else 1
|
VALIDATION_FRQ = 3 if not debug else 1
|
||||||
SELF_TRAIN_FRQ = 1 if not debug else 1
|
SELF_TRAIN_FRQ = 1 if not debug else 1
|
||||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
@@ -288,7 +288,7 @@ if __name__ == '__main__':
|
|||||||
batch_train_beta = 1
|
batch_train_beta = 1
|
||||||
weight_hidden_size = 3
|
weight_hidden_size = 3
|
||||||
residual_skip = True
|
residual_skip = True
|
||||||
n_seeds = 2
|
n_seeds = 5
|
||||||
|
|
||||||
data_path = Path('data')
|
data_path = Path('data')
|
||||||
data_path.mkdir(exist_ok=True, parents=True)
|
data_path.mkdir(exist_ok=True, parents=True)
|
||||||
|
|||||||
19
network.py
19
network.py
@@ -3,7 +3,6 @@ import copy
|
|||||||
import random
|
import random
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -15,6 +14,7 @@ from tqdm import tqdm
|
|||||||
def prng():
|
def prng():
|
||||||
return random.random()
|
return random.random()
|
||||||
|
|
||||||
|
|
||||||
class FixTypes:
|
class FixTypes:
|
||||||
|
|
||||||
divergent = 'divergent'
|
divergent = 'divergent'
|
||||||
@@ -61,14 +61,12 @@ class Net(nn.Module):
|
|||||||
|
|
||||||
def apply_weights(self, new_weights: Tensor):
|
def apply_weights(self, new_weights: Tensor):
|
||||||
""" Changing the weights of a network to new given values. """
|
""" Changing the weights of a network to new given values. """
|
||||||
keys = self.state_dict().keys()
|
with torch.no_grad():
|
||||||
shapes = [x.shape for x in self.state_dict().values()]
|
i = 0
|
||||||
numels = np.cumsum([0, *[x.numel() for x in self.state_dict().values()]])
|
for parameters in self.parameters():
|
||||||
new_state_dict = {key: new_weights[start: end].view(
|
size = parameters.numel()
|
||||||
shape) for key, shape, start, end in zip(keys, shapes, numels, numels[1:])
|
parameters[:] = new_weights[i:i+size].view(parameters.shape)[:]
|
||||||
}
|
i += size
|
||||||
# noinspection PyTypeChecker
|
|
||||||
self.load_state_dict(new_state_dict)
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __init__(self, i_size: int, h_size: int, o_size: int, name=None, start_time=1) -> None:
|
def __init__(self, i_size: int, h_size: int, o_size: int, name=None, start_time=1) -> None:
|
||||||
@@ -164,7 +162,6 @@ class Net(nn.Module):
|
|||||||
weight_matrix = torch.cat([x.view(-1, 1) for x in self.parameters()])
|
weight_matrix = torch.cat([x.view(-1, 1) for x in self.parameters()])
|
||||||
return weight_matrix
|
return weight_matrix
|
||||||
|
|
||||||
|
|
||||||
def self_train(self,
|
def self_train(self,
|
||||||
training_steps: int,
|
training_steps: int,
|
||||||
log_step_size: int = 0,
|
log_step_size: int = 0,
|
||||||
@@ -478,7 +475,7 @@ class MetaNet(nn.Module):
|
|||||||
for cell in layer.meta_cell_list:
|
for cell in layer.meta_cell_list:
|
||||||
# Individual replacement on cell lvl
|
# Individual replacement on cell lvl
|
||||||
for weight in cell.meta_weight_list:
|
for weight in cell.meta_weight_list:
|
||||||
weight.apply_weights(next(particle_weights_list))
|
weight.apply_weights(next(particle_weights_list).detach())
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
22
sanity_check_particle_weight_swap.py
Normal file
22
sanity_check_particle_weight_swap.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from network import MetaNet
|
||||||
|
from sparse_net import SparseNetwork
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
dense_metanet = MetaNet(30, depth=5, width=6, out=10, residual_skip=True,
|
||||||
|
weight_hidden_size=3, )
|
||||||
|
sparse_metanet = SparseNetwork(30, depth=5, width=6, out=10, residual_skip=True,
|
||||||
|
weight_hidden_size=3,)
|
||||||
|
|
||||||
|
particles = [torch.cat([x.view(-1) for x in x.parameters()]) for x in dense_metanet.particles]
|
||||||
|
|
||||||
|
# Transfer weights
|
||||||
|
sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles)
|
||||||
|
|
||||||
|
# Transfer weights
|
||||||
|
dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights)
|
||||||
|
new_particles = [torch.cat([x.view(-1) for x in x.parameters()]) for x in dense_metanet.particles]
|
||||||
|
|
||||||
|
print(f' Particles are same: {all([(x==y).all() for x,y in zip(particles, new_particles) ])}')
|
||||||
@@ -55,8 +55,8 @@ if __name__ == '__main__':
|
|||||||
mnist_test = MNIST(str(data_path), transform=MNIST_TRANSFORM, download=True, train=False)
|
mnist_test = MNIST(str(data_path), transform=MNIST_TRANSFORM, download=True, train=False)
|
||||||
d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER)
|
d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER)
|
||||||
loss_fn = nn.CrossEntropyLoss()
|
loss_fn = nn.CrossEntropyLoss()
|
||||||
|
model_path = (Path() / r'experiments\output\mn_st_40_6_res_Tsk_0.85\trained_model_ckpt_e40.tp')
|
||||||
model = torch.load("mn_st_40_6_res_Tsk_0.85", map_location=DEVICE).eval()
|
model = torch.load(model_path, map_location=DEVICE).eval()
|
||||||
weights = extract_weights_from_model(model)
|
weights = extract_weights_from_model(model)
|
||||||
test_weights_as_model(weights, d_test)
|
test_weights_as_model(weights, d_test)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user