MetaNetworks Debugged

This commit is contained in:
Steffen Illium
2022-01-31 10:35:11 +01:00
parent 49c0d8a621
commit 246d825bb4
8 changed files with 169 additions and 109 deletions

View File

@ -1,7 +1,6 @@
import copy
from typing import Dict, List
import numpy as np
from torch import Tensor
from network import Net
@ -9,7 +8,7 @@ def is_divergent(network: Net) -> bool:
for i in network.input_weight_matrix():
weight_value = i[0].item()
if np.isnan(weight_value).all() or np.isinf(weight_value).all():
if np.isnan(weight_value).any() or np.isinf(weight_value).any():
return True
return False
@ -25,7 +24,7 @@ def is_identity_function(network: Net, epsilon=pow(10, -5)) -> bool:
def is_zero_fixpoint(network: Net, epsilon=pow(10, -5)) -> bool:
target_data = network.create_target_weights(network.input_weight_matrix().detach().numpy())
target_data = network.create_target_weights(network.input_weight_matrix().detach())
result = np.allclose(target_data, np.zeros_like(target_data), rtol=0, atol=epsilon)
# result = bool(len(np.nonzero(network.create_target_weights(network.input_weight_matrix()))))
return result
@ -95,4 +94,4 @@ def test_status(net: Net) -> Net:
else:
net.is_fixpoint = "other_func"
return net
return net