import os import time import dill from tqdm import tqdm import copy class Experiment: @staticmethod def from_dill(path): with open(path, "rb") as dill_file: return dill.load(dill_file) def __init__(self, name=None, ident=None): self.experiment_id = '{}_{}'.format(ident or '', time.time()) self.experiment_name = name or 'unnamed_experiment' self.next_iteration = 0 self.log_messages = [] self.historical_particles = {} def __enter__(self): self.dir = os.path.join('experiments', 'exp-{name}-{id}-{it}'.format( name=self.experiment_name, id=self.experiment_id, it=self.next_iteration) ) os.makedirs(self.dir) print("** created {dir} **".format(dir=self.dir)) return self def __exit__(self, exc_type, exc_value, traceback): self.save(experiment=self.without_particles()) self.save_log() self.next_iteration += 1 def log(self, message, **kwargs): self.log_messages.append(message) print(message, **kwargs) def save_log(self, log_name="log"): with open(os.path.join(self.dir, "{name}.txt".format(name=log_name)), "w") as log_file: for log_message in self.log_messages: print(str(log_message), file=log_file) def __copy__(self): copy_ = Experiment(name=self.experiment_name,) copy_.__dict__ = {attr: self.__dict__[attr] for attr in self.__dict__ if attr not in ['particles', 'historical_particles']} return copy_ def without_particles(self): self_copy = copy.copy(self) # self_copy.particles = [particle.states for particle in self.particles] self_copy.historical_particles = {key: val.states for key, val in self.historical_particles.items()} return self_copy def save(self, **kwargs): for name, value in kwargs.items(): with open(os.path.join(self.dir, "{name}.dill".format(name=name)), "wb") as dill_file: dill.dump(value, dill_file) class FixpointExperiment(Experiment): def __init__(self, **kwargs): kwargs['name'] = self.__class__.__name__ if 'name' not in kwargs else kwargs['name'] super().__init__(**kwargs) self.counters = dict(divergent=0, fix_zero=0, fix_other=0, fix_sec=0, other=0) self.interesting_fixpoints = [] def run_net(self, net, step_limit=100, run_id=0): i = 0 while i < step_limit and not net.is_diverged() and not net.is_fixpoint(): net.self_attack() i += 1 if run_id: net.save_state(time=i) self.count(net) def count(self, net): if net.is_diverged(): self.counters['divergent'] += 1 elif net.is_fixpoint(): if net.is_zero(): self.counters['fix_zero'] += 1 else: self.counters['fix_other'] += 1 self.interesting_fixpoints.append(net.get_weights()) elif net.is_fixpoint(2): self.counters['fix_sec'] += 1 else: self.counters['other'] += 1 class MixedFixpointExperiment(FixpointExperiment): def run_net(self, net, trains_per_application=100, step_limit=100, run_id=0): i = 0 while i < step_limit and not net.is_diverged() and not net.is_fixpoint(): net.self_attack() with tqdm(postfix=["Loss", dict(value=0)]) as bar: for _ in range(trains_per_application): loss = net.compiled().train() bar.postfix[1]["value"] = loss bar.update() i += 1 if run_id: net.save_state() self.count(net) class SoupExperiment(Experiment): pass class IdentLearningExperiment(Experiment): def __init__(self): super(IdentLearningExperiment, self).__init__(name=self.__class__.__name__) pass