from state dict to direct parameter update

This commit is contained in:
Steffen Illium 2022-02-22 10:08:27 +01:00
parent bb12176f72
commit f0ad875e79

View File

@ -94,13 +94,13 @@ class SparseLayer(nn.Module):
def replace_weights_by_particles(self, particles):
assert len(particles) == self.nr_nets
with torch.no_grad():
# Particle Weight Update
all_weights = [list(particle.parameters()) for particle in particles]
all_weights = [torch.cat(x).view(-1) for x in zip(*all_weights)]
# [layer.view(-1, int(len(layer) / self.nr_nets)) for layer in self.weights]
for widx, (weights, key) in enumerate(zip(all_weights, self.state_dict().keys())):
self.state_dict()[key] = weights[:]
for weights, parameters in zip(all_weights, self.parameters()):
parameters[:] = weights[:]
return self
def __call__(self, x):