reset diverged by xavier init

This commit is contained in:
Steffen Illium
2022-02-26 16:42:37 +01:00
parent a3a587476c
commit 78a919395b
3 changed files with 25 additions and 1 deletions

View File

@ -17,6 +17,11 @@ from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Compose, Resize
def xavier_init(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data)
class SparseLayer(nn.Module):
def __init__(self, nr_nets, interface=5, depth=3, width=2, out=1):
super(SparseLayer, self).__init__()
@ -41,6 +46,7 @@ class SparseLayer(nn.Module):
self.indices.append(indices)
self.diag_shapes.append(diag_shape)
self.weights.append(weights)
self.apply(xavier_init)
def coo_sparse_layer(self, layer_id):
with torch.no_grad():
@ -91,6 +97,12 @@ class SparseLayer(nn.Module):
particles.apply_weights(weights)
return self._particles
def reset_diverged_particles(self):
for weights in self.weights:
if torch.isinf(weights).any() or torch.isnan(weights).any():
with torch.no_grad():
xavier_init(weights)
@property
def particle_weights(self):
all_weights = [layer.view(-1, int(len(layer) / self.nr_nets)) for layer in self.weights]
@ -246,6 +258,10 @@ class SparseNetwork(nn.Module):
def particle_weights(self):
return (x for y in self.sparselayers for x in y.particle_weights)
def reset_diverged_particles(self):
for layer in self.sparselayers:
layer.reset_diverged_particles()
def to(self, *args, **kwargs):
super(SparseNetwork, self).to(*args, **kwargs)
self.first_layer = self.first_layer.to(*args, **kwargs)
@ -266,7 +282,7 @@ class SparseNetwork(nn.Module):
output = layer(x)
# loss = sum([loss_fn(out, target) for out, target in zip(output, target_data)]) / len(output)
loss = loss_fn(output, target_data) * 100
loss = loss_fn(output, target_data) * 85
losses.append(loss.detach())
loss.backward()
@ -314,6 +330,7 @@ def test_sparse_net_sef_train():
tqdm.write(f"identity_fn after {epoch + 1} self-train epochs: {counter}")
for key, value in counter.items():
df.loc[df.shape[0]] = (epoch, key, value)
net.reset_diverged_particles()
counter = defaultdict(lambda: 0)
id_functions = functionalities_test.test_for_fixpoints(counter, list(net.particles))