reset diverged by xavier init
This commit is contained in:
@ -17,6 +17,11 @@ from torchvision.datasets import MNIST
|
||||
from torchvision.transforms import ToTensor, Compose, Resize
|
||||
|
||||
|
||||
def xavier_init(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight.data)
|
||||
|
||||
|
||||
class SparseLayer(nn.Module):
|
||||
def __init__(self, nr_nets, interface=5, depth=3, width=2, out=1):
|
||||
super(SparseLayer, self).__init__()
|
||||
@ -41,6 +46,7 @@ class SparseLayer(nn.Module):
|
||||
self.indices.append(indices)
|
||||
self.diag_shapes.append(diag_shape)
|
||||
self.weights.append(weights)
|
||||
self.apply(xavier_init)
|
||||
|
||||
def coo_sparse_layer(self, layer_id):
|
||||
with torch.no_grad():
|
||||
@ -91,6 +97,12 @@ class SparseLayer(nn.Module):
|
||||
particles.apply_weights(weights)
|
||||
return self._particles
|
||||
|
||||
def reset_diverged_particles(self):
|
||||
for weights in self.weights:
|
||||
if torch.isinf(weights).any() or torch.isnan(weights).any():
|
||||
with torch.no_grad():
|
||||
xavier_init(weights)
|
||||
|
||||
@property
|
||||
def particle_weights(self):
|
||||
all_weights = [layer.view(-1, int(len(layer) / self.nr_nets)) for layer in self.weights]
|
||||
@ -246,6 +258,10 @@ class SparseNetwork(nn.Module):
|
||||
def particle_weights(self):
|
||||
return (x for y in self.sparselayers for x in y.particle_weights)
|
||||
|
||||
def reset_diverged_particles(self):
|
||||
for layer in self.sparselayers:
|
||||
layer.reset_diverged_particles()
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
super(SparseNetwork, self).to(*args, **kwargs)
|
||||
self.first_layer = self.first_layer.to(*args, **kwargs)
|
||||
@ -266,7 +282,7 @@ class SparseNetwork(nn.Module):
|
||||
output = layer(x)
|
||||
# loss = sum([loss_fn(out, target) for out, target in zip(output, target_data)]) / len(output)
|
||||
|
||||
loss = loss_fn(output, target_data) * 100
|
||||
loss = loss_fn(output, target_data) * 85
|
||||
|
||||
losses.append(loss.detach())
|
||||
loss.backward()
|
||||
@ -314,6 +330,7 @@ def test_sparse_net_sef_train():
|
||||
tqdm.write(f"identity_fn after {epoch + 1} self-train epochs: {counter}")
|
||||
for key, value in counter.items():
|
||||
df.loc[df.shape[0]] = (epoch, key, value)
|
||||
net.reset_diverged_particles()
|
||||
|
||||
counter = defaultdict(lambda: 0)
|
||||
id_functions = functionalities_test.test_for_fixpoints(counter, list(net.particles))
|
||||
|
Reference in New Issue
Block a user