from state dict to direct parameter update
This commit is contained in:
parent
bb12176f72
commit
f0ad875e79
@ -94,13 +94,13 @@ class SparseLayer(nn.Module):
|
||||
|
||||
def replace_weights_by_particles(self, particles):
|
||||
assert len(particles) == self.nr_nets
|
||||
|
||||
# Particle Weight Update
|
||||
all_weights = [list(particle.parameters()) for particle in particles]
|
||||
all_weights = [torch.cat(x).view(-1) for x in zip(*all_weights)]
|
||||
# [layer.view(-1, int(len(layer) / self.nr_nets)) for layer in self.weights]
|
||||
for widx, (weights, key) in enumerate(zip(all_weights, self.state_dict().keys())):
|
||||
self.state_dict()[key] = weights[:]
|
||||
with torch.no_grad():
|
||||
# Particle Weight Update
|
||||
all_weights = [list(particle.parameters()) for particle in particles]
|
||||
all_weights = [torch.cat(x).view(-1) for x in zip(*all_weights)]
|
||||
# [layer.view(-1, int(len(layer) / self.nr_nets)) for layer in self.weights]
|
||||
for weights, parameters in zip(all_weights, self.parameters()):
|
||||
parameters[:] = weights[:]
|
||||
return self
|
||||
|
||||
def __call__(self, x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user