journal linspace basins

This commit is contained in:
steffen-illium
2021-06-25 10:25:25 +02:00
parent cf6eec639f
commit 14d9a533cb
8 changed files with 69 additions and 100 deletions

View File

@@ -1,7 +1,6 @@
import copy
import random
import os.path
import pickle
from pathlib import Path
from typing import Union
@@ -13,7 +12,6 @@ from matplotlib import pyplot as plt
from torch.nn import functional as F
from tabulate import tabulate
from experiments.helpers import check_folder, summary_fixpoint_percentage, summary_fixpoint_experiment
from functionalities_test import test_for_fixpoints, is_zero_fixpoint, is_divergent, is_identity_function
from network import Net
from visualization import plot_loss, bar_chart_fixpoints, plot_3d_soup, line_chart_fixpoints
@@ -25,20 +23,6 @@ def prng():
class SoupRobustnessExperiment:
@staticmethod
def apply_noise(network, noise: int):
""" Changing the weights of a network to values + noise """
for layer_id, layer_name in enumerate(network.state_dict()):
for line_id, line_values in enumerate(network.state_dict()[layer_name]):
for weight_id, weight_value in enumerate(network.state_dict()[layer_name][line_id]):
# network.state_dict()[layer_name][line_id][weight_id] = weight_value + noise
if prng() < 0.5:
network.state_dict()[layer_name][line_id][weight_id] = weight_value + noise
else:
network.state_dict()[layer_name][line_id][weight_id] = weight_value - noise
return network
def __init__(self, population_size, net_i_size, net_h_size, net_o_size, learning_rate, attack_chance,
train_nets, ST_steps, epochs, log_step_size, directory: Union[str, Path]):
super().__init__()
@@ -146,8 +130,7 @@ class SoupRobustnessExperiment:
clone = Net(fixpoint.input_size, fixpoint.hidden_size, fixpoint.out_size,
f"{fixpoint.name}_clone_noise10e-{noise_level}")
clone.load_state_dict(copy.deepcopy(fixpoint.state_dict()))
rand_noise = prng() * pow(10, -noise_level) # n / 1
clone = self.apply_noise(clone, rand_noise)
clone = clone.apply_noise(pow(10, -noise_level))
while not is_zero_fixpoint(clone) and not is_divergent(clone):
if is_identity_function(clone):