reset diverged by xavier init

This commit is contained in:
Steffen Illium
2022-02-26 16:42:37 +01:00
parent a3a587476c
commit 78a919395b
3 changed files with 25 additions and 1 deletions

View File

@ -393,6 +393,8 @@ if __name__ == '__main__':
if use_sparse_network: if use_sparse_network:
dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights) dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights)
dense_metanet.reset_diverged_particles()
# Task Train # Task Train
if not init_st: if not init_st:
# Zero your gradients for every batch! # Zero your gradients for every batch!

View File

@ -487,6 +487,11 @@ class MetaNet(nn.Module):
def all_layers(self): def all_layers(self):
return (x for x in (self._meta_layer_first, *self._meta_layer_list, self._meta_layer_last)) return (x for x in (self._meta_layer_first, *self._meta_layer_list, self._meta_layer_last))
def reset_diverged_particles(self):
for particle in self.particles:
if particle.is_fixpoint == FixTypes.divergent:
particle.apply(xavier_init)
class MetaNetCompareBaseline(nn.Module): class MetaNetCompareBaseline(nn.Module):

View File

@ -17,6 +17,11 @@ from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Compose, Resize from torchvision.transforms import ToTensor, Compose, Resize
def xavier_init(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data)
class SparseLayer(nn.Module): class SparseLayer(nn.Module):
def __init__(self, nr_nets, interface=5, depth=3, width=2, out=1): def __init__(self, nr_nets, interface=5, depth=3, width=2, out=1):
super(SparseLayer, self).__init__() super(SparseLayer, self).__init__()
@ -41,6 +46,7 @@ class SparseLayer(nn.Module):
self.indices.append(indices) self.indices.append(indices)
self.diag_shapes.append(diag_shape) self.diag_shapes.append(diag_shape)
self.weights.append(weights) self.weights.append(weights)
self.apply(xavier_init)
def coo_sparse_layer(self, layer_id): def coo_sparse_layer(self, layer_id):
with torch.no_grad(): with torch.no_grad():
@ -91,6 +97,12 @@ class SparseLayer(nn.Module):
particles.apply_weights(weights) particles.apply_weights(weights)
return self._particles return self._particles
def reset_diverged_particles(self):
for weights in self.weights:
if torch.isinf(weights).any() or torch.isnan(weights).any():
with torch.no_grad():
xavier_init(weights)
@property @property
def particle_weights(self): def particle_weights(self):
all_weights = [layer.view(-1, int(len(layer) / self.nr_nets)) for layer in self.weights] all_weights = [layer.view(-1, int(len(layer) / self.nr_nets)) for layer in self.weights]
@ -246,6 +258,10 @@ class SparseNetwork(nn.Module):
def particle_weights(self): def particle_weights(self):
return (x for y in self.sparselayers for x in y.particle_weights) return (x for y in self.sparselayers for x in y.particle_weights)
def reset_diverged_particles(self):
for layer in self.sparselayers:
layer.reset_diverged_particles()
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
super(SparseNetwork, self).to(*args, **kwargs) super(SparseNetwork, self).to(*args, **kwargs)
self.first_layer = self.first_layer.to(*args, **kwargs) self.first_layer = self.first_layer.to(*args, **kwargs)
@ -266,7 +282,7 @@ class SparseNetwork(nn.Module):
output = layer(x) output = layer(x)
# loss = sum([loss_fn(out, target) for out, target in zip(output, target_data)]) / len(output) # loss = sum([loss_fn(out, target) for out, target in zip(output, target_data)]) / len(output)
loss = loss_fn(output, target_data) * 100 loss = loss_fn(output, target_data) * 85
losses.append(loss.detach()) losses.append(loss.detach())
loss.backward() loss.backward()
@ -314,6 +330,7 @@ def test_sparse_net_sef_train():
tqdm.write(f"identity_fn after {epoch + 1} self-train epochs: {counter}") tqdm.write(f"identity_fn after {epoch + 1} self-train epochs: {counter}")
for key, value in counter.items(): for key, value in counter.items():
df.loc[df.shape[0]] = (epoch, key, value) df.loc[df.shape[0]] = (epoch, key, value)
net.reset_diverged_particles()
counter = defaultdict(lambda: 0) counter = defaultdict(lambda: 0)
id_functions = functionalities_test.test_for_fixpoints(counter, list(net.particles)) id_functions = functionalities_test.test_for_fixpoints(counter, list(net.particles))