in between upload

This commit is contained in:
Steffen Illium
2022-02-27 17:56:25 +01:00
parent 78a919395b
commit 926b27b4ef
3 changed files with 167 additions and 104 deletions

View File

@ -16,7 +16,7 @@ from torch.nn import Flatten
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import MNIST from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Compose, Resize from torchvision.transforms import ToTensor, Compose, Resize
from tqdm import tqdm from tqdm import tqdm, trange
# noinspection DuplicatedCode # noinspection DuplicatedCode
if platform.node() == 'CarbonX': if platform.node() == 'CarbonX':
@ -46,7 +46,7 @@ WORKER = 10 if not debug else 2
debug = False debug = False
BATCHSIZE = 500 if not debug else 50 BATCHSIZE = 500 if not debug else 50
EPOCH = 100 EPOCH = 100
VALIDATION_FRQ = 4 if not debug else 1 VALIDATION_FRQ = 3 if not debug else 1
SELF_TRAIN_FRQ = 1 if not debug else 1 SELF_TRAIN_FRQ = 1 if not debug else 1
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@ -292,24 +292,23 @@ if __name__ == '__main__':
train_to_task_first = False train_to_task_first = False
sequential_task_train = True sequential_task_train = True
force_st_for_n_from_last_epochs = 5 force_st_for_n_from_last_epochs = 5
n_st_per_batch = 3 n_st_per_batch = 10
activation = None # nn.ReLU() # activation = None # nn.ReLU()
use_sparse_network = False use_sparse_network = True
for weight_hidden_size in [8]: for weight_hidden_size in [4, 5, 6]:
tsk_threshold = 0.85 tsk_threshold = 0.85
weight_hidden_size = weight_hidden_size weight_hidden_size = weight_hidden_size
residual_skip = True residual_skip = True
n_seeds = 3 n_seeds = 1
data_path = Path('data') data_path = Path('data')
data_path.mkdir(exist_ok=True, parents=True) data_path.mkdir(exist_ok=True, parents=True)
assert not (train_to_task_first and train_to_id_first)
st_str = f'{"" if self_train else "no_"}st{f"_n_{n_st_per_batch}" if n_st_per_batch else ""}' st_str = f'{"" if self_train else "no_"}st{f"_n_{n_st_per_batch}" if n_st_per_batch else ""}'
ac_str = f'_{activation.__class__.__name__}' if activation is not None else '' # ac_str = f'_{activation.__class__.__name__}' if activation is not None else ''
res_str = f'{"" if residual_skip else "_no_res"}' res_str = f'{"" if residual_skip else "_no_res"}'
# dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}' # dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}'
id_str = f'{f"_StToId" if train_to_id_first else ""}' id_str = f'{f"_StToId" if train_to_id_first else ""}'
@ -318,7 +317,7 @@ if __name__ == '__main__':
f_str = f'_f_{force_st_for_n_from_last_epochs}' if \ f_str = f'_f_{force_st_for_n_from_last_epochs}' if \
force_st_for_n_from_last_epochs and sequential_task_train and train_to_task_first else "" force_st_for_n_from_last_epochs and sequential_task_train and train_to_task_first else ""
config_str = f'{res_str}{id_str}{tsk_str}{f_str}{sprs_str}' config_str = f'{res_str}{id_str}{tsk_str}{f_str}{sprs_str}'
exp_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{config_str}{ac_str}' exp_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{config_str}'
if not training: if not training:
# noinspection PyRedeclaration # noinspection PyRedeclaration
@ -326,10 +325,12 @@ if __name__ == '__main__':
for seed in range(n_seeds): for seed in range(n_seeds):
seed_path = exp_path / str(seed) seed_path = exp_path / str(seed)
seed_path.mkdir(exist_ok=True, parents=True)
model_path = seed_path / '0000_trained_model.zip' model_path = seed_path / '0000_trained_model.zip'
df_store_path = seed_path / 'train_store.csv' df_store_path = seed_path / 'train_store.csv'
weight_store_path = seed_path / 'weight_store.csv' weight_store_path = seed_path / 'weight_store.csv'
init_st_store_path = seed_path / 'init_st_counter.csv'
srnn_parameters = dict() srnn_parameters = dict()
if training: if training:
@ -345,92 +346,139 @@ if __name__ == '__main__':
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER) d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
interface = np.prod(dataset[0][0].shape) interface = np.prod(dataset[0][0].shape)
dense_metanet = MetaNet(interface, depth=5, width=6, out=10, residual_skip=residual_skip, dense_metanet = MetaNet(interface, depth=3, width=6, out=10, residual_skip=residual_skip,
weight_hidden_size=weight_hidden_size, activation=activation).to(DEVICE) weight_hidden_size=weight_hidden_size
sparse_metanet = SparseNetwork(interface, depth=5, width=6, out=10, residual_skip=residual_skip, ).to(DEVICE)
weight_hidden_size=weight_hidden_size, activation=activation sparse_metanet = SparseNetwork(interface, depth=3, width=6, out=10, residual_skip=residual_skip,
weight_hidden_size=weight_hidden_size
).to(DEVICE) if use_sparse_network else dense_metanet ).to(DEVICE) if use_sparse_network else dense_metanet
meta_weight_count = sum(p.numel() for p in next(dense_metanet.particles).parameters()) meta_weight_count = sum(p.numel() for p in next(dense_metanet.particles).parameters())
loss_fn = nn.CrossEntropyLoss() loss_fn = nn.CrossEntropyLoss()
dense_optimizer = torch.optim.SGD(dense_metanet.parameters(), lr=0.004, momentum=0.9) optimizer = torch.optim.SGD(sparse_metanet.parameters(), lr=0.004, momentum=0.9)
sparse_optimizer = torch.optim.SGD(
sparse_metanet.parameters(), lr=0.004, momentum=0.9
) if use_sparse_network else dense_optimizer
train_store = new_storage_df('train', None) train_store = new_storage_df('train', None)
weight_store = new_storage_df('weights', meta_weight_count) weight_store = new_storage_df('weights', meta_weight_count)
init_tsk = train_to_task_first
for epoch in tqdm(range(EPOCH), desc=f'Train - Epochs'): if train_to_task_first:
dense_metanet = dense_metanet.train()
for epoch in trange(10):
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='Train - Batch'):
# Task Train
# Zero your gradients for every batch!
optimizer.zero_grad()
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
y_pred = dense_metanet(batch_x)
loss = loss_fn(y_pred, batch_y.to(torch.long))
loss.backward()
# Adjust learning weights
optimizer.step()
step_log = dict(Epoch=epoch, Batch=batch,
Metric='Task Loss', Score=loss.item())
train_store.loc[train_store.shape[0]] = step_log
# Transfer weights
if use_sparse_network:
sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles)
if train_to_id_first:
sparse_metanet = sparse_metanet.train()
init_st_epochs = 1500
init_st_df = pd.DataFrame(columns=['Epoch', 'Func Type', 'Count'])
for st_epoch in trange(init_st_epochs):
_ = sparse_metanet.combined_self_train(optimizer)
if st_epoch % 500 == 0:
counter = defaultdict(lambda: 0)
id_functions = test_for_fixpoints(counter, list(sparse_metanet.particles))
counter = dict(counter)
tqdm.write(f"identity_fn after {st_epoch} self-train epochs: {counter}")
for key, value in counter.items():
init_st_df.loc[init_st_df.shape[0]] = (st_epoch, key, value)
sparse_metanet.reset_diverged_particles()
counter = defaultdict(lambda: 0)
id_functions = test_for_fixpoints(counter, list(sparse_metanet.particles))
counter = dict(counter)
tqdm.write(f"identity_fn after {init_st_epochs} self-train epochs: {counter}")
for key, value in counter.items():
init_st_df.loc[init_st_df.shape[0]] = (init_st_epochs, key, value)
init_st_df.to_csv(init_st_store_path, mode='w', index=False)
c = pd.read_csv(init_st_store_path)
sns.lineplot(data=c, x='Epoch', y='Count', hue='Func Type')
plt.savefig(init_st_store_path.parent / f'{init_st_store_path.stem}.png', dpi=300)
# Transfer weights
if use_sparse_network:
dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights)
for epoch in trange(EPOCH, desc=f'Train - Epochs'):
tqdm.write(f'{seed}: {exp_path}') tqdm.write(f'{seed}: {exp_path}')
is_validation_epoch = epoch % VALIDATION_FRQ == 0 if not debug else True is_validation_epoch = (epoch % VALIDATION_FRQ == 0) if not debug else True
is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True is_self_train_epoch = (epoch % SELF_TRAIN_FRQ == 0) if not debug else True
sparse_metanet = sparse_metanet.train() sparse_metanet = sparse_metanet.train()
dense_metanet = dense_metanet.train() dense_metanet = dense_metanet.train()
if is_validation_epoch: if is_validation_epoch:
metric = torchmetrics.Accuracy() metric = torchmetrics.Accuracy()
else: else:
metric = None metric = None
init_st = train_to_id_first and not all(
x.is_fixpoint == ft.identity_func for x in dense_metanet.particles for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='Train - Batch'):
)
force_st = (force_st_for_n_from_last_epochs >= (EPOCH - epoch)
) and sequential_task_train and force_st_for_n_from_last_epochs
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
# Self Train # Self Train
if self_train and ((not init_tsk and (is_self_train_epoch or init_st)) or force_st): if is_self_train_epoch:
# Transfer weights
if use_sparse_network:
sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles)
for _ in range(n_st_per_batch): for _ in range(n_st_per_batch):
self_train_loss = sparse_metanet.combined_self_train(sparse_optimizer, reduction='mean') self_train_loss = sparse_metanet.combined_self_train(optimizer)
# noinspection PyUnboundLocalVariable # noinspection PyUnboundLocalVariable
step_log = dict(Epoch=epoch, Batch=batch, step_log = dict(Epoch=epoch, Batch=batch,
Metric='Self Train Loss', Score=self_train_loss.item()) Metric='Self Train Loss', Score=self_train_loss.item())
train_store.loc[train_store.shape[0]] = step_log train_store.loc[train_store.shape[0]] = step_log
# Clean Divergent
sparse_metanet.reset_diverged_particles()
# Transfer weights # Transfer weights
if use_sparse_network: if use_sparse_network:
dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights) dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights)
dense_metanet.reset_diverged_particles()
# Task Train # Task Train
if not init_st: # Zero your gradients for every batch!
# Zero your gradients for every batch! optimizer.zero_grad()
dense_optimizer.zero_grad() batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE) y_pred = dense_metanet(batch_x)
y_pred = dense_metanet(batch_x)
# loss = loss_fn(y, batch_y.unsqueeze(-1).to(torch.float32))
loss = loss_fn(y_pred, batch_y.to(torch.long))
loss.backward()
# Adjust learning weights loss = loss_fn(y_pred, batch_y.to(torch.long))
dense_optimizer.step() loss.backward()
step_log = dict(Epoch=epoch, Batch=batch, # Adjust learning weights
Metric='Task Loss', Score=loss.item()) optimizer.step()
train_store.loc[train_store.shape[0]] = step_log
if is_validation_epoch: # Transfer weights
metric(y_pred.cpu(), batch_y.cpu()) if use_sparse_network:
sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles)
step_log = dict(Epoch=epoch, Batch=batch,
Metric='Task Loss', Score=loss.item())
train_store.loc[train_store.shape[0]] = step_log
if is_validation_epoch:
metric(y_pred.cpu(), batch_y.cpu())
if batch >= 3 and debug: if batch >= 3 and debug:
break break
if is_validation_epoch: if is_validation_epoch:
dense_metanet = dense_metanet.eval() dense_metanet = dense_metanet.eval()
if not init_st:
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric='Train Accuracy', Score=metric.compute().item()) Metric='Train Accuracy', Score=metric.compute().item())
train_store.loc[train_store.shape[0]] = validation_log train_store.loc[train_store.shape[0]] = validation_log
accuracy = checkpoint_and_validate(dense_metanet, seed_path, epoch).item() accuracy = checkpoint_and_validate(dense_metanet, seed_path, epoch).item()
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric='Test Accuracy', Score=accuracy) Metric='Test Accuracy', Score=accuracy)
train_store.loc[train_store.shape[0]] = validation_log train_store.loc[train_store.shape[0]] = validation_log
if init_tsk or (train_to_task_first and sequential_task_train):
init_tsk = accuracy <= tsk_threshold if is_validation_epoch:
if init_st or is_validation_epoch:
counter_dict = defaultdict(lambda: 0) counter_dict = defaultdict(lambda: 0)
# This returns ID-functions # This returns ID-functions
_ = test_for_fixpoints(counter_dict, list(dense_metanet.particles)) _ = test_for_fixpoints(counter_dict, list(dense_metanet.particles))
@ -439,12 +487,14 @@ if __name__ == '__main__':
step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value) step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value)
train_store.loc[train_store.shape[0]] = step_log train_store.loc[train_store.shape[0]] = step_log
tqdm.write(f'Fixpoint Tester Results: {counter_dict}') tqdm.write(f'Fixpoint Tester Results: {counter_dict}')
if init_st or is_validation_epoch:
for particle in dense_metanet.particles: for particle in dense_metanet.particles:
weight_log = (epoch, particle.name, *flat_for_store(particle.parameters())) weight_log = (epoch, particle.name, *flat_for_store(particle.parameters()))
weight_store.loc[weight_store.shape[0]] = weight_log weight_store.loc[weight_store.shape[0]] = weight_log
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False) train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(),
weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False) index=False)
weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(),
index=False)
train_store = new_storage_df('train', None) train_store = new_storage_df('train', None)
weight_store = new_storage_df('weights', meta_weight_count) weight_store = new_storage_df('weights', meta_weight_count)

View File

@ -445,10 +445,12 @@ class MetaNet(nn.Module):
tensor = self._meta_layer_first(x) tensor = self._meta_layer_first(x)
residual = None residual = None
for idx, meta_layer in enumerate(self._meta_layer_list, start=1): for idx, meta_layer in enumerate(self._meta_layer_list, start=1):
if idx % 2 == 1 and self.residual_skip: # if idx % 2 == 1 and self.residual_skip:
if self.residual_skip:
residual = tensor residual = tensor
tensor = meta_layer(tensor) tensor = meta_layer(tensor)
if idx % 2 == 0 and self.residual_skip: # if idx % 2 == 0 and self.residual_skip:
if self.residual_skip:
tensor = tensor + residual tensor = tensor + residual
tensor = self._meta_layer_last(tensor) tensor = self._meta_layer_last(tensor)
return tensor return tensor

View File

@ -1,25 +1,29 @@
from collections import defaultdict from collections import defaultdict
import pandas as pd import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
from torch import nn from torch import nn
import functionalities_test import functionalities_test
from network import Net from network import Net
from functionalities_test import is_identity_function from functionalities_test import is_identity_function, test_for_fixpoints, epsilon_error_margin
from tqdm import tqdm,trange from tqdm import tqdm, trange
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
import torch import torch
from torch.nn import Flatten from torch.nn import Flatten
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision.datasets import MNIST from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Compose, Resize from torchvision.transforms import ToTensor, Compose, Resize
def xavier_init(m): def xavier_init(m):
if isinstance(m, nn.Linear): if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data) return nn.init.xavier_uniform_(m.weight.data)
if isinstance(m, torch.Tensor):
return nn.init.xavier_uniform_(m)
class SparseLayer(nn.Module): class SparseLayer(nn.Module):
@ -101,7 +105,9 @@ class SparseLayer(nn.Module):
for weights in self.weights: for weights in self.weights:
if torch.isinf(weights).any() or torch.isnan(weights).any(): if torch.isinf(weights).any() or torch.isnan(weights).any():
with torch.no_grad(): with torch.no_grad():
xavier_init(weights) where_nan = torch.nan_to_num(weights, -99, -99, -99)
mask = torch.where(where_nan == -99, 0, 1)
weights[:] = (where_nan * mask + torch.randn_like(weights) * (1 - mask))[:]
@property @property
def particle_weights(self): def particle_weights(self):
@ -139,8 +145,9 @@ def test_sparse_layer():
optimizer = torch.optim.SGD(net.parameters(), lr=0.008, momentum=0.9) optimizer = torch.optim.SGD(net.parameters(), lr=0.008, momentum=0.9)
# optimizer = torch.optim.SGD([layer.coalesce().values() for layer in net.sparse_sub_layer], lr=0.004, momentum=0.9) # optimizer = torch.optim.SGD([layer.coalesce().values() for layer in net.sparse_sub_layer], lr=0.004, momentum=0.9)
df = pd.DataFrame(columns=['Epoch', 'Func Type', 'Count']) df = pd.DataFrame(columns=['Epoch', 'Func Type', 'Count'])
train_iterations = 20000
for train_iteration in trange(20000): for train_iteration in trange(train_iterations):
optimizer.zero_grad() optimizer.zero_grad()
X, Y = net.get_self_train_inputs_and_targets() X, Y = net.get_self_train_inputs_and_targets()
output = net(X) output = net(X)
@ -163,12 +170,11 @@ def test_sparse_layer():
counter = defaultdict(lambda: 0) counter = defaultdict(lambda: 0)
id_functions = functionalities_test.test_for_fixpoints(counter, list(net.particles)) id_functions = functionalities_test.test_for_fixpoints(counter, list(net.particles))
counter = dict(counter) counter = dict(counter)
tqdm.write(f"identity_fn after {train_iteration + 1} self-train epochs: {counter}") tqdm.write(f"identity_fn after {train_iterations} self-train epochs: {counter}")
for key, value in counter.items(): for key, value in counter.items():
df.loc[df.shape[0]] = (train_iteration, key, value) df.loc[df.shape[0]] = (train_iterations, key, value)
df.to_csv('counter.csv', mode='w') df.to_csv('counter.csv', mode='w')
import seaborn as sns
import matplotlib.pyplot as plt
c = pd.read_csv('counter.csv', index_col=0) c = pd.read_csv('counter.csv', index_col=0)
sns.lineplot(data=c, x='Epoch', y='Count', hue='Func Type') sns.lineplot(data=c, x='Epoch', y='Count', hue='Func Type')
plt.savefig('counter.png', dpi=300) plt.savefig('counter.png', dpi=300)
@ -191,6 +197,11 @@ def embed_vector(x, repeat_dim):
class SparseNetwork(nn.Module): class SparseNetwork(nn.Module):
@property
def nr_nets(self):
return sum(x.nr_nets for x in self.sparselayers)
def __init__(self, input_dim, depth, width, out, residual_skip=True, activation=None, def __init__(self, input_dim, depth, width, out, residual_skip=True, activation=None,
weight_interface=5, weight_hidden_size=2, weight_output_size=1 weight_interface=5, weight_hidden_size=2, weight_output_size=1
): ):
@ -216,16 +227,13 @@ class SparseNetwork(nn.Module):
if self.activation: if self.activation:
tensor = self.activation(tensor) tensor = self.activation(tensor)
for nl_idx, network_layer in enumerate(self.hidden_layers): for nl_idx, network_layer in enumerate(self.hidden_layers):
# Sparse Layer pass # if idx % 2 == 1 and self.residual_skip:
if self.residual_skip:
residual = tensor
tensor = self.sparse_layer_forward(tensor, network_layer) tensor = self.sparse_layer_forward(tensor, network_layer)
# if idx % 2 == 0 and self.residual_skip:
if self.activation: if self.residual_skip:
tensor = self.activation(tensor) tensor = tensor + residual
if nl_idx % 2 == 0 and self.residual_skip:
residual = tensor.clone()
if nl_idx % 2 == 1 and self.residual_skip:
# noinspection PyUnboundLocalVariable
tensor += residual
tensor = self.sparse_layer_forward(tensor, self.last_layer, view_dim=self.out_dim) tensor = self.sparse_layer_forward(tensor, self.last_layer, view_dim=self.out_dim)
return tensor return tensor
@ -282,7 +290,7 @@ class SparseNetwork(nn.Module):
output = layer(x) output = layer(x)
# loss = sum([loss_fn(out, target) for out, target in zip(output, target_data)]) / len(output) # loss = sum([loss_fn(out, target) for out, target in zip(output, target_data)]) / len(output)
loss = loss_fn(output, target_data) * 85 loss = loss_fn(output, target_data) * layer.nr_nets
losses.append(loss.detach()) losses.append(loss.detach())
loss.backward() loss.backward()
@ -311,39 +319,42 @@ def test_sparse_net():
data_dim = np.prod(dataset[0][0].shape) data_dim = np.prod(dataset[0][0].shape)
metanet = SparseNetwork(data_dim, depth=3, width=5, out=10) metanet = SparseNetwork(data_dim, depth=3, width=5, out=10)
batchx, batchy = next(iter(d)) batchx, batchy = next(iter(d))
metanet(batchx) out = metanet(batchx)
print(f"identity_fn after {train_iteration+1} self-train iterations: {sum([torch.allclose(out[i], Y[i], rtol=0, atol=epsilon) for i in range(net.nr_nets)])}/{net.nr_nets}")
result = sum([torch.allclose(out[i], batchy[i], rtol=0, atol=epsilon_error_margin) for i in range(metanet.nr_nets)])
# print(f"identity_fn after {train_iteration+1} self-train iterations: {result} /{net.nr_nets}")
def test_sparse_net_sef_train(): def test_sparse_net_sef_train():
net = SparseNetwork(5, 5, 6, 10) sparse_metanet = SparseNetwork(15*15, 5, 6, 10).to('cuda')
epochs = 10000 init_st_store_path = Path('counter.csv')
df = pd.DataFrame(columns=['Epoch', 'Func Type', 'Count']) optimizer = torch.optim.SGD(sparse_metanet.parameters(), lr=0.004, momentum=0.9)
optimizer = torch.optim.SGD(net.parameters(), lr=0.004, momentum=0.9) init_st_epochs = 10000
for epoch in trange(epochs): init_st_df = pd.DataFrame(columns=['Epoch', 'Func Type', 'Count'])
_ = net.combined_self_train(optimizer)
if epoch % 500 == 0: for st_epoch in trange(init_st_epochs):
_ = sparse_metanet.combined_self_train(optimizer)
if st_epoch % 500 == 0:
counter = defaultdict(lambda: 0) counter = defaultdict(lambda: 0)
id_functions = functionalities_test.test_for_fixpoints(counter, list(net.particles)) id_functions = test_for_fixpoints(counter, list(sparse_metanet.particles))
counter = dict(counter) counter = dict(counter)
tqdm.write(f"identity_fn after {epoch + 1} self-train epochs: {counter}") tqdm.write(f"identity_fn after {st_epoch} self-train epochs: {counter}")
for key, value in counter.items(): for key, value in counter.items():
df.loc[df.shape[0]] = (epoch, key, value) init_st_df.loc[init_st_df.shape[0]] = (st_epoch, key, value)
net.reset_diverged_particles() sparse_metanet.reset_diverged_particles()
counter = defaultdict(lambda: 0) counter = defaultdict(lambda: 0)
id_functions = functionalities_test.test_for_fixpoints(counter, list(net.particles)) id_functions = test_for_fixpoints(counter, list(sparse_metanet.particles))
counter = dict(counter) counter = dict(counter)
tqdm.write(f"identity_fn after {epochs} self-train epochs: {counter}") tqdm.write(f"identity_fn after {init_st_epochs} self-train epochs: {counter}")
for key, value in counter.items(): for key, value in counter.items():
df.loc[df.shape[0]] = (epoch, key, value) init_st_df.loc[init_st_df.shape[0]] = (init_st_epochs, key, value)
df.to_csv('counter.csv', mode='w') init_st_df.to_csv(init_st_store_path, mode='w', index=False)
import seaborn as sns
import matplotlib.pyplot as plt c = pd.read_csv(init_st_store_path)
c = pd.read_csv('counter.csv', index_col=0)
sns.lineplot(data=c, x='Epoch', y='Count', hue='Func Type') sns.lineplot(data=c, x='Epoch', y='Count', hue='Func Type')
plt.savefig('counter.png', dpi=300) plt.savefig(init_st_store_path, dpi=300)
def test_manual_for_loop(): def test_manual_for_loop():
@ -353,7 +364,7 @@ def test_manual_for_loop():
rounds = 1000 rounds = 1000
for net in tqdm(nets): for net in tqdm(nets):
optimizer = torch.optim.SGD(net.parameters(), lr=0.004, momentum=0.9) optimizer = torch.optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)
for i in range(rounds): for i in range(rounds):
optimizer.zero_grad() optimizer.zero_grad()
input_data = net.input_weight_matrix() input_data = net.input_weight_matrix()