diff --git a/code/methods.py b/code/methods.py index 2aa0c2f..2490fe8 100644 --- a/code/methods.py +++ b/code/methods.py @@ -3,6 +3,10 @@ from keras.models import Sequential, Model from keras.layers import SimpleRNN, Dense from keras.layers import Input, TimeDistributed from tqdm import tqdm +import time +import os +import dill + import itertools @@ -168,16 +172,47 @@ class FeedForwardNetwork(_BaseNetwork): bar.update() return losses +class Experiment: + + @staticmethod + def from_dill(path): + with open(path) as dill_file: + return dill.load(dill_file) + + def __init__(self, name=None, id=None): + self.experiment_id = id or time.time() + this_file = os.path.realpath(__file__) + self.experiment_name = name or os.path.basename(this_file) + self.base_dir = os.path.realpath((os.path.dirname(this_file) + "/..")) + "/" + self.next_iteration = 0 + + def __enter__(self): + self.dir = self.base_dir + "experiments/exp-" + str(self.experiment_name) + "-" + str(self.experiment_id) + "-" + str(self.next_iteration) + "/" + os.mkdir(self.dir) + print("** created " + str(self.dir)) + + def __exit__(self, exc_type, exc_value, traceback): + self.save(experiment=self) + self.next_iteration += 1 + + def save(self, **kwargs): + for name,value in kwargs.items(): + with open(self.dir + "/" + str(name) + ".dill", "wb") as dill_file: + dill.dump(value, dill_file) + + + if __name__ == '__main__': - features, cells, layers = 2, 2, 2 - use_recurrent = False - if use_recurrent: - network = Network(features, cells, layers, recurrent=use_recurrent) - r = RecurrentNetwork(network) - loss = r.fit(epochs=10) - else: - network = Network(features, cells, layers, recurrent=use_recurrent) - ff = FeedForwardNetwork(network) - loss = ff.fit(epochs=10) - print(loss) + with Experiment() as exp: + features, cells, layers = 2, 2, 2 + use_recurrent = False + if use_recurrent: + network = Network(features, cells, layers, recurrent=use_recurrent) + r = RecurrentNetwork(network) + loss = r.fit(epochs=10) + else: + network = Network(features, cells, layers, recurrent=use_recurrent) + ff = FeedForwardNetwork(network) + loss = ff.fit(epochs=10) + print(loss)