Journal TEx Text
This commit is contained in:
@@ -63,6 +63,23 @@ class SoupExperiment:
|
||||
net = Net(self.net_input_size, self.net_hidden_size, self.net_out_size, net_name)
|
||||
self.population.append(net)
|
||||
|
||||
def population_self_train(self):
|
||||
# Self-training each network in the population
|
||||
for j in range(self.population_size):
|
||||
net = self.population[j]
|
||||
|
||||
for _ in range(self.ST_steps):
|
||||
net.self_train(1, self.log_step_size, self.net_learning_rate)
|
||||
|
||||
def population_attack(self):
|
||||
# A network attacking another network with a given percentage
|
||||
if random.randint(1, 100) <= self.attack_chance:
|
||||
random_net1, random_net2 = random.sample(range(self.population_size), 2)
|
||||
random_net1 = self.population[random_net1]
|
||||
random_net2 = self.population[random_net2]
|
||||
print(f"\n Attack: {random_net1.name} -> {random_net2.name}")
|
||||
random_net1.attack(random_net2)
|
||||
|
||||
def evolve(self):
|
||||
""" Evolving consists of attacking & self-training. """
|
||||
|
||||
@@ -71,19 +88,10 @@ class SoupExperiment:
|
||||
loop_epochs.set_description("Evolving soup %s" % i)
|
||||
|
||||
# A network attacking another network with a given percentage
|
||||
if random.randint(1, 100) <= self.attack_chance:
|
||||
random_net1, random_net2 = random.sample(range(self.population_size), 2)
|
||||
random_net1 = self.population[random_net1]
|
||||
random_net2 = self.population[random_net2]
|
||||
print(f"\n Attack: {random_net1.name} -> {random_net2.name}")
|
||||
random_net1.attack(random_net2)
|
||||
self.population_attack()
|
||||
|
||||
# Self-training each network in the population
|
||||
for j in range(self.population_size):
|
||||
net = self.population[j]
|
||||
|
||||
for _ in range(self.ST_steps):
|
||||
net.self_train(1, self.log_step_size, self.net_learning_rate)
|
||||
self.population_self_train()
|
||||
|
||||
# Testing for fixpoints after each batch of ST steps to see relevant data
|
||||
if i % self.ST_steps == 0:
|
||||
|
50
experiments/soup_melt_exp.py
Normal file
50
experiments/soup_melt_exp.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import random
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from experiments.soup_exp import SoupExperiment
|
||||
from functionalities_test import test_for_fixpoints
|
||||
|
||||
|
||||
class MeltingSoupExperiment(SoupExperiment):
|
||||
|
||||
def __init__(self, melt_chance, *args, keep_population_size=True, **kwargs):
|
||||
super(MeltingSoupExperiment, self).__init__(*args, **kwargs)
|
||||
self.keep_population_size = keep_population_size
|
||||
self.melt_chance = melt_chance
|
||||
|
||||
def population_melt(self):
|
||||
# A network melting with another network by a given percentage
|
||||
if random.randint(1, 100) <= self.melt_chance:
|
||||
random_net1_idx, random_net2_idx, destroy_idx = random.sample(range(self.population_size), 3)
|
||||
random_net1 = self.population[random_net1_idx]
|
||||
random_net2 = self.population[random_net2_idx]
|
||||
print(f"\n Melt: {random_net1.name} -> {random_net2.name}")
|
||||
melted_network = random_net1.melt(random_net2)
|
||||
if self.keep_population_size:
|
||||
del self.population[destroy_idx]
|
||||
self.population.append(melted_network)
|
||||
|
||||
def evolve(self):
|
||||
""" Evolving consists of attacking, melting & self-training. """
|
||||
|
||||
loop_epochs = tqdm(range(self.epochs))
|
||||
for i in loop_epochs:
|
||||
loop_epochs.set_description("Evolving soup %s" % i)
|
||||
|
||||
self.population_attack()
|
||||
|
||||
self.population_melt()
|
||||
|
||||
self.population_self_train()
|
||||
|
||||
# Testing for fixpoints after each batch of ST steps to see relevant data
|
||||
if i % self.ST_steps == 0:
|
||||
test_for_fixpoints(self.fixpoint_counters, self.population)
|
||||
fixpoints_percentage = round(self.fixpoint_counters["identity_func"] / self.population_size, 1)
|
||||
self.fixpoint_counters_history.append(fixpoints_percentage)
|
||||
|
||||
# Resetting the fixpoint counter. Last iteration not to be reset -
|
||||
# it is important for the bar_chart_fixpoints().
|
||||
if i < self.epochs:
|
||||
self.reset_fixpoint_counters()
|
Reference in New Issue
Block a user