reset diverged by xavier init
This commit is contained in:
@ -393,6 +393,8 @@ if __name__ == '__main__':
|
|||||||
if use_sparse_network:
|
if use_sparse_network:
|
||||||
dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights)
|
dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights)
|
||||||
|
|
||||||
|
dense_metanet.reset_diverged_particles()
|
||||||
|
|
||||||
# Task Train
|
# Task Train
|
||||||
if not init_st:
|
if not init_st:
|
||||||
# Zero your gradients for every batch!
|
# Zero your gradients for every batch!
|
||||||
|
@ -487,6 +487,11 @@ class MetaNet(nn.Module):
|
|||||||
def all_layers(self):
|
def all_layers(self):
|
||||||
return (x for x in (self._meta_layer_first, *self._meta_layer_list, self._meta_layer_last))
|
return (x for x in (self._meta_layer_first, *self._meta_layer_list, self._meta_layer_last))
|
||||||
|
|
||||||
|
def reset_diverged_particles(self):
|
||||||
|
for particle in self.particles:
|
||||||
|
if particle.is_fixpoint == FixTypes.divergent:
|
||||||
|
particle.apply(xavier_init)
|
||||||
|
|
||||||
|
|
||||||
class MetaNetCompareBaseline(nn.Module):
|
class MetaNetCompareBaseline(nn.Module):
|
||||||
|
|
||||||
|
@ -17,6 +17,11 @@ from torchvision.datasets import MNIST
|
|||||||
from torchvision.transforms import ToTensor, Compose, Resize
|
from torchvision.transforms import ToTensor, Compose, Resize
|
||||||
|
|
||||||
|
|
||||||
|
def xavier_init(m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
nn.init.xavier_uniform_(m.weight.data)
|
||||||
|
|
||||||
|
|
||||||
class SparseLayer(nn.Module):
|
class SparseLayer(nn.Module):
|
||||||
def __init__(self, nr_nets, interface=5, depth=3, width=2, out=1):
|
def __init__(self, nr_nets, interface=5, depth=3, width=2, out=1):
|
||||||
super(SparseLayer, self).__init__()
|
super(SparseLayer, self).__init__()
|
||||||
@ -41,6 +46,7 @@ class SparseLayer(nn.Module):
|
|||||||
self.indices.append(indices)
|
self.indices.append(indices)
|
||||||
self.diag_shapes.append(diag_shape)
|
self.diag_shapes.append(diag_shape)
|
||||||
self.weights.append(weights)
|
self.weights.append(weights)
|
||||||
|
self.apply(xavier_init)
|
||||||
|
|
||||||
def coo_sparse_layer(self, layer_id):
|
def coo_sparse_layer(self, layer_id):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -91,6 +97,12 @@ class SparseLayer(nn.Module):
|
|||||||
particles.apply_weights(weights)
|
particles.apply_weights(weights)
|
||||||
return self._particles
|
return self._particles
|
||||||
|
|
||||||
|
def reset_diverged_particles(self):
|
||||||
|
for weights in self.weights:
|
||||||
|
if torch.isinf(weights).any() or torch.isnan(weights).any():
|
||||||
|
with torch.no_grad():
|
||||||
|
xavier_init(weights)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def particle_weights(self):
|
def particle_weights(self):
|
||||||
all_weights = [layer.view(-1, int(len(layer) / self.nr_nets)) for layer in self.weights]
|
all_weights = [layer.view(-1, int(len(layer) / self.nr_nets)) for layer in self.weights]
|
||||||
@ -246,6 +258,10 @@ class SparseNetwork(nn.Module):
|
|||||||
def particle_weights(self):
|
def particle_weights(self):
|
||||||
return (x for y in self.sparselayers for x in y.particle_weights)
|
return (x for y in self.sparselayers for x in y.particle_weights)
|
||||||
|
|
||||||
|
def reset_diverged_particles(self):
|
||||||
|
for layer in self.sparselayers:
|
||||||
|
layer.reset_diverged_particles()
|
||||||
|
|
||||||
def to(self, *args, **kwargs):
|
def to(self, *args, **kwargs):
|
||||||
super(SparseNetwork, self).to(*args, **kwargs)
|
super(SparseNetwork, self).to(*args, **kwargs)
|
||||||
self.first_layer = self.first_layer.to(*args, **kwargs)
|
self.first_layer = self.first_layer.to(*args, **kwargs)
|
||||||
@ -266,7 +282,7 @@ class SparseNetwork(nn.Module):
|
|||||||
output = layer(x)
|
output = layer(x)
|
||||||
# loss = sum([loss_fn(out, target) for out, target in zip(output, target_data)]) / len(output)
|
# loss = sum([loss_fn(out, target) for out, target in zip(output, target_data)]) / len(output)
|
||||||
|
|
||||||
loss = loss_fn(output, target_data) * 100
|
loss = loss_fn(output, target_data) * 85
|
||||||
|
|
||||||
losses.append(loss.detach())
|
losses.append(loss.detach())
|
||||||
loss.backward()
|
loss.backward()
|
||||||
@ -314,6 +330,7 @@ def test_sparse_net_sef_train():
|
|||||||
tqdm.write(f"identity_fn after {epoch + 1} self-train epochs: {counter}")
|
tqdm.write(f"identity_fn after {epoch + 1} self-train epochs: {counter}")
|
||||||
for key, value in counter.items():
|
for key, value in counter.items():
|
||||||
df.loc[df.shape[0]] = (epoch, key, value)
|
df.loc[df.shape[0]] = (epoch, key, value)
|
||||||
|
net.reset_diverged_particles()
|
||||||
|
|
||||||
counter = defaultdict(lambda: 0)
|
counter = defaultdict(lambda: 0)
|
||||||
id_functions = functionalities_test.test_for_fixpoints(counter, list(net.particles))
|
id_functions = functionalities_test.test_for_fixpoints(counter, list(net.particles))
|
||||||
|
Reference in New Issue
Block a user