Visuals
This commit is contained in:
29
code/soup.py
29
code/soup.py
@ -57,28 +57,30 @@ class Soup:
|
||||
other_particle_id = int(prng() * len(self.particles))
|
||||
other_particle = self.particles[other_particle_id]
|
||||
particle.attack(other_particle)
|
||||
description['attacking'] = other_particle.get_uid()
|
||||
description['action'] = 'attacking'
|
||||
description['counterpart'] = other_particle.get_uid()
|
||||
if prng() < self.params.get('train_other_rate'):
|
||||
other_particle_id = int(prng() * len(self.particles))
|
||||
other_particle = self.particles[other_particle_id]
|
||||
particle.train_other(other_particle)
|
||||
description['training'] = other_particle.get_uid()
|
||||
description['action'] = 'train_other'
|
||||
description['counterpart'] = other_particle.get_uid()
|
||||
for _ in range(self.params.get('train', 0)):
|
||||
loss = particle.compiled().train()
|
||||
description['fitted'] = self.params.get('train', 0)
|
||||
description['loss'] = loss
|
||||
description['action'] = 'train_self'
|
||||
description['counterpart'] = None
|
||||
if self.params.get('remove_divergent') and particle.is_diverged():
|
||||
new_particle = self.generate_particle()
|
||||
self.particles[particle_id] = new_particle
|
||||
description['died'] = True
|
||||
description['cause'] = 'divergent'
|
||||
description['substitute'] = new_particle.get_uid()
|
||||
description['action'] = 'divergent_dead'
|
||||
description['counterpart'] = new_particle.get_uid()
|
||||
if self.params.get('remove_zero') and particle.is_zero():
|
||||
new_particle = self.generate_particle()
|
||||
self.particles[particle_id] = new_particle
|
||||
description['died'] = True
|
||||
description['cause'] = 'zero'
|
||||
description['substitute'] = new_particle.get_uid()
|
||||
description['action'] = 'zweo_dead'
|
||||
description['counterpart'] = new_particle.get_uid()
|
||||
particle.save_state(**description)
|
||||
|
||||
def count(self):
|
||||
@ -120,15 +122,22 @@ class ParticleDecorator:
|
||||
return self.uid
|
||||
|
||||
def make_state(self, **kwargs):
|
||||
state = {'class': self.net.__class__.__name__, 'weights': self.net.get_weights()}
|
||||
weights = self.net.get_weights_flat()
|
||||
if any(np.isinf(weights)):
|
||||
return None
|
||||
state = {'class': self.net.__class__.__name__, 'weights': weights}
|
||||
state.update(kwargs)
|
||||
return state
|
||||
|
||||
def save_state(self, **kwargs):
|
||||
state = self.make_state(**kwargs)
|
||||
self.states += [state]
|
||||
if state is not None:
|
||||
self.states += [state]
|
||||
else:
|
||||
pass
|
||||
|
||||
def update_state(self, number, **kwargs):
|
||||
raise NotImplementedError('Result is vague')
|
||||
if number < len(self.states):
|
||||
self.states[number] = self.make_state(**kwargs)
|
||||
else:
|
||||
|
Reference in New Issue
Block a user