plotting
This commit is contained in:
27
code/soup.py
27
code/soup.py
@ -1,9 +1,5 @@
|
||||
import random
|
||||
import copy
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from experiment import *
|
||||
from network import *
|
||||
|
||||
|
||||
@ -21,7 +17,19 @@ class Soup:
|
||||
self.params = dict(attacking_rate=0.1, train_other_rate=0.1, train=0)
|
||||
self.params.update(kwargs)
|
||||
self.time = 0
|
||||
|
||||
|
||||
def __copy__(self):
|
||||
copy_ = Soup(self.size, self.generator, **self.params)
|
||||
copy_.__dict__ = {attr: self.__dict__[attr] for attr in self.__dict__ if
|
||||
attr not in ['particles', 'historical_particles']}
|
||||
return copy_
|
||||
|
||||
def without_particles(self):
|
||||
self_copy = copy.copy(self)
|
||||
# self_copy.particles = [particle.states for particle in self.particles]
|
||||
self_copy.historical_particles = {key: val.states for key, val in self.historical_particles.items()}
|
||||
return self_copy
|
||||
|
||||
def with_params(self, **kwargs):
|
||||
self.params.update(kwargs)
|
||||
return self
|
||||
@ -94,6 +102,7 @@ class Soup:
|
||||
particle.print_weights()
|
||||
print(particle.is_fixpoint())
|
||||
|
||||
|
||||
class ParticleDecorator:
|
||||
|
||||
next_uid = 0
|
||||
@ -131,7 +140,6 @@ class ParticleDecorator:
|
||||
return self.states
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if False:
|
||||
with SoupExperiment() as exp:
|
||||
@ -155,12 +163,11 @@ if __name__ == '__main__':
|
||||
# net_generator = lambda: AggregatingNeuralNetwork(4, 2, 2).with_keras_params(activation='sigmoid')\
|
||||
# .with_params(shuffler=AggregatingNeuralNetwork.shuffle_random)
|
||||
# net_generator = lambda: RecurrentNeuralNetwork(2, 2).with_keras_params(activation='linear').with_params()
|
||||
soup = Soup(10, net_generator).with_params(remove_divergent=True, remove_zero=True, train=200)
|
||||
soup = Soup(10, net_generator).with_params(remove_divergent=True, remove_zero=True, train=10)
|
||||
soup.seed()
|
||||
for _ in tqdm(range(10)):
|
||||
for _ in tqdm(range(100)):
|
||||
soup.evolve()
|
||||
soup.print_all()
|
||||
exp.log(soup.count())
|
||||
exp.save(soup=soup) # you can access soup.historical_particles[particle_uid].states[time_step]['loss']
|
||||
exp.save(soup=soup.without_particles()) # you can access soup.historical_particles[particle_uid].states[time_step]['loss']
|
||||
# or soup.historical_particles[particle_uid].states[time_step]['weights'] from soup.dill
|
||||
|
||||
|
Reference in New Issue
Block a user