This commit is contained in:
Si11ium
2019-03-06 19:21:19 +01:00
parent 6ced18c2d7
commit bae997feab
3 changed files with 69 additions and 43 deletions

View File

@ -1,9 +1,5 @@
import random
import copy
from tqdm import tqdm
from experiment import *
from network import *
@ -21,7 +17,19 @@ class Soup:
self.params = dict(attacking_rate=0.1, train_other_rate=0.1, train=0)
self.params.update(kwargs)
self.time = 0
def __copy__(self):
copy_ = Soup(self.size, self.generator, **self.params)
copy_.__dict__ = {attr: self.__dict__[attr] for attr in self.__dict__ if
attr not in ['particles', 'historical_particles']}
return copy_
def without_particles(self):
self_copy = copy.copy(self)
# self_copy.particles = [particle.states for particle in self.particles]
self_copy.historical_particles = {key: val.states for key, val in self.historical_particles.items()}
return self_copy
def with_params(self, **kwargs):
self.params.update(kwargs)
return self
@ -94,6 +102,7 @@ class Soup:
particle.print_weights()
print(particle.is_fixpoint())
class ParticleDecorator:
next_uid = 0
@ -131,7 +140,6 @@ class ParticleDecorator:
return self.states
if __name__ == '__main__':
if False:
with SoupExperiment() as exp:
@ -155,12 +163,11 @@ if __name__ == '__main__':
# net_generator = lambda: AggregatingNeuralNetwork(4, 2, 2).with_keras_params(activation='sigmoid')\
# .with_params(shuffler=AggregatingNeuralNetwork.shuffle_random)
# net_generator = lambda: RecurrentNeuralNetwork(2, 2).with_keras_params(activation='linear').with_params()
soup = Soup(10, net_generator).with_params(remove_divergent=True, remove_zero=True, train=200)
soup = Soup(10, net_generator).with_params(remove_divergent=True, remove_zero=True, train=10)
soup.seed()
for _ in tqdm(range(10)):
for _ in tqdm(range(100)):
soup.evolve()
soup.print_all()
exp.log(soup.count())
exp.save(soup=soup) # you can access soup.historical_particles[particle_uid].states[time_step]['loss']
exp.save(soup=soup.without_particles()) # you can access soup.historical_particles[particle_uid].states[time_step]['loss']
# or soup.historical_particles[particle_uid].states[time_step]['weights'] from soup.dill