This commit is contained in:
Si11ium
2019-03-11 16:12:10 +01:00
parent 7a76b1ba88
commit 1f76c06c01
4 changed files with 159 additions and 147 deletions

View File

@ -13,9 +13,8 @@ class Experiment:
return dill.load(dill_file)
def __init__(self, name=None, ident=None):
self.experiment_id = ident or time.time()
self.experiment_id = '{}_{}'.format(ident or '', time.time().as_integer_ratio()[0])
self.experiment_name = name or 'unnamed_experiment'
self.base_dir = self.experiment_name
self.next_iteration = 0
self.log_messages = []
self.historical_particles = {}
@ -62,8 +61,8 @@ class Experiment:
class FixpointExperiment(Experiment):
def __init__(self):
super().__init__(name=self.__class__.__name__)
def __init__(self, **kwargs):
super().__init__(name=self.__class__.__name__, **kwargs)
self.counters = dict(divergent=0, fix_zero=0, fix_other=0, fix_sec=0, other=0)
self.interesting_fixpoints = []
@ -73,7 +72,7 @@ class FixpointExperiment(Experiment):
net.self_attack()
i += 1
if run_id:
net.save_state(time=run_id)
net.save_state(time=i)
self.count(net)
def count(self, net):
@ -94,9 +93,6 @@ class FixpointExperiment(Experiment):
class MixedFixpointExperiment(FixpointExperiment):
def run_net(self, net, trains_per_application=100, step_limit=100, run_id=0):
# TODO Where to place the trajectory storage ?
# weights = net.get_weights()
# self.add_trajectory_segment(run_id, weights)
i = 0
while i < step_limit and not net.is_diverged() and not net.is_fixpoint():
@ -107,6 +103,8 @@ class MixedFixpointExperiment(FixpointExperiment):
bar.postfix[1]["value"] = loss
bar.update()
i += 1
if run_id:
net.save_state()
self.count(net)
@ -115,4 +113,7 @@ class SoupExperiment(Experiment):
class IdentLearningExperiment(Experiment):
def __init__(self):
super(IdentLearningExperiment, self).__init__(name=self.__class__.__name__)
pass