built class for training all networks, including working fixpoint check and some experiments on that
This commit is contained in:
@ -2,6 +2,7 @@ import sys
|
||||
import os
|
||||
import time
|
||||
import dill
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class Experiment:
|
||||
@ -69,15 +70,23 @@ class FixpointExperiment(Experiment):
|
||||
self.counters['fix_zero'] += 1
|
||||
else:
|
||||
self.counters['fix_other'] += 1
|
||||
self.interesting_fixpoints.append(net)
|
||||
self.log(net.repr_weights())
|
||||
net.self_attack()
|
||||
self.log(net.repr_weights())
|
||||
self.interesting_fixpoints.append(net.get_weights())
|
||||
elif net.is_fixpoint(2):
|
||||
self.counters['fix_sec'] += 1
|
||||
else:
|
||||
self.counters['other'] += 1
|
||||
|
||||
class MixedFixpointExperiment(FixpointExperiment):
|
||||
|
||||
def run_net(self, net, trains_per_application=100, step_limit=100):
|
||||
i = 0
|
||||
while i < step_limit and not net.is_diverged() and not net.is_fixpoint():
|
||||
net.self_attack()
|
||||
for _ in tqdm(range(trains_per_application)):
|
||||
loss = net.compiled().train()
|
||||
i += 1
|
||||
self.count(net)
|
||||
|
||||
|
||||
class SoupExperiment(Experiment):
|
||||
pass
|
||||
|
Reference in New Issue
Block a user