MetaNetworks Debugged
This commit is contained in:
@ -1,7 +1,6 @@
|
||||
import copy
|
||||
from typing import Dict, List
|
||||
import numpy as np
|
||||
from torch import Tensor
|
||||
from network import Net
|
||||
|
||||
|
||||
@ -9,7 +8,7 @@ def is_divergent(network: Net) -> bool:
|
||||
for i in network.input_weight_matrix():
|
||||
weight_value = i[0].item()
|
||||
|
||||
if np.isnan(weight_value).all() or np.isinf(weight_value).all():
|
||||
if np.isnan(weight_value).any() or np.isinf(weight_value).any():
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -25,7 +24,7 @@ def is_identity_function(network: Net, epsilon=pow(10, -5)) -> bool:
|
||||
|
||||
|
||||
def is_zero_fixpoint(network: Net, epsilon=pow(10, -5)) -> bool:
|
||||
target_data = network.create_target_weights(network.input_weight_matrix().detach().numpy())
|
||||
target_data = network.create_target_weights(network.input_weight_matrix().detach())
|
||||
result = np.allclose(target_data, np.zeros_like(target_data), rtol=0, atol=epsilon)
|
||||
# result = bool(len(np.nonzero(network.create_target_weights(network.input_weight_matrix()))))
|
||||
return result
|
||||
@ -95,4 +94,4 @@ def test_status(net: Net) -> Net:
|
||||
else:
|
||||
net.is_fixpoint = "other_func"
|
||||
|
||||
return net
|
||||
return net
|
||||
|
Reference in New Issue
Block a user