journal linspace basins
This commit is contained in:
@@ -84,21 +84,6 @@ def distance_from_parent(nets, distance="MIM", print_it=True):
|
||||
|
||||
class SpawnExperiment:
|
||||
|
||||
@staticmethod
|
||||
def apply_noise(network, noise: int):
|
||||
""" Changing the weights of a network to values + noise """
|
||||
|
||||
for layer_id, layer_name in enumerate(network.state_dict()):
|
||||
for line_id, line_values in enumerate(network.state_dict()[layer_name]):
|
||||
for weight_id, weight_value in enumerate(network.state_dict()[layer_name][line_id]):
|
||||
# network.state_dict()[layer_name][line_id][weight_id] = weight_value + noise
|
||||
if prng() < 0.5:
|
||||
network.state_dict()[layer_name][line_id][weight_id] = weight_value + noise
|
||||
else:
|
||||
network.state_dict()[layer_name][line_id][weight_id] = weight_value - noise
|
||||
|
||||
return network
|
||||
|
||||
def __init__(self, population_size, log_step_size, net_input_size, net_hidden_size, net_out_size, net_learning_rate,
|
||||
epochs, st_steps, nr_clones, noise, directory) -> None:
|
||||
self.population_size = population_size
|
||||
@@ -171,7 +156,7 @@ class SpawnExperiment:
|
||||
f"ST_net_{str(i)}_clone_{str(j)}", start_time=self.ST_steps)
|
||||
clone.load_state_dict(copy.deepcopy(net.state_dict()))
|
||||
rand_noise = prng() * self.noise
|
||||
clone = self.apply_noise(clone, rand_noise)
|
||||
clone = clone.apply_noise(rand_noise)
|
||||
clone.s_train_weights_history = copy.deepcopy(net.s_train_weights_history)
|
||||
clone.number_trained = copy.deepcopy(net.number_trained)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user