save states for particles in soup (another decorator pattern)
This commit is contained in:
parent
2966b41baf
commit
c564c02b31
80
code/soup.py
80
code/soup.py
@ -17,36 +17,61 @@ class Soup:
|
|||||||
self.size = size
|
self.size = size
|
||||||
self.generator = generator
|
self.generator = generator
|
||||||
self.particles = []
|
self.particles = []
|
||||||
self.params = dict(meeting_rate=0.1, train_other_rate=0.1, train=0)
|
self.historical_particles = {}
|
||||||
|
self.params = dict(attacking_rate=0.1, train_other_rate=0.1, train=0)
|
||||||
self.params.update(kwargs)
|
self.params.update(kwargs)
|
||||||
|
self.time = 0
|
||||||
|
|
||||||
def with_params(self, **kwargs):
|
def with_params(self, **kwargs):
|
||||||
self.params.update(kwargs)
|
self.params.update(kwargs)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def generate_particle(self):
|
||||||
|
new_particle = ParticleDecorator(self.generator())
|
||||||
|
self.historical_particles[new_particle.get_uid()] = new_particle
|
||||||
|
return new_particle
|
||||||
|
|
||||||
|
def get_particle(self, uid, otherwise=None):
|
||||||
|
return self.historical_particles.get(uid, otherwise)
|
||||||
|
|
||||||
def seed(self):
|
def seed(self):
|
||||||
self.particles = []
|
self.particles = []
|
||||||
for _ in range(self.size):
|
for _ in range(self.size):
|
||||||
self.particles += [self.generator()]
|
self.particles += [self.generate_particle()]
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def evolve(self, iterations=1):
|
def evolve(self, iterations=1):
|
||||||
for _ in range(iterations):
|
for _ in range(iterations):
|
||||||
|
self.time += 1
|
||||||
for particle_id, particle in enumerate(self.particles):
|
for particle_id, particle in enumerate(self.particles):
|
||||||
if prng() < self.params.get('meeting_rate'):
|
description = {'time': self.time}
|
||||||
|
if prng() < self.params.get('attacking_rate'):
|
||||||
other_particle_id = int(prng() * len(self.particles))
|
other_particle_id = int(prng() * len(self.particles))
|
||||||
other_particle = self.particles[other_particle_id]
|
other_particle = self.particles[other_particle_id]
|
||||||
particle.attack(other_particle)
|
particle.attack(other_particle)
|
||||||
|
description['attacking'] = other_particle.get_uid()
|
||||||
if prng() < self.params.get('train_other_rate'):
|
if prng() < self.params.get('train_other_rate'):
|
||||||
other_particle_id = int(prng() * len(self.particles))
|
other_particle_id = int(prng() * len(self.particles))
|
||||||
other_particle = self.particles[other_particle_id]
|
other_particle = self.particles[other_particle_id]
|
||||||
particle.train_other(other_particle)
|
particle.train_other(other_particle)
|
||||||
|
description['training'] = other_particle.get_uid()
|
||||||
for _ in range(self.params.get('train', 0)):
|
for _ in range(self.params.get('train', 0)):
|
||||||
particle.compiled().train()
|
loss = particle.compiled().train()
|
||||||
|
description['fitted'] = self.params.get('train', 0)
|
||||||
|
description['loss'] = loss
|
||||||
if self.params.get('remove_divergent') and particle.is_diverged():
|
if self.params.get('remove_divergent') and particle.is_diverged():
|
||||||
self.particles[particle_id] = self.generator()
|
new_particle = self.generate_particle()
|
||||||
|
self.particles[particle_id] = new_particle
|
||||||
|
description['died'] = True
|
||||||
|
description['cause'] = 'divergent'
|
||||||
|
description['substitute'] = new_particle.get_uid()
|
||||||
if self.params.get('remove_zero') and particle.is_zero():
|
if self.params.get('remove_zero') and particle.is_zero():
|
||||||
self.particles[particle_id] = self.generator()
|
new_particle = self.generate_particle()
|
||||||
|
self.particles[particle_id] = new_particle
|
||||||
|
description['died'] = True
|
||||||
|
description['cause'] = 'zero'
|
||||||
|
description['substitute'] = new_particle.get_uid()
|
||||||
|
particle.save_state(**description)
|
||||||
|
|
||||||
def count(self):
|
def count(self):
|
||||||
counters = dict(divergent=0, fix_zero=0, fix_other=0, fix_sec=0, other=0)
|
counters = dict(divergent=0, fix_zero=0, fix_other=0, fix_sec=0, other=0)
|
||||||
@ -69,11 +94,42 @@ class Soup:
|
|||||||
particle.print_weights()
|
particle.print_weights()
|
||||||
print(particle.is_fixpoint())
|
print(particle.is_fixpoint())
|
||||||
|
|
||||||
|
class ParticleDecorator:
|
||||||
class LearningSoup(Soup):
|
|
||||||
|
next_uid = 0
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super(LearningSoup, self).__init__(**kwargs)
|
def __init__(self, net):
|
||||||
|
self.uid = self.__class__.next_uid
|
||||||
|
self.__class__.next_uid += 1
|
||||||
|
self.net = net
|
||||||
|
self.states = []
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self.net, name)
|
||||||
|
|
||||||
|
def get_uid(self):
|
||||||
|
return self.uid
|
||||||
|
|
||||||
|
def make_state(self, **kwargs):
|
||||||
|
state = {'class': self.net.__class__.__name__, 'weights': self.net.get_weights()}
|
||||||
|
state.update(kwargs)
|
||||||
|
return state
|
||||||
|
|
||||||
|
def save_state(self, **kwargs):
|
||||||
|
state = self.make_state(**kwargs)
|
||||||
|
self.states += [state]
|
||||||
|
|
||||||
|
def update_state(self, number, **kwargs):
|
||||||
|
if number < len(self.states):
|
||||||
|
self.states[number] = self.make_state(**kwargs)
|
||||||
|
else:
|
||||||
|
for i in range(len(self.states), number):
|
||||||
|
self.states += [None]
|
||||||
|
self.states += self.make_state(**kwargs)
|
||||||
|
|
||||||
|
def get_states(self):
|
||||||
|
return self.states
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -105,4 +161,6 @@ if __name__ == '__main__':
|
|||||||
soup.evolve()
|
soup.evolve()
|
||||||
soup.print_all()
|
soup.print_all()
|
||||||
exp.log(soup.count())
|
exp.log(soup.count())
|
||||||
|
exp.save(soup=soup) # you can access soup.historical_particles[particle_uid].states[time_step]['loss']
|
||||||
|
# or soup.historical_particles[particle_uid].states[time_step]['weights'] from soup.dill
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user