Things
This commit is contained in:
46
code/soup.py
46
code/soup.py
@ -7,7 +7,7 @@ def prng():
|
||||
return random.random()
|
||||
|
||||
|
||||
class Soup:
|
||||
class Soup(object):
|
||||
|
||||
def __init__(self, size, generator, **kwargs):
|
||||
self.size = size
|
||||
@ -105,50 +105,6 @@ class Soup:
|
||||
print(particle.is_fixpoint())
|
||||
|
||||
|
||||
class ParticleDecorator:
|
||||
|
||||
next_uid = 0
|
||||
|
||||
def __init__(self, net):
|
||||
self.uid = self.__class__.next_uid
|
||||
self.__class__.next_uid += 1
|
||||
self.net = net
|
||||
self.states = []
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.net, name)
|
||||
|
||||
def get_uid(self):
|
||||
return self.uid
|
||||
|
||||
def make_state(self, **kwargs):
|
||||
weights = self.net.get_weights_flat()
|
||||
if any(np.isinf(weights)):
|
||||
return None
|
||||
state = {'class': self.net.__class__.__name__, 'weights': weights}
|
||||
state.update(kwargs)
|
||||
return state
|
||||
|
||||
def save_state(self, **kwargs):
|
||||
state = self.make_state(**kwargs)
|
||||
if state is not None:
|
||||
self.states += [state]
|
||||
else:
|
||||
pass
|
||||
|
||||
def update_state(self, number, **kwargs):
|
||||
raise NotImplementedError('Result is vague')
|
||||
if number < len(self.states):
|
||||
self.states[number] = self.make_state(**kwargs)
|
||||
else:
|
||||
for i in range(len(self.states), number):
|
||||
self.states += [None]
|
||||
self.states += self.make_state(**kwargs)
|
||||
|
||||
def get_states(self):
|
||||
return self.states
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if True:
|
||||
with SoupExperiment() as exp:
|
||||
|
Reference in New Issue
Block a user