MetaNetworks Debugged II
This commit is contained in:
@ -1,16 +1,13 @@
|
||||
import copy
|
||||
from typing import Dict, List
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from network import Net
|
||||
|
||||
|
||||
def is_divergent(network: Net) -> bool:
|
||||
for i in network.input_weight_matrix():
|
||||
weight_value = i[0].item()
|
||||
|
||||
if np.isnan(weight_value).any() or np.isinf(weight_value).any():
|
||||
return True
|
||||
return False
|
||||
return network.input_weight_matrix().isinf().any().item() or network.input_weight_matrix().isnan().any().item()
|
||||
|
||||
|
||||
def is_identity_function(network: Net, epsilon=pow(10, -5)) -> bool:
|
||||
@ -19,13 +16,14 @@ def is_identity_function(network: Net, epsilon=pow(10, -5)) -> bool:
|
||||
target_data = network.create_target_weights(input_data)
|
||||
predicted_values = network(input_data)
|
||||
|
||||
return np.allclose(target_data.detach().numpy(), predicted_values.detach().numpy(),
|
||||
rtol=0, atol=epsilon)
|
||||
|
||||
return torch.allclose(target_data.detach(), predicted_values.detach(),
|
||||
rtol=0, atol=epsilon)
|
||||
|
||||
|
||||
def is_zero_fixpoint(network: Net, epsilon=pow(10, -5)) -> bool:
|
||||
target_data = network.create_target_weights(network.input_weight_matrix().detach())
|
||||
result = np.allclose(target_data, np.zeros_like(target_data), rtol=0, atol=epsilon)
|
||||
result = torch.allclose(target_data, torch.zeros_like(target_data), rtol=0, atol=epsilon)
|
||||
# result = bool(len(np.nonzero(network.create_target_weights(network.input_weight_matrix()))))
|
||||
return result
|
||||
|
||||
@ -49,15 +47,15 @@ def is_secondary_fixpoint(network: Net, epsilon: float = pow(10, -5)) -> bool:
|
||||
second_output = network(input_data_2)
|
||||
|
||||
# Perform the Check: all(epsilon > abs(input_data - second_output))
|
||||
check_abs_within_epsilon = np.allclose(target_data.detach().numpy(), second_output.detach().numpy(),
|
||||
rtol=0, atol=epsilon)
|
||||
check_abs_within_epsilon = torch.allclose(target_data.detach(), second_output.detach(),
|
||||
rtol=0, atol=epsilon)
|
||||
return check_abs_within_epsilon
|
||||
|
||||
|
||||
def test_for_fixpoints(fixpoint_counter: Dict, nets: List, id_functions=None):
|
||||
id_functions = id_functions or list()
|
||||
|
||||
for net in nets:
|
||||
for net in tqdm(nets, desc='Fixpoint Tester', total=len(nets)):
|
||||
if is_divergent(net):
|
||||
fixpoint_counter["divergent"] += 1
|
||||
net.is_fixpoint = "divergent"
|
||||
|
Reference in New Issue
Block a user