This commit is contained in:
Si11ium
2019-03-07 12:05:58 +01:00
parent bae997feab
commit 95c2ff4200
4 changed files with 128 additions and 48 deletions

View File

@ -2,8 +2,7 @@ import os
import time
import dill
from tqdm import tqdm
from collections import defaultdict
import copy
class Experiment:
@ -19,7 +18,7 @@ class Experiment:
self.base_dir = self.experiment_name
self.next_iteration = 0
self.log_messages = []
self.data_storage = defaultdict(list)
self.historical_particles = dict()
def __enter__(self):
self.dir = os.path.join(self.base_dir, 'experiments', 'exp-{name}-{id}-{it}'.format(
@ -31,7 +30,7 @@ class Experiment:
return self
def __exit__(self, exc_type, exc_value, traceback):
self.save(experiment=self)
self.save(experiment=self.without_particles())
self.save_log()
self.next_iteration += 1
@ -43,14 +42,26 @@ class Experiment:
with open(os.path.join(self.dir, "{name}.txt".format(name=log_name)), "w") as log_file:
for log_message in self.log_messages:
print(str(log_message), file=log_file)
def __copy__(self):
copy_ = Experiment(name=self.experiment_name,)
copy_.__dict__ = {attr: self.__dict__[attr] for attr in self.__dict__ if
attr not in ['particles', 'historical_particles']}
return copy_
def without_particles(self):
self_copy = copy.copy(self)
# self_copy.particles = [particle.states for particle in self.particles]
self_copy.historical_particles = {key: val.states for key, val in self.historical_particles.items()}
return self_copy
def save(self, **kwargs):
for name, value in kwargs.items():
with open(os.path.join(self.dir, "{name}.dill".format(name=name)), "wb") as dill_file:
dill.dump(value, dill_file)
def add_trajectory_segment(self, run_id, trajectory):
self.data_storage[run_id].append(trajectory)
self.historical_particles[run_id].append(trajectory)
return