TeamWork 3>
This commit is contained in:
@ -1,4 +1,3 @@
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import dill
|
||||
@ -75,15 +74,19 @@ class FixpointExperiment(Experiment):
|
||||
self.counters['fix_sec'] += 1
|
||||
else:
|
||||
self.counters['other'] += 1
|
||||
|
||||
|
||||
|
||||
class MixedFixpointExperiment(FixpointExperiment):
|
||||
|
||||
def run_net(self, net, trains_per_application=100, step_limit=100):
|
||||
i = 0
|
||||
while i < step_limit and not net.is_diverged() and not net.is_fixpoint():
|
||||
net.self_attack()
|
||||
for _ in tqdm(range(trains_per_application)):
|
||||
loss = net.compiled().train()
|
||||
with tqdm(postfix=["Loss", dict(value=0)]) as bar:
|
||||
for _ in range(trains_per_application):
|
||||
loss = net.compiled().train()
|
||||
bar.postfix[1]["value"] = loss
|
||||
bar.update()
|
||||
i += 1
|
||||
self.count(net)
|
||||
|
||||
|
Reference in New Issue
Block a user