This commit is contained in:
Si11ium
2019-03-07 12:05:58 +01:00
parent bae997feab
commit 95c2ff4200
4 changed files with 128 additions and 48 deletions

View File

@ -57,28 +57,30 @@ class Soup:
other_particle_id = int(prng() * len(self.particles))
other_particle = self.particles[other_particle_id]
particle.attack(other_particle)
description['attacking'] = other_particle.get_uid()
description['action'] = 'attacking'
description['counterpart'] = other_particle.get_uid()
if prng() < self.params.get('train_other_rate'):
other_particle_id = int(prng() * len(self.particles))
other_particle = self.particles[other_particle_id]
particle.train_other(other_particle)
description['training'] = other_particle.get_uid()
description['action'] = 'train_other'
description['counterpart'] = other_particle.get_uid()
for _ in range(self.params.get('train', 0)):
loss = particle.compiled().train()
description['fitted'] = self.params.get('train', 0)
description['loss'] = loss
description['action'] = 'train_self'
description['counterpart'] = None
if self.params.get('remove_divergent') and particle.is_diverged():
new_particle = self.generate_particle()
self.particles[particle_id] = new_particle
description['died'] = True
description['cause'] = 'divergent'
description['substitute'] = new_particle.get_uid()
description['action'] = 'divergent_dead'
description['counterpart'] = new_particle.get_uid()
if self.params.get('remove_zero') and particle.is_zero():
new_particle = self.generate_particle()
self.particles[particle_id] = new_particle
description['died'] = True
description['cause'] = 'zero'
description['substitute'] = new_particle.get_uid()
description['action'] = 'zweo_dead'
description['counterpart'] = new_particle.get_uid()
particle.save_state(**description)
def count(self):
@ -120,15 +122,22 @@ class ParticleDecorator:
return self.uid
def make_state(self, **kwargs):
state = {'class': self.net.__class__.__name__, 'weights': self.net.get_weights()}
weights = self.net.get_weights_flat()
if any(np.isinf(weights)):
return None
state = {'class': self.net.__class__.__name__, 'weights': weights}
state.update(kwargs)
return state
def save_state(self, **kwargs):
state = self.make_state(**kwargs)
self.states += [state]
if state is not None:
self.states += [state]
else:
pass
def update_state(self, number, **kwargs):
raise NotImplementedError('Result is vague')
if number < len(self.states):
self.states[number] = self.make_state(**kwargs)
else: