TeamWork 3>

This commit is contained in:
Si11ium
2019-03-05 12:51:41 +01:00
parent 18c84d1483
commit 7766fed5ab
4 changed files with 169 additions and 122 deletions

View File

@ -1,4 +1,3 @@
import sys
import os
import time
import dill
@ -75,15 +74,19 @@ class FixpointExperiment(Experiment):
self.counters['fix_sec'] += 1
else:
self.counters['other'] += 1
class MixedFixpointExperiment(FixpointExperiment):
def run_net(self, net, trains_per_application=100, step_limit=100):
i = 0
while i < step_limit and not net.is_diverged() and not net.is_fixpoint():
net.self_attack()
for _ in tqdm(range(trains_per_application)):
loss = net.compiled().train()
with tqdm(postfix=["Loss", dict(value=0)]) as bar:
for _ in range(trains_per_application):
loss = net.compiled().train()
bar.postfix[1]["value"] = loss
bar.update()
i += 1
self.count(net)