smaller task train

This commit is contained in:
Steffen Illium
2022-03-03 14:57:26 +01:00
parent 926b27b4ef
commit e167cc78c5
4 changed files with 490 additions and 170 deletions

View File

@ -1,5 +1,6 @@
import pickle import pickle
import re import re
import shutil
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
import sys import sys
@ -45,11 +46,14 @@ from functionalities_test import test_for_fixpoints
WORKER = 10 if not debug else 2 WORKER = 10 if not debug else 2
debug = False debug = False
BATCHSIZE = 500 if not debug else 50 BATCHSIZE = 500 if not debug else 50
EPOCH = 100 EPOCH = 50
VALIDATION_FRQ = 3 if not debug else 1 VALIDATION_FRQ = 3 if not debug else 1
SELF_TRAIN_FRQ = 1 if not debug else 1 SELF_TRAIN_FRQ = 1 if not debug else 1
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DATA_PATH = Path('data')
DATA_PATH.mkdir(exist_ok=True, parents=True)
if debug: if debug:
torch.autograd.set_detect_anomaly(True) torch.autograd.set_detect_anomaly(True)
@ -86,6 +90,9 @@ def set_checkpoint(model, out_path, epoch_n, final_model=False):
ckpt_path.parent.mkdir(exist_ok=True, parents=True) ckpt_path.parent.mkdir(exist_ok=True, parents=True)
torch.save(model, ckpt_path, pickle_protocol=pickle.HIGHEST_PROTOCOL) torch.save(model, ckpt_path, pickle_protocol=pickle.HIGHEST_PROTOCOL)
py_store_path = Path(out_path) / 'exp_py.txt'
if not py_store_path.exists():
shutil.copy(__file__, py_store_path)
return ckpt_path return ckpt_path
@ -98,9 +105,9 @@ def validate(checkpoint_path, ratio=0.1):
ut = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)]) ut = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)])
try: try:
datas = MNIST(str(data_path), transform=ut, train=False) datas = MNIST(str(DATA_PATH), transform=ut, train=False)
except RuntimeError: except RuntimeError:
datas = MNIST(str(data_path), transform=ut, train=False, download=True) datas = MNIST(str(DATA_PATH), transform=ut, train=False, download=True)
valid_d = DataLoader(datas, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER) valid_d = DataLoader(datas, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
model = torch.load(checkpoint_path, map_location=DEVICE).eval() model = torch.load(checkpoint_path, map_location=DEVICE).eval()
@ -171,13 +178,13 @@ def plot_training_result(path_to_dataframe):
# plots the first set of data # plots the first set of data
data = df[(df['Metric'] == 'Task Loss') | (df['Metric'] == 'Self Train Loss')].groupby(['Epoch', 'Metric']).mean() data = df[(df['Metric'] == 'Task Loss') | (df['Metric'] == 'Self Train Loss')].groupby(['Epoch', 'Metric']).mean()
palette = sns.color_palette()[0:data.reset_index()['Metric'].unique().shape[0]] palette = sns.color_palette()[1:data.reset_index()['Metric'].unique().shape[0]+1]
sns.lineplot(data=data.groupby(['Epoch', 'Metric']).mean(), x='Epoch', y='Score', hue='Metric', sns.lineplot(data=data.groupby(['Epoch', 'Metric']).mean(), x='Epoch', y='Score', hue='Metric',
palette=palette, ax=ax1) palette=palette, ax=ax1)
# plots the second set of data # plots the second set of data
data = df[(df['Metric'] == 'Test Accuracy') | (df['Metric'] == 'Train Accuracy')] data = df[(df['Metric'] == 'Test Accuracy') | (df['Metric'] == 'Train Accuracy')]
palette = sns.color_palette()[len(palette):data.reset_index()['Metric'].unique().shape[0] + len(palette)] palette = sns.color_palette()[len(palette)+1:data.reset_index()['Metric'].unique().shape[0] + len(palette)+1]
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', hue='Metric', palette=palette) sns.lineplot(data=data, x='Epoch', y='Score', marker='o', hue='Metric', palette=palette)
ax1.set(yscale='log', ylabel='Losses') ax1.set(yscale='log', ylabel='Losses')
@ -195,7 +202,7 @@ def plot_training_result(path_to_dataframe):
def plot_network_connectivity_by_fixtype(path_to_trained_model): def plot_network_connectivity_by_fixtype(path_to_trained_model):
m = torch.load(path_to_trained_model, map_location=torch.device('cpu')) m = torch.load(path_to_trained_model, map_location=torch.device('cpu')).eval()
# noinspection PyProtectedMember # noinspection PyProtectedMember
particles = list(m.particles) particles = list(m.particles)
df = pd.DataFrame(columns=['type', 'layer', 'neuron', 'name']) df = pd.DataFrame(columns=['type', 'layer', 'neuron', 'name'])
@ -209,8 +216,9 @@ def plot_network_connectivity_by_fixtype(path_to_trained_model):
divisor = df.loc[(df['layer'] == layer), 'neuron'].max() divisor = df.loc[(df['layer'] == layer), 'neuron'].max()
df.loc[(df['layer'] == layer), 'neuron'] /= divisor df.loc[(df['layer'] == layer), 'neuron'] /= divisor
print('gathered') tqdm.write(f'Connectivity Data gathered')
for n, fixtype in enumerate([ft.other_func, ft.identity_func]): for n, fixtype in enumerate(ft.all_types()):
if df[df['type'] == fixtype].shape[0] > 0:
plt.clf() plt.clf()
ax = sns.lineplot(y='neuron', x='layer', hue='name', data=df[df['type'] == fixtype], ax = sns.lineplot(y='neuron', x='layer', hue='name', data=df[df['type'] == fixtype],
legend=False, estimator=None, lw=1) legend=False, estimator=None, lw=1)
@ -223,17 +231,22 @@ def plot_network_connectivity_by_fixtype(path_to_trained_model):
plt.show() plt.show()
else: else:
plt.savefig(Path(path_to_trained_model.parent / f'net_connectivity_{fixtype}.png'), dpi=300) plt.savefig(Path(path_to_trained_model.parent / f'net_connectivity_{fixtype}.png'), dpi=300)
print('plottet') tqdm.write(f'Connectivity plottet: {fixtype} - n = {df[df["type"] == fixtype].shape[0]}')
else:
tqdm.write(f'No Connectivity {fixtype}')
def run_particle_dropout_test(run_path): def run_particle_dropout_test(model_path):
diff_store_path = run_path / 'diff_store.csv' diff_store_path = model_path.parent / 'diff_store.csv'
latest_model = torch.load(model_path, map_location=DEVICE).eval()
prtcl_dict = defaultdict(lambda: 0) prtcl_dict = defaultdict(lambda: 0)
_ = test_for_fixpoints(prtcl_dict, list(latest_model.particles)) _ = test_for_fixpoints(prtcl_dict, list(latest_model.particles))
tqdm.write(str(dict(prtcl_dict))) tqdm.write(str(dict(prtcl_dict)))
diff_df = pd.DataFrame(columns=['Particle Type', 'Accuracy', 'Diff'])
acc_pre = validate(model_path, ratio=1).item() acc_pre = validate(model_path, ratio=1).item()
diff_df = pd.DataFrame(columns=['Particle Type', 'Accuracy', 'Diff']) diff_df.loc[diff_df.shape[0]] = ('All Organism', acc_pre, 0)
for fixpoint_type in ft.all_types(): for fixpoint_type in ft.all_types():
new_model = torch.load(model_path, map_location=DEVICE).eval().replace_with_zero(fixpoint_type) new_model = torch.load(model_path, map_location=DEVICE).eval().replace_with_zero(fixpoint_type)
if [x for x in new_model.particles if x.is_fixpoint == fixpoint_type]: if [x for x in new_model.particles if x.is_fixpoint == fixpoint_type]:
@ -247,14 +260,16 @@ def run_particle_dropout_test(run_path):
return diff_store_path return diff_store_path
def plot_dropout_stacked_barplot(path_to_diff_df): def plot_dropout_stacked_barplot(model_path):
diff_df = pd.read_csv(path_to_diff_df) diff_store_path = model_path.parent / 'diff_store.csv'
diff_df = pd.read_csv(diff_store_path)
particle_dict = defaultdict(lambda: 0) particle_dict = defaultdict(lambda: 0)
latest_model = torch.load(model_path, map_location=DEVICE).eval()
_ = test_for_fixpoints(particle_dict, list(latest_model.particles)) _ = test_for_fixpoints(particle_dict, list(latest_model.particles))
tqdm.write(str(dict(particle_dict))) tqdm.write(str(dict(particle_dict)))
plt.clf() plt.clf()
fig, ax = plt.subplots(ncols=2) fig, ax = plt.subplots(ncols=2)
colors = sns.color_palette()[:diff_df.shape[0]] colors = sns.color_palette()[1:diff_df.shape[0]+1]
barplot = sns.barplot(data=diff_df, y='Accuracy', x='Particle Type', ax=ax[0], palette=colors) barplot = sns.barplot(data=diff_df, y='Accuracy', x='Particle Type', ax=ax[0], palette=colors)
# noinspection PyUnboundLocalVariable # noinspection PyUnboundLocalVariable
#for idx, patch in enumerate(barplot.patches): #for idx, patch in enumerate(barplot.patches):
@ -265,18 +280,18 @@ def plot_dropout_stacked_barplot(path_to_diff_df):
ax[0].set_title('Accuracy after particle dropout') ax[0].set_title('Accuracy after particle dropout')
ax[0].set_xlabel('Particle Type') ax[0].set_xlabel('Particle Type')
ax[1].pie(particle_dict.values(), labels=particle_dict.keys(), colors=colors, ) ax[1].pie(particle_dict.values(), labels=particle_dict.keys(), colors=list(reversed(colors)), )
ax[1].set_title('Particle Count') ax[1].set_title('Particle Count')
plt.tight_layout() plt.tight_layout()
if debug: if debug:
plt.show() plt.show()
else: else:
plt.savefig(Path(path_to_diff_df.parent / 'dropout_stacked_barplot.png'), dpi=300) plt.savefig(Path(diff_store_path.parent / 'dropout_stacked_barplot.png'), dpi=300)
def run_particle_dropout_and_plot(run_path): def run_particle_dropout_and_plot(model_path):
diff_store_path = run_particle_dropout_test(run_path) diff_store_path = run_particle_dropout_test(model_path)
plot_dropout_stacked_barplot(diff_store_path) plot_dropout_stacked_barplot(diff_store_path)
@ -284,40 +299,63 @@ def flat_for_store(parameters):
return (x.item() for y in parameters for x in y.detach().flatten()) return (x.item() for y in parameters for x in y.detach().flatten())
def train_self_replication(model, optimizer, st_steps) -> dict:
for _ in range(st_steps):
self_train_loss = model.combined_self_train(optimizer)
# noinspection PyUnboundLocalVariable
step_log = dict(Metric='Self Train Loss', Score=self_train_loss.item())
return step_log
def train_task(model, optimizer, loss_func, btch_x, btch_y) -> (dict, torch.Tensor):
# Zero your gradients for every batch!
optimizer.zero_grad()
btch_x, btch_y = btch_x.to(DEVICE), btch_y.to(DEVICE)
y_prd = model(btch_x)
# loss = loss_fn(y, batch_y.unsqueeze(-1).to(torch.float32))
loss = loss_func(y_prd, btch_y.to(torch.float))
loss.backward()
# Adjust learning weights
optimizer.step()
stp_log = dict(Metric='Task Loss', Score=loss.item())
return stp_log, y_prd
if __name__ == '__main__': if __name__ == '__main__':
self_train = True
training = True training = True
train_to_id_first = False train_to_id_first = True
train_to_task_first = False train_to_task_first = False
sequential_task_train = True seq_task_train = True
force_st_for_n_from_last_epochs = 5 force_st_for_epochs_n = 5
n_st_per_batch = 10 n_st_per_batch = 2
# activation = None # nn.ReLU() activation = None # nn.ReLU()
use_sparse_network = True use_sparse_network = False
for weight_hidden_size in [4, 5, 6]: for weight_hidden_size in [4, 5, 6]:
tsk_threshold = 0.85 tsk_threshold = 0.85
weight_hidden_size = weight_hidden_size weight_hidden_size = weight_hidden_size
residual_skip = True residual_skip = False
n_seeds = 1 n_seeds = 3
depth = 3
data_path = Path('data') assert not (train_to_task_first and train_to_id_first)
data_path.mkdir(exist_ok=True, parents=True)
st_str = f'{"" if self_train else "no_"}st{f"_n_{n_st_per_batch}" if n_st_per_batch else ""}' ac_str = f'_{activation.__class__.__name__}' if activation is not None else ''
# ac_str = f'_{activation.__class__.__name__}' if activation is not None else ''
res_str = f'{"" if residual_skip else "_no_res"}' res_str = f'{"" if residual_skip else "_no_res"}'
# dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}' # dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}'
id_str = f'{f"_StToId" if train_to_id_first else ""}' id_str = f'{f"_StToId" if train_to_id_first else ""}'
tsk_str = f'{f"_Tsk_{tsk_threshold}" if train_to_task_first and tsk_threshold != 1 else ""}' tsk_str = f'{f"_Tsk_{tsk_threshold}" if train_to_task_first and tsk_threshold != 1 else ""}'
sprs_str = '_sprs' if use_sparse_network else '' sprs_str = '_sprs' if use_sparse_network else ''
f_str = f'_f_{force_st_for_n_from_last_epochs}' if \ f_str = f'_f_{force_st_for_epochs_n}' if \
force_st_for_n_from_last_epochs and sequential_task_train and train_to_task_first else "" force_st_for_epochs_n and seq_task_train and train_to_task_first else ""
config_str = f'{res_str}{id_str}{tsk_str}{f_str}{sprs_str}' config_str = f'{res_str}{id_str}{tsk_str}{f_str}{sprs_str}'
exp_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{config_str}' exp_path = Path('output') / f'mn_st_{EPOCH}_{weight_hidden_size}{config_str}{ac_str}'
if not training: if not training:
# noinspection PyRedeclaration # noinspection PyRedeclaration
@ -325,12 +363,10 @@ if __name__ == '__main__':
for seed in range(n_seeds): for seed in range(n_seeds):
seed_path = exp_path / str(seed) seed_path = exp_path / str(seed)
seed_path.mkdir(exist_ok=True, parents=True)
model_path = seed_path / '0000_trained_model.zip' model_path = seed_path / '0000_trained_model.zip'
df_store_path = seed_path / 'train_store.csv' df_store_path = seed_path / 'train_store.csv'
weight_store_path = seed_path / 'weight_store.csv' weight_store_path = seed_path / 'weight_store.csv'
init_st_store_path = seed_path / 'init_st_counter.csv'
srnn_parameters = dict() srnn_parameters = dict()
if training: if training:
@ -340,135 +376,84 @@ if __name__ == '__main__':
utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)]) utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)])
try: try:
dataset = MNIST(str(data_path), transform=utility_transforms) dataset = MNIST(str(DATA_PATH), transform=utility_transforms)
except RuntimeError: except RuntimeError:
dataset = MNIST(str(data_path), transform=utility_transforms, download=True) dataset = MNIST(str(DATA_PATH), transform=utility_transforms, download=True)
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER) d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
interface = np.prod(dataset[0][0].shape) interface = np.prod(dataset[0][0].shape)
dense_metanet = MetaNet(interface, depth=3, width=6, out=10, residual_skip=residual_skip, dense_metanet = MetaNet(interface, depth=depth, width=6, out=10, residual_skip=residual_skip,
weight_hidden_size=weight_hidden_size weight_hidden_size=weight_hidden_size, activation=activation).to(DEVICE)
).to(DEVICE) sparse_metanet = SparseNetwork(interface, depth=depth, width=6, out=10, residual_skip=residual_skip,
sparse_metanet = SparseNetwork(interface, depth=3, width=6, out=10, residual_skip=residual_skip, weight_hidden_size=weight_hidden_size, activation=activation
weight_hidden_size=weight_hidden_size
).to(DEVICE) if use_sparse_network else dense_metanet ).to(DEVICE) if use_sparse_network else dense_metanet
meta_weight_count = sum(p.numel() for p in next(dense_metanet.particles).parameters()) if use_sparse_network:
sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles)
loss_fn = nn.CrossEntropyLoss() loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(sparse_metanet.parameters(), lr=0.004, momentum=0.9) dense_optimizer = torch.optim.SGD(dense_metanet.parameters(), lr=0.004, momentum=0.9)
sparse_optimizer = torch.optim.SGD(
sparse_metanet.parameters(), lr=0.001, momentum=0.9
) if use_sparse_network else dense_optimizer
dense_weights_updated = False
sparse_weights_updated = False
train_store = new_storage_df('train', None) train_store = new_storage_df('train', None)
weight_store = new_storage_df('weights', meta_weight_count) weight_store = new_storage_df('weights', dense_metanet.particle_parameter_count)
if train_to_task_first: init_tsk = train_to_task_first
dense_metanet = dense_metanet.train() for epoch in tqdm(range(EPOCH), desc=f'Train - Epochs'):
for epoch in trange(10): is_validation_epoch = epoch % VALIDATION_FRQ == 0 if not debug else True
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='Train - Batch'): is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True
# Task Train
# Zero your gradients for every batch!
optimizer.zero_grad()
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
y_pred = dense_metanet(batch_x)
loss = loss_fn(y_pred, batch_y.to(torch.long))
loss.backward()
# Adjust learning weights
optimizer.step()
step_log = dict(Epoch=epoch, Batch=batch,
Metric='Task Loss', Score=loss.item())
train_store.loc[train_store.shape[0]] = step_log
# Transfer weights
if use_sparse_network:
sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles)
if train_to_id_first:
sparse_metanet = sparse_metanet.train()
init_st_epochs = 1500
init_st_df = pd.DataFrame(columns=['Epoch', 'Func Type', 'Count'])
for st_epoch in trange(init_st_epochs):
_ = sparse_metanet.combined_self_train(optimizer)
if st_epoch % 500 == 0:
counter = defaultdict(lambda: 0)
id_functions = test_for_fixpoints(counter, list(sparse_metanet.particles))
counter = dict(counter)
tqdm.write(f"identity_fn after {st_epoch} self-train epochs: {counter}")
for key, value in counter.items():
init_st_df.loc[init_st_df.shape[0]] = (st_epoch, key, value)
sparse_metanet.reset_diverged_particles()
counter = defaultdict(lambda: 0)
id_functions = test_for_fixpoints(counter, list(sparse_metanet.particles))
counter = dict(counter)
tqdm.write(f"identity_fn after {init_st_epochs} self-train epochs: {counter}")
for key, value in counter.items():
init_st_df.loc[init_st_df.shape[0]] = (init_st_epochs, key, value)
init_st_df.to_csv(init_st_store_path, mode='w', index=False)
c = pd.read_csv(init_st_store_path)
sns.lineplot(data=c, x='Epoch', y='Count', hue='Func Type')
plt.savefig(init_st_store_path.parent / f'{init_st_store_path.stem}.png', dpi=300)
# Transfer weights
if use_sparse_network:
dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights)
for epoch in trange(EPOCH, desc=f'Train - Epochs'):
tqdm.write(f'{seed}: {exp_path}')
is_validation_epoch = (epoch % VALIDATION_FRQ == 0) if not debug else True
is_self_train_epoch = (epoch % SELF_TRAIN_FRQ == 0) if not debug else True
sparse_metanet = sparse_metanet.train() sparse_metanet = sparse_metanet.train()
dense_metanet = dense_metanet.train() dense_metanet = dense_metanet.train()
if is_validation_epoch:
# Init metrics, even we do not need:
metric = torchmetrics.Accuracy() metric = torchmetrics.Accuracy()
else:
metric = None
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='Train - Batch'): # Define what to train in this epoch:
do_tsk_train = train_to_task_first
force_st = (force_st_for_epochs_n >= (EPOCH - epoch)) and force_st_for_epochs_n
init_st = (train_to_id_first and not dense_metanet.count_fixpoints() > 200)
do_st_train = init_st or is_self_train_epoch or force_st
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
# Self Train # Self Train
if is_self_train_epoch: if do_st_train:
for _ in range(n_st_per_batch):
self_train_loss = sparse_metanet.combined_self_train(optimizer)
# noinspection PyUnboundLocalVariable
step_log = dict(Epoch=epoch, Batch=batch,
Metric='Self Train Loss', Score=self_train_loss.item())
train_store.loc[train_store.shape[0]] = step_log
# Clean Divergent
sparse_metanet.reset_diverged_particles()
# Transfer weights # Transfer weights
if dense_weights_updated:
sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles)
dense_weights_updated = False
st_steps = n_st_per_batch if not init_st else n_st_per_batch * 10
step_log = train_self_replication(sparse_metanet, sparse_optimizer, st_steps)
step_log.update(dict(Epoch=epoch, Batch=batch))
train_store.loc[train_store.shape[0]] = step_log
if use_sparse_network: if use_sparse_network:
dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights) sparse_weights_updated = True
# Task Train # Task Train
# Zero your gradients for every batch! if not init_st:
optimizer.zero_grad()
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
y_pred = dense_metanet(batch_x)
loss = loss_fn(y_pred, batch_y.to(torch.long))
loss.backward()
# Adjust learning weights
optimizer.step()
# Transfer weights # Transfer weights
if use_sparse_network: if sparse_weights_updated:
sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles) dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights)
sparse_weights_updated = False
step_log, y_pred = train_task(dense_metanet, dense_optimizer, loss_fn, batch_x, batch_y)
step_log = dict(Epoch=epoch, Batch=batch, step_log.update(dict(Epoch=epoch, Batch=batch))
Metric='Task Loss', Score=loss.item())
train_store.loc[train_store.shape[0]] = step_log train_store.loc[train_store.shape[0]] = step_log
if is_validation_epoch: if use_sparse_network:
dense_weights_updated = True
metric(y_pred.cpu(), batch_y.cpu()) metric(y_pred.cpu(), batch_y.cpu())
if batch >= 3 and debug:
break
if is_validation_epoch: if is_validation_epoch:
dense_metanet = dense_metanet.eval() if sparse_weights_updated:
dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights)
sparse_weights_updated = False
dense_metanet = dense_metanet.eval()
if do_tsk_train:
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric='Train Accuracy', Score=metric.compute().item()) Metric='Train Accuracy', Score=metric.compute().item())
train_store.loc[train_store.shape[0]] = validation_log train_store.loc[train_store.shape[0]] = validation_log
@ -477,8 +462,12 @@ if __name__ == '__main__':
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric='Test Accuracy', Score=accuracy) Metric='Test Accuracy', Score=accuracy)
train_store.loc[train_store.shape[0]] = validation_log train_store.loc[train_store.shape[0]] = validation_log
if init_tsk or (train_to_task_first and seq_task_train):
if is_validation_epoch: init_tsk = accuracy <= tsk_threshold
if init_st or is_validation_epoch:
if dense_weights_updated:
sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles)
dense_weights_updated = False
counter_dict = defaultdict(lambda: 0) counter_dict = defaultdict(lambda: 0)
# This returns ID-functions # This returns ID-functions
_ = test_for_fixpoints(counter_dict, list(dense_metanet.particles)) _ = test_for_fixpoints(counter_dict, list(dense_metanet.particles))
@ -487,18 +476,26 @@ if __name__ == '__main__':
step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value) step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value)
train_store.loc[train_store.shape[0]] = step_log train_store.loc[train_store.shape[0]] = step_log
tqdm.write(f'Fixpoint Tester Results: {counter_dict}') tqdm.write(f'Fixpoint Tester Results: {counter_dict}')
if sum(x.is_fixpoint == ft.identity_func for x in dense_metanet.particles) > 200:
train_to_id_first = False
# Reset Diverged particles
sparse_metanet.reset_diverged_particles()
if use_sparse_network:
sparse_weights_updated = True
# FLUSH to disk
if is_validation_epoch:
for particle in dense_metanet.particles: for particle in dense_metanet.particles:
weight_log = (epoch, particle.name, *flat_for_store(particle.parameters())) weight_log = (epoch, particle.name, *flat_for_store(particle.parameters()))
weight_store.loc[weight_store.shape[0]] = weight_log weight_store.loc[weight_store.shape[0]] = weight_log
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False)
index=False) weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(),
index=False)
train_store = new_storage_df('train', None) train_store = new_storage_df('train', None)
weight_store = new_storage_df('weights', meta_weight_count) weight_store = new_storage_df('weights', dense_metanet.particle_parameter_count)
dense_metanet.eval() ###########################################################
# EPOCHS endet
dense_metanet = dense_metanet.eval()
counter_dict = defaultdict(lambda: 0) counter_dict = defaultdict(lambda: 0)
# This returns ID-functions # This returns ID-functions
@ -527,7 +524,7 @@ if __name__ == '__main__':
print(f'Search path was: {seed_path}:') print(f'Search path was: {seed_path}:')
print(f'Found Models are: {list(seed_path.rglob(".tp"))}') print(f'Found Models are: {list(seed_path.rglob(".tp"))}')
exit(1) exit(1)
latest_model = torch.load(model_path, map_location=DEVICE).eval()
try: try:
run_particle_dropout_and_plot(seed_path) run_particle_dropout_and_plot(seed_path)
except ValueError as e: except ValueError as e:

View File

@ -0,0 +1,316 @@
import platform
import sys
from collections import defaultdict
from pathlib import Path
import numpy as np
import torch
import torchmetrics
from torch import nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
# noinspection DuplicatedCode
if platform.node() == 'CarbonX':
debug = True
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
print("@ Warning, Debugging Config@!!!!!! @")
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
else:
debug = False
try:
# noinspection PyUnboundLocalVariable
if __package__ is None:
DIR = Path(__file__).resolve().parent
sys.path.insert(0, str(DIR.parent))
__package__ = DIR.name
else:
DIR = None
except NameError:
DIR = None
pass
from network import MetaNet, FixTypes as ft
from sparse_net import SparseNetwork
from functionalities_test import test_for_fixpoints
from experiments.meta_task_exp import new_storage_df, train_self_replication, train_task, set_checkpoint, \
flat_for_store, plot_training_result, plot_training_particle_types, run_particle_dropout_and_plot, \
plot_network_connectivity_by_fixtype
WORKER = 10 if not debug else 2
debug = False
BATCHSIZE = 50 if not debug else 50
EPOCH = 10
VALIDATION_FRQ = 1 if not debug else 1
SELF_TRAIN_FRQ = 1 if not debug else 1
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class AddTaskDataset(Dataset):
def __init__(self, length=int(1e5)):
super().__init__()
self.length = length
def __len__(self):
return self.length
def __getitem__(self, _):
ab = torch.randn(size=(2,)).to(torch.float32)
return ab, ab.sum(axis=-1, keepdims=True)
def validate(checkpoint_path, valid_d, ratio=1, validmetric=torchmetrics.MeanAbsoluteError()):
checkpoint_path = Path(checkpoint_path)
import torchmetrics
# initialize metric
model = torch.load(checkpoint_path, map_location=DEVICE).eval()
n_samples = int(len(valid_d) * ratio)
with tqdm(total=n_samples, desc='Validation Run: ') as pbar:
for idx, (valid_batch_x, valid_batch_y) in enumerate(valid_d):
valid_batch_x, valid_batch_y = valid_batch_x.to(DEVICE), valid_batch_y.to(DEVICE)
y_valid = model(valid_batch_x)
# metric on current batch
acc = validmetric(y_valid.cpu(), valid_batch_y.cpu())
pbar.set_postfix_str(f'Acc: {acc}')
pbar.update()
if idx == n_samples:
break
# metric on all batches using custom accumulation
acc = validmetric.compute()
tqdm.write(f"Avg. Accuracy on all data: {acc}")
return acc
def checkpoint_and_validate(model, out_path, epoch_n, valid_d, final_model=False):
out_path = Path(out_path)
ckpt_path = set_checkpoint(model, out_path, epoch_n, final_model=final_model)
result = validate(ckpt_path, valid_d)
return result
if __name__ == '__main__':
training = True
train_to_id_first = False
train_to_task_first = False
seq_task_train = True
force_st_for_epochs_n = 5
n_st_per_batch = 10
activation = None # nn.ReLU()
use_sparse_network = False
for weight_hidden_size in [3, 4]:
tsk_threshold = 0.85
weight_hidden_size = weight_hidden_size
residual_skip = False
n_seeds = 3
depth = 3
width = 3
out = 1
data_path = Path('data')
data_path.mkdir(exist_ok=True, parents=True)
assert not (train_to_task_first and train_to_id_first)
ac_str = f'_{activation.__class__.__name__}' if activation is not None else ''
s_str = f'_n_{n_st_per_batch}' if n_st_per_batch > 1 else ""
res_str = f'{"" if residual_skip else "_no_res"}'
# dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}'
id_str = f'{f"_StToId" if train_to_id_first else ""}'
tsk_str = f'{f"_Tsk_{tsk_threshold}" if train_to_task_first and tsk_threshold != 1 else ""}'
sprs_str = '_sprs' if use_sparse_network else ''
f_str = f'_f_{force_st_for_epochs_n}' if \
force_st_for_epochs_n and seq_task_train and train_to_task_first else ""
config_str = f'{s_str}{res_str}{id_str}{tsk_str}{f_str}{sprs_str}'
exp_path = Path('output') / f'add_st_{EPOCH}_{weight_hidden_size}{config_str}{ac_str}'
if not training:
# noinspection PyRedeclaration
exp_path = Path('output') / 'mn_st_n_2_100_4'
for seed in range(n_seeds):
seed_path = exp_path / str(seed)
model_path = seed_path / '0000_trained_model.zip'
df_store_path = seed_path / 'train_store.csv'
weight_store_path = seed_path / 'weight_store.csv'
srnn_parameters = dict()
if training:
# Check if files do exist on project location, warn and break.
for path in [model_path, df_store_path, weight_store_path]:
assert not path.exists(), f'Path "{path}" already exists. Check your configuration!'
train_data = AddTaskDataset()
valid_data = AddTaskDataset()
train_load = DataLoader(train_data, batch_size=BATCHSIZE, shuffle=True,
drop_last=True, num_workers=WORKER)
vali_load = DataLoader(valid_data, batch_size=BATCHSIZE, shuffle=False,
drop_last=True, num_workers=WORKER)
interface = np.prod(train_data[0][0].shape)
dense_metanet = MetaNet(interface, depth=depth, width=width, out=out,
residual_skip=residual_skip, weight_hidden_size=weight_hidden_size,
activation=activation
).to(DEVICE)
sparse_metanet = SparseNetwork(interface, depth=depth, width=width, out=out,
residual_skip=residual_skip, weight_hidden_size=weight_hidden_size,
activation=activation
).to(DEVICE) if use_sparse_network else dense_metanet
if use_sparse_network:
sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles)
loss_fn = nn.MSELoss()
dense_optimizer = torch.optim.SGD(dense_metanet.parameters(), lr=0.004, momentum=0.9)
sparse_optimizer = torch.optim.SGD(
sparse_metanet.parameters(), lr=0.001, momentum=0.9
) if use_sparse_network else dense_optimizer
dense_weights_updated = False
sparse_weights_updated = False
train_store = new_storage_df('train', None)
weight_store = new_storage_df('weights', dense_metanet.particle_parameter_count)
init_tsk = train_to_task_first
for epoch in tqdm(range(EPOCH), desc=f'Train - Epochs'):
is_validation_epoch = epoch % VALIDATION_FRQ == 0 if not debug else True
is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True
sparse_metanet = sparse_metanet.train()
dense_metanet = dense_metanet.train()
# Init metrics, even we do not need:
metric = torchmetrics.MeanAbsoluteError()
# Define what to train in this epoch:
do_tsk_train = train_to_task_first
force_st = (force_st_for_epochs_n >= (EPOCH - epoch)) and force_st_for_epochs_n
init_st = (train_to_id_first and not dense_metanet.count_fixpoints() > 200)
do_st_train = init_st or is_self_train_epoch or force_st
for batch, (batch_x, batch_y) in tqdm(enumerate(train_load),
total=len(train_load), desc='MetaNet Train - Batch'
):
# Self Train
if do_st_train:
# Transfer weights
if dense_weights_updated:
sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles)
dense_weights_updated = False
st_steps = n_st_per_batch if not init_st else n_st_per_batch * 10
step_log = train_self_replication(sparse_metanet, sparse_optimizer, st_steps)
step_log.update(dict(Epoch=epoch, Batch=batch))
train_store.loc[train_store.shape[0]] = step_log
if use_sparse_network:
sparse_weights_updated = True
# Task Train
if not init_st:
# Transfer weights
if sparse_weights_updated:
dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights)
sparse_weights_updated = False
step_log, y_pred = train_task(dense_metanet, dense_optimizer, loss_fn, batch_x, batch_y)
step_log.update(dict(Epoch=epoch, Batch=batch))
train_store.loc[train_store.shape[0]] = step_log
if use_sparse_network:
dense_weights_updated = True
metric(y_pred.cpu(), batch_y.cpu())
if is_validation_epoch:
if sparse_weights_updated:
dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights)
sparse_weights_updated = False
dense_metanet = dense_metanet.eval()
if do_tsk_train:
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric='Train Accuracy', Score=metric.compute().item())
train_store.loc[train_store.shape[0]] = validation_log
accuracy = checkpoint_and_validate(dense_metanet, seed_path, epoch, vali_load).item()
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric='Test Accuracy', Score=accuracy)
train_store.loc[train_store.shape[0]] = validation_log
if init_tsk or (train_to_task_first and seq_task_train):
init_tsk = accuracy <= tsk_threshold
if init_st or is_validation_epoch:
if dense_weights_updated:
sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles)
dense_weights_updated = False
counter_dict = defaultdict(lambda: 0)
# This returns ID-functions
_ = test_for_fixpoints(counter_dict, list(dense_metanet.particles))
counter_dict = dict(counter_dict)
for key, value in counter_dict.items():
step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value)
train_store.loc[train_store.shape[0]] = step_log
tqdm.write(f'Fixpoint Tester Results: {counter_dict}')
if sum(x.is_fixpoint == ft.identity_func for x in dense_metanet.particles) > 200:
train_to_id_first = False
# Reset Diverged particles
sparse_metanet.reset_diverged_particles()
if use_sparse_network:
sparse_weights_updated = True
# FLUSH to disk
if is_validation_epoch:
for particle in dense_metanet.particles:
weight_log = (epoch, particle.name, *flat_for_store(particle.parameters()))
weight_store.loc[weight_store.shape[0]] = weight_log
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False)
weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
train_store = new_storage_df('train', None)
weight_store = new_storage_df('weights', dense_metanet.particle_parameter_count)
###########################################################
# EPOCHS endet
dense_metanet = dense_metanet.eval()
counter_dict = defaultdict(lambda: 0)
# This returns ID-functions
_ = test_for_fixpoints(counter_dict, list(dense_metanet.particles))
for key, value in dict(counter_dict).items():
step_log = dict(Epoch=int(EPOCH), Batch=BATCHSIZE, Metric=key, Score=value)
train_store.loc[train_store.shape[0]] = step_log
accuracy = checkpoint_and_validate(dense_metanet, seed_path, EPOCH, vali_load, final_model=True)
validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
Metric='Test Accuracy', Score=accuracy.item())
for particle in dense_metanet.particles:
weight_log = (EPOCH, particle.name, *(flat_for_store(particle.parameters())))
weight_store.loc[weight_store.shape[0]] = weight_log
train_store.loc[train_store.shape[0]] = validation_log
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False)
weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
plot_training_result(df_store_path)
plot_training_particle_types(df_store_path)
try:
model_path = next(seed_path.glob(f'*e{EPOCH}.tp'))
except StopIteration:
print('Model pattern did not trigger.')
print(f'Search path was: {seed_path}:')
print(f'Found Models are: {list(seed_path.rglob(".tp"))}')
exit(1)
try:
run_particle_dropout_and_plot(model_path)
except ValueError as e:
print(e)
try:
plot_network_connectivity_by_fixtype(model_path)
except ValueError as e:
print(e)
if n_seeds >= 2:
pass

View File

@ -57,7 +57,7 @@ if __name__ == '__main__':
multiplication_target = 0.03 multiplication_target = 0.03
loss_fn = nn.MSELoss() loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.008, momentum=0.9) optimizer = torch.optim.SGD(net.parameters(), lr=0.004, momentum=0.9)
train_frame = pd.DataFrame(columns=['Epoch', 'Batch', 'Metric', 'Score']) train_frame = pd.DataFrame(columns=['Epoch', 'Batch', 'Metric', 'Score'])
@ -67,7 +67,7 @@ if __name__ == '__main__':
mean_batch_loss = [] mean_batch_loss = []
mean_self_tain_loss = [] mean_self_tain_loss = []
for batch, (batch_x, batch_y) in tenumerate(dataloader): for batch, (batch_x, batch_y) in tenumerate(dataloader):
self_train_loss, _ = net.self_train(10, save_history=False) self_train_loss, _ = net.self_train(2, save_history=False)
batch_x_emb = torch.zeros(batch_x.shape[0], 5) batch_x_emb = torch.zeros(batch_x.shape[0], 5)
batch_x_emb[:, -1] = batch_x.squeeze() batch_x_emb[:, -1] = batch_x.squeeze()
y = net(batch_x_emb) y = net(batch_x_emb)

View File

@ -489,6 +489,13 @@ class MetaNet(nn.Module):
def all_layers(self): def all_layers(self):
return (x for x in (self._meta_layer_first, *self._meta_layer_list, self._meta_layer_last)) return (x for x in (self._meta_layer_first, *self._meta_layer_list, self._meta_layer_last))
@property
def particle_parameter_count(self):
return sum(p.numel() for p in next(self.particles).parameters())
def count_fixpoints(self, fix_type=FixTypes.identity_func):
return sum(x.is_fixpoint == fix_type for x in self.particles)
def reset_diverged_particles(self): def reset_diverged_particles(self):
for particle in self.particles: for particle in self.particles:
if particle.is_fixpoint == FixTypes.divergent: if particle.is_fixpoint == FixTypes.divergent: