built class for training all networks, including working fixpoint check and some experiments on that

This commit is contained in:
Thomas Gabor
2019-03-05 04:42:50 +01:00
parent 7ea8287b0e
commit 19e4ed65f9
2 changed files with 271 additions and 31 deletions

View File

@ -2,6 +2,7 @@ import sys
import os
import time
import dill
from tqdm import tqdm
class Experiment:
@ -69,15 +70,23 @@ class FixpointExperiment(Experiment):
self.counters['fix_zero'] += 1
else:
self.counters['fix_other'] += 1
self.interesting_fixpoints.append(net)
self.log(net.repr_weights())
net.self_attack()
self.log(net.repr_weights())
self.interesting_fixpoints.append(net.get_weights())
elif net.is_fixpoint(2):
self.counters['fix_sec'] += 1
else:
self.counters['other'] += 1
class MixedFixpointExperiment(FixpointExperiment):
def run_net(self, net, trains_per_application=100, step_limit=100):
i = 0
while i < step_limit and not net.is_diverged() and not net.is_fixpoint():
net.self_attack()
for _ in tqdm(range(trains_per_application)):
loss = net.compiled().train()
i += 1
self.count(net)
class SoupExperiment(Experiment):
pass