wrote quick experimentation class

This commit is contained in:
Thomas Gabor
2019-03-02 18:24:24 +01:00
parent ee3ac7d41a
commit 9feb1bd3d9

View File

@@ -3,6 +3,10 @@ from keras.models import Sequential, Model
from keras.layers import SimpleRNN, Dense from keras.layers import SimpleRNN, Dense
from keras.layers import Input, TimeDistributed from keras.layers import Input, TimeDistributed
from tqdm import tqdm from tqdm import tqdm
import time
import os
import dill
import itertools import itertools
@@ -168,16 +172,47 @@ class FeedForwardNetwork(_BaseNetwork):
bar.update() bar.update()
return losses return losses
class Experiment:
@staticmethod
def from_dill(path):
with open(path) as dill_file:
return dill.load(dill_file)
def __init__(self, name=None, id=None):
self.experiment_id = id or time.time()
this_file = os.path.realpath(__file__)
self.experiment_name = name or os.path.basename(this_file)
self.base_dir = os.path.realpath((os.path.dirname(this_file) + "/..")) + "/"
self.next_iteration = 0
def __enter__(self):
self.dir = self.base_dir + "experiments/exp-" + str(self.experiment_name) + "-" + str(self.experiment_id) + "-" + str(self.next_iteration) + "/"
os.mkdir(self.dir)
print("** created " + str(self.dir))
def __exit__(self, exc_type, exc_value, traceback):
self.save(experiment=self)
self.next_iteration += 1
def save(self, **kwargs):
for name,value in kwargs.items():
with open(self.dir + "/" + str(name) + ".dill", "wb") as dill_file:
dill.dump(value, dill_file)
if __name__ == '__main__': if __name__ == '__main__':
features, cells, layers = 2, 2, 2 with Experiment() as exp:
use_recurrent = False features, cells, layers = 2, 2, 2
if use_recurrent: use_recurrent = False
network = Network(features, cells, layers, recurrent=use_recurrent) if use_recurrent:
r = RecurrentNetwork(network) network = Network(features, cells, layers, recurrent=use_recurrent)
loss = r.fit(epochs=10) r = RecurrentNetwork(network)
else: loss = r.fit(epochs=10)
network = Network(features, cells, layers, recurrent=use_recurrent) else:
ff = FeedForwardNetwork(network) network = Network(features, cells, layers, recurrent=use_recurrent)
loss = ff.fit(epochs=10) ff = FeedForwardNetwork(network)
print(loss) loss = ff.fit(epochs=10)
print(loss)