diff --git a/sparse_net.py b/sparse_net.py index 5273842..004840e 100644 --- a/sparse_net.py +++ b/sparse_net.py @@ -94,13 +94,13 @@ class SparseLayer(nn.Module): def replace_weights_by_particles(self, particles): assert len(particles) == self.nr_nets - - # Particle Weight Update - all_weights = [list(particle.parameters()) for particle in particles] - all_weights = [torch.cat(x).view(-1) for x in zip(*all_weights)] - # [layer.view(-1, int(len(layer) / self.nr_nets)) for layer in self.weights] - for widx, (weights, key) in enumerate(zip(all_weights, self.state_dict().keys())): - self.state_dict()[key] = weights[:] + with torch.no_grad(): + # Particle Weight Update + all_weights = [list(particle.parameters()) for particle in particles] + all_weights = [torch.cat(x).view(-1) for x in zip(*all_weights)] + # [layer.view(-1, int(len(layer) / self.nr_nets)) for layer in self.weights] + for weights, parameters in zip(all_weights, self.parameters()): + parameters[:] = weights[:] return self def __call__(self, x):