diff --git a/code/soup.py b/code/soup.py index d813378..8f4e435 100644 --- a/code/soup.py +++ b/code/soup.py @@ -17,36 +17,61 @@ class Soup: self.size = size self.generator = generator self.particles = [] - self.params = dict(meeting_rate=0.1, train_other_rate=0.1, train=0) + self.historical_particles = {} + self.params = dict(attacking_rate=0.1, train_other_rate=0.1, train=0) self.params.update(kwargs) + self.time = 0 def with_params(self, **kwargs): self.params.update(kwargs) return self + def generate_particle(self): + new_particle = ParticleDecorator(self.generator()) + self.historical_particles[new_particle.get_uid()] = new_particle + return new_particle + + def get_particle(self, uid, otherwise=None): + return self.historical_particles.get(uid, otherwise) + def seed(self): self.particles = [] for _ in range(self.size): - self.particles += [self.generator()] + self.particles += [self.generate_particle()] return self def evolve(self, iterations=1): for _ in range(iterations): + self.time += 1 for particle_id, particle in enumerate(self.particles): - if prng() < self.params.get('meeting_rate'): + description = {'time': self.time} + if prng() < self.params.get('attacking_rate'): other_particle_id = int(prng() * len(self.particles)) other_particle = self.particles[other_particle_id] particle.attack(other_particle) + description['attacking'] = other_particle.get_uid() if prng() < self.params.get('train_other_rate'): other_particle_id = int(prng() * len(self.particles)) other_particle = self.particles[other_particle_id] particle.train_other(other_particle) + description['training'] = other_particle.get_uid() for _ in range(self.params.get('train', 0)): - particle.compiled().train() + loss = particle.compiled().train() + description['fitted'] = self.params.get('train', 0) + description['loss'] = loss if self.params.get('remove_divergent') and particle.is_diverged(): - self.particles[particle_id] = self.generator() + new_particle = self.generate_particle() + self.particles[particle_id] = new_particle + description['died'] = True + description['cause'] = 'divergent' + description['substitute'] = new_particle.get_uid() if self.params.get('remove_zero') and particle.is_zero(): - self.particles[particle_id] = self.generator() + new_particle = self.generate_particle() + self.particles[particle_id] = new_particle + description['died'] = True + description['cause'] = 'zero' + description['substitute'] = new_particle.get_uid() + particle.save_state(**description) def count(self): counters = dict(divergent=0, fix_zero=0, fix_other=0, fix_sec=0, other=0) @@ -69,11 +94,42 @@ class Soup: particle.print_weights() print(particle.is_fixpoint()) - -class LearningSoup(Soup): - - def __init__(self, *args, **kwargs): - super(LearningSoup, self).__init__(**kwargs) +class ParticleDecorator: + + next_uid = 0 + + def __init__(self, net): + self.uid = self.__class__.next_uid + self.__class__.next_uid += 1 + self.net = net + self.states = [] + + def __getattr__(self, name): + return getattr(self.net, name) + + def get_uid(self): + return self.uid + + def make_state(self, **kwargs): + state = {'class': self.net.__class__.__name__, 'weights': self.net.get_weights()} + state.update(kwargs) + return state + + def save_state(self, **kwargs): + state = self.make_state(**kwargs) + self.states += [state] + + def update_state(self, number, **kwargs): + if number < len(self.states): + self.states[number] = self.make_state(**kwargs) + else: + for i in range(len(self.states), number): + self.states += [None] + self.states += self.make_state(**kwargs) + + def get_states(self): + return self.states + if __name__ == '__main__': @@ -105,4 +161,6 @@ if __name__ == '__main__': soup.evolve() soup.print_all() exp.log(soup.count()) + exp.save(soup=soup) # you can access soup.historical_particles[particle_uid].states[time_step]['loss'] + # or soup.historical_particles[particle_uid].states[time_step]['weights'] from soup.dill