Final Experiments and Plot adjustments

This commit is contained in:
Steffen Illium 2022-03-12 11:39:28 +01:00
parent dd2458da4a
commit 0ba3994325
6 changed files with 214 additions and 79 deletions

View File

@ -65,8 +65,8 @@ class AddTaskDataset(Dataset):
def set_checkpoint(model, out_path, epoch_n, final_model=False):
epoch_n = str(epoch_n)
if not final_model:
epoch_n = str(epoch_n)
ckpt_path = Path(out_path) / 'ckpt' / f'{epoch_n.zfill(4)}_model_ckpt.tp'
else:
if isinstance(epoch_n, str):
@ -145,6 +145,7 @@ def plot_training_particle_types(path_to_dataframe):
labels=fix_types.tolist(), colors=PALETTE)
ax.set(ylabel='Particle Count', xlabel='Epoch')
ax.yaxis.get_major_locator().set_params(integer=True)
# ax.set_title('Particle Type Count')
fig.legend(loc="center right", title='Particle Type', bbox_to_anchor=(0.85, 0.5))
@ -219,6 +220,9 @@ def plot_network_connectivity_by_fixtype(path_to_trained_model):
legend=False, estimator=None, lw=1)
_ = sns.lineplot(y=[0, 1], x=[-1, df['Layer'].max()], legend=False, estimator=None, lw=0)
ax.set_title(fixtype)
ax.yaxis.get_major_locator().set_params(integer=True)
ax.xaxis.get_major_locator().set_params(integer=True)
ax.set_ylabel('Normalized Neuron Position (1/n)') # XAXIS Label
lines = ax.get_lines()
for line in lines:
line.set_color(PALETTE[n])
@ -273,7 +277,7 @@ def plot_dropout_stacked_barplot(mdl_path, diff_store_path, metric_class=torchme
_ = sns.barplot(data=diff_df, y=metric_name, x='Particle Type', ax=ax[0], palette=colors[:palette_len], ci=None)
ax[0].set_title(f'{metric_name} after particle dropout')
ax[0].set_xlabel('Particle Type')
# ax[0].set_xlabel('Particle Type') # XAXIS Label
ax[0].set_xticklabels(ax[0].get_xticklabels(), rotation=30)
ax[1].pie(sorted_particle_dict.values(), labels=sorted_particle_dict.keys(),
@ -345,9 +349,11 @@ def highlight_fixpoints_vs_mnist_mean(mdl_path, dataloader):
fig, axs = plt.subplots(1, 3)
for idx, image in enumerate([binary_image, real_image, mnist_mean]):
for idx, (image, title) in enumerate(zip([binary_image, real_image, mnist_mean],
["Particle Count", "Particle Value", "MNIST mean"])):
img = axs[idx].imshow(image.squeeze().detach().cpu())
img.axes.axis('off')
img.axes.set_title('Random Noise')
plt.tight_layout()
plt.savefig(mdl_path.parent / 'heatmap.png', dpi=300)

View File

@ -45,10 +45,9 @@ def test_robustness(model_path, noise_levels=10, seeds=10, log_step_size=10):
time_to_vergence = [[0 for _ in range(noise_levels)] for _ in range(len(networks))]
time_as_fixpoint = [[0 for _ in range(noise_levels)] for _ in range(len(networks))]
row_headers = []
df = pd.DataFrame(columns=['setting', 'Noise Level', 'Self Train Steps', 'absolute_loss',
'Time to convergence', 'Time as fixpoint'])
with tqdm(total=(seeds * noise_levels * len(networks))) as pbar:
with tqdm(total=(seeds * noise_levels * len(networks)), desc='Per Particle Robustness') as pbar:
for setting, fixpoint in enumerate(networks): # 1 / n
row_headers.append(fixpoint.name)
for seed in range(seeds): # n / 1

View File

@ -25,6 +25,7 @@ from experiments.meta_task_utility import (ToFloat, new_storage_df, train_task,
from experiments.robustness_tester import test_robustness
from plot_3d_trajectories import plot_single_3d_trajectories_by_layer, plot_grouped_3d_trajectories_by_layer
if platform.node() == 'CarbonX':
debug = True
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
@ -33,9 +34,11 @@ if platform.node() == 'CarbonX':
else:
debug = False
from network import MetaNet, FixTypes
from functionalities_test import test_for_fixpoints
utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)]) # , AddGaussianNoise()])
WORKER = 10 if not debug else 2
debug = False
@ -62,17 +65,24 @@ if __name__ == '__main__':
training = True
plotting = True
robustnes = True
n_st = 1 # per batch !!
activation = None # nn.ReLU()
robustnes = False
n_st = 300 # per batch !!
activation = None # nn.ReLU()
for weight_hidden_size in [3]:
train_to_task_first = True
min_task_acc = 0.85
residual_skip = True
add_gauss = False
alpha_st_modulator = 0
for weight_hidden_size in [5]:
weight_hidden_size = weight_hidden_size
residual_skip = False
n_seeds = 3
depth = 5
width = 3
depth = 3
width = 5
out = 10
data_path = Path('data')
@ -80,13 +90,17 @@ if __name__ == '__main__':
# noinspection PyUnresolvedReferences
ac_str = f'_{activation.__class__.__name__}' if activation is not None else ''
res_str = f'{"" if residual_skip else "_no_res"}'
a_str = f'_aStM_{alpha_st_modulator}' if alpha_st_modulator not in [1, 0] else ''
res_str = '_no_res' if not residual_skip else ''
st_str = f'_nst_{n_st}'
tsk_str = f'_tsktr_{min_task_acc}' if train_to_task_first else ''
w_str = f'_w{width}wh{weight_hidden_size}d{depth}'
config_str = f'{res_str}{ac_str}{st_str}'
exp_path = Path('output') / f'mn_st_{EPOCH}_{weight_hidden_size}{config_str}'
config_str = f'{res_str}{ac_str}{st_str}{tsk_str}{a_str}{w_str}'
exp_path = Path('output') / f'mn_st_{EPOCH}{config_str}'
last_accuracy = 0
for seed in range(n_seeds):
for seed in range(0, n_seeds):
seed_path = exp_path / str(seed)
df_store_path = seed_path / 'train_store.csv'
@ -135,16 +149,20 @@ if __name__ == '__main__':
weight_log = (epoch, particle.name, *flat_for_store(particle.parameters()))
weight_store.loc[weight_store.shape[0]] = weight_log
do_self_train = not train_to_task_first or last_accuracy >= min_task_acc
train_to_task_first = train_to_task_first if not do_self_train else False
for batch, (batch_x, batch_y) in tqdm(enumerate(train_loader),
total=len(train_loader), desc='MetaNet Train - Batch'
):
# Self Train
self_train_loss = metanet.combined_self_train(n_st_per_batch,
reduction='mean', per_particle=False)
# noinspection PyUnboundLocalVariable
st_step_log = dict(Metric='Self Train Loss', Score=self_train_loss.item())
st_step_log.update(dict(Epoch=epoch, Batch=batch))
train_store.loc[train_store.shape[0]] = st_step_log
if do_self_train:
self_train_loss = metanet.combined_self_train(n_st_per_batch, alpha=alpha_st_modulator,
reduction='mean', per_particle=False)
# noinspection PyUnboundLocalVariable
st_step_log = dict(Metric='Self Train Loss', Score=self_train_loss.item())
st_step_log.update(dict(Epoch=epoch, Batch=batch))
train_store.loc[train_store.shape[0]] = st_step_log
# Task Train
tsk_step_log, y_pred = train_task(metanet, optimizer, loss_fn, batch_x, batch_y)
@ -152,11 +170,12 @@ if __name__ == '__main__':
train_store.loc[train_store.shape[0]] = tsk_step_log
metric(y_pred.cpu(), batch_y.cpu())
last_accuracy = metric.compute().item()
if is_validation_epoch:
metanet = metanet.eval()
try:
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric=f'Train {VAL_METRIC_NAME}', Score=metric.compute().item())
Metric=f'Train {VAL_METRIC_NAME}', Score=last_accuracy)
train_store.loc[train_store.shape[0]] = validation_log
except RuntimeError:
pass
@ -208,7 +227,7 @@ if __name__ == '__main__':
weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
try:
model_path = next(seed_path.glob(f'*{FINAL_CHECKPOINT_NAME}'))
model_path = next(seed_path.glob(f'{FINAL_CHECKPOINT_NAME}'))
except StopIteration:
print('Model pattern did not trigger.')
print(f'Search path was: {seed_path}:')

View File

@ -19,7 +19,7 @@ from plot_3d_trajectories import plot_single_3d_trajectories_by_layer, plot_grou
WORKER = 0
BATCHSIZE = 50
EPOCH = 30
EPOCH = 60
VALIDATION_FRQ = 3
VAL_METRIC_CLASS = torchmetrics.MeanAbsoluteError
# noinspection PyProtectedMember
@ -34,15 +34,19 @@ if __name__ == '__main__':
training = False
plotting = True
n_st = 100
robustness = True
attack = False
attack_ratio = 0.01
melt = False
melt_ratio = 0.01
n_st = 200
activation = None # nn.ReLU()
for weight_hidden_size in [2]:
for weight_hidden_size in [3]:
tsk_threshold = 0.85
weight_hidden_size = weight_hidden_size
residual_skip = True
n_seeds = 3
n_seeds = 10
depth = 3
width = 3
out = 1
@ -53,10 +57,13 @@ if __name__ == '__main__':
# noinspection PyUnresolvedReferences
ac_str = f'_{activation.__class__.__name__}' if activation is not None else ''
res_str = f'{"" if residual_skip else "_no_res"}'
att_str = f'_att_{attack_ratio}' if attack else ''
mlt_str = f'_mlt_{melt_ratio}' if melt else ''
w_str = f'_w{width}wh{weight_hidden_size}d{depth}'
# dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}'
config_str = f'{res_str}'
exp_path = Path('output') / f'add_st_{EPOCH}_{weight_hidden_size}{config_str}{ac_str}'
config_str = f'{res_str}{att_str}{ac_str}{mlt_str}{w_str}'
exp_path = Path('output') / f'add_st_{EPOCH}{config_str}'
# if not training:
# # noinspection PyRedeclaration
@ -113,6 +120,20 @@ if __name__ == '__main__':
st_step_log.update(dict(Epoch=epoch, Batch=batch))
train_store.loc[train_store.shape[0]] = st_step_log
# Attack
if attack:
after_attack_loss = metanet.make_particles_attack(attack_ratio)
st_step_log = dict(Metric='After Attack Loss', Score=after_attack_loss.item())
st_step_log.update(dict(Epoch=epoch, Batch=batch))
train_store.loc[train_store.shape[0]] = st_step_log
# Melt
if melt:
after_melt_loss = metanet.make_particles_melt(melt_ratio)
st_step_log = dict(Metric='After Melt Loss', Score=after_melt_loss.item())
st_step_log.update(dict(Epoch=epoch, Batch=batch))
train_store.loc[train_store.shape[0]] = st_step_log
# Task Train
tsk_step_log, y_pred = train_task(metanet, optimizer, loss_fn, batch_x, batch_y)
tsk_step_log.update(dict(Epoch=epoch, Batch=batch))
@ -200,17 +221,20 @@ if __name__ == '__main__':
except ValueError as e:
print('ERROR:', e)
try:
tqdm.write('Trajectory plotting ...')
plot_single_3d_trajectories_by_layer(model_path, weight_store_path, status_type=ft.identity_func)
plot_single_3d_trajectories_by_layer(model_path, weight_store_path, status_type=ft.other_func)
plot_grouped_3d_trajectories_by_layer(model_path, weight_store_path, status_type=ft.identity_func)
plot_grouped_3d_trajectories_by_layer(model_path, weight_store_path, status_type=ft.other_func)
tqdm.write('Trajectory plotting Done')
except ValueError as e:
print('ERROR:', e)
try:
test_robustness(model_path, seeds=10)
pass
except ValueError as e:
print('ERROR:', e)
if robustness:
try:
test_robustness(model_path, seeds=10)
pass
except ValueError as e:
print('ERROR:', e)
if 2 <= n_seeds == sum(list(x.is_dir() for x in exp_path.iterdir())):
if plotting:

View File

@ -10,7 +10,6 @@ import torch.nn.functional as F
from torch import optim, Tensor
from tqdm import tqdm
def xavier_init(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data)
@ -22,17 +21,24 @@ def prng():
class FixTypes:
divergent = 'divergent'
fix_zero = 'fix_zero'
identity_func = 'identity_func'
fix_sec = 'fix_sec'
other_func = 'other_func'
divergent = 'Divergend'
fix_zero = 'All Zero'
identity_func = 'Self-Replicator'
fix_sec = 'Self-Replicator 2nd'
other_func = 'Other'
@classmethod
def all_types(cls):
return [val for key, val in cls.__dict__.items() if isinstance(val, str) and not key.startswith('_')]
class NetworkLevel:
all = 'All'
layer = 'Layer'
cell = 'Cell'
class Net(nn.Module):
@staticmethod
@ -365,6 +371,25 @@ class MetaCell(nn.Module):
def particles(self):
return (net for net in self.meta_weight_list)
def make_particles_attack(self, ratio=0.01):
random_particle_list = list(self.particles)
random.shuffle(random_particle_list)
for idx, particle in enumerate(self.particles):
if random.random() <= ratio:
other = random_particle_list[idx]
if other != particle:
particle.attack(other)
def make_particles_melt(self, ratio=0.01):
random_particle_list = list(self.particles)
random.shuffle(random_particle_list)
for idx, particle in enumerate(self.particles):
if random.random() <= ratio:
other = random_particle_list[idx]
if other != particle:
new_particle = particle.melt(other)
particle.apply_weights(new_particle.target_weight_matrix())
class MetaLayer(nn.Module):
def __init__(self, name, interface=4, width=4, # residual_skip=False,
@ -451,12 +476,12 @@ class MetaNet(nn.Module):
tensor = self._meta_layer_first(x)
residual = None
for idx, meta_layer in enumerate(self._meta_layer_list, start=1):
# if idx % 2 == 1 and self.residual_skip:
if self.residual_skip:
if idx % 2 == 1 and self.residual_skip:
# if self.residual_skip:
residual = tensor
tensor = meta_layer(tensor)
# if idx % 2 == 0 and self.residual_skip:
if self.residual_skip:
if idx % 2 == 0 and self.residual_skip:
# if self.residual_skip:
tensor = tensor + residual
tensor = self._meta_layer_last(tensor)
return tensor
@ -465,7 +490,7 @@ class MetaNet(nn.Module):
def particles(self):
return (cell for metalayer in self.all_layers for cell in metalayer.particles)
def combined_self_train(self, n_st_steps, reduction='mean', per_particle=True):
def combined_self_train(self, n_st_steps, reduction='mean', per_particle=True, alpha=1):
losses = []
@ -487,6 +512,8 @@ class MetaNet(nn.Module):
train_losses.append(loss)
train_losses = torch.hstack(train_losses).sum(dim=-1, keepdim=True)
if alpha not in [0, 1]:
train_losses *= alpha
train_losses.backward()
optim.step()
losses.append(train_losses.detach())
@ -505,6 +532,65 @@ class MetaNet(nn.Module):
weight.apply_weights(next(particle_weights_list).detach())
return self
def make_particles_attack(self, ratio=0.01, level=NetworkLevel.cell, reduction='mean'):
if level == NetworkLevel.all:
raise NotImplementedError()
pass
elif level == NetworkLevel.layer:
raise NotImplementedError()
pass
elif level == NetworkLevel.cell:
for layer in self.all_layers:
for cell in layer.meta_cell_list:
cell.make_particles_attack(ratio)
pass
else:
raise ValueError(f'level has to be any of: {[level]}')
# Self Train Loss after attack:
with torch.no_grad():
sa_losses = []
for particle in self.particles:
# Intergrate optimizer and backward function
input_data = particle.input_weight_matrix()
target_data = particle.create_target_weights(input_data)
output = particle(input_data)
loss = F.mse_loss(output, target_data, reduction=reduction)
sa_losses.append(loss)
after_attack_loss = torch.hstack(sa_losses).sum(dim=-1, keepdim=True)
return after_attack_loss
def make_particles_melt(self, ratio=0.01, level=NetworkLevel.cell, reduction='mean'):
if level == NetworkLevel.all:
raise NotImplementedError()
pass
elif level == NetworkLevel.layer:
raise NotImplementedError()
pass
elif level == NetworkLevel.cell:
for layer in self.all_layers:
for cell in layer.meta_cell_list:
cell.make_particles_melt(ratio)
pass
else:
raise ValueError(f'level has to be any of: {[level]}')
# Self Train Loss after attack:
with torch.no_grad():
sa_losses = []
for particle in self.particles:
# Intergrate optimizer and backward function
input_data = particle.input_weight_matrix()
target_data = particle.create_target_weights(input_data)
output = particle(input_data)
loss = F.mse_loss(output, target_data, reduction=reduction)
sa_losses.append(loss)
after_melt_loss = torch.hstack(sa_losses).sum(dim=-1, keepdim=True)
return after_melt_loss
@property
def all_layers(self):
return (x for x in (self._meta_layer_first, *self._meta_layer_list, self._meta_layer_last))
@ -541,12 +627,12 @@ class MetaNetCompareBaseline(nn.Module):
tensor = self._first_layer(x)
residual = None
for idx, meta_layer in enumerate(self._meta_layer_list, start=1):
# if idx % 2 == 1 and self.residual_skip:
if self.residual_skip:
if idx % 2 == 1 and self.residual_skip:
# if self.residual_skip:
residual = tensor
tensor = meta_layer(tensor)
# if idx % 2 == 0 and self.residual_skip:
if self.residual_skip:
if idx % 2 == 0 and self.residual_skip:
# if self.residual_skip:
tensor = tensor + residual
tensor = self._last_layer(tensor)
return tensor

View File

@ -69,35 +69,36 @@ def plot_grouped_3d_trajectories_by_layer(model_path, all_weights_path, status_t
fixpoint_statuses = [net.is_fixpoint for net in model_layer.particles]
num_status_of_layer = sum([net.is_fixpoint == status_type for net in model_layer.particles])
layer = all_weights[all_weights.Weight.str.startswith(f"L{layer_idx}")]
weight_batches = np.vstack([np.array(layer[layer.Weight == name].values.tolist())[:, 2:]
for name in layer.Weight.unique()])
plt.clf()
fig = plt.figure()
fig.set_figheight(10)
fig.set_figwidth(12)
ax = plt.axes(projection='3d')
plt.tight_layout()
if num_status_of_layer != 0:
layer = all_weights[all_weights.Weight.str.startswith(f"L{layer_idx}")]
weight_batches = np.vstack([np.array(layer[layer.Weight == name].values.tolist())[:, 2:]
for name in layer.Weight.unique()])
plt.clf()
fig = plt.figure()
fig.set_figheight(10)
fig.set_figwidth(12)
ax = plt.axes(projection='3d')
plt.tight_layout()
pca.fit(weight_batches)
w_transformed = pca.transform(weight_batches)
for transformed_trajectory, status in zip(
np.split(w_transformed, len(layer.Weight.unique())), fixpoint_statuses):
if status == status_type:
xdata = transformed_trajectory[:, 0]
ydata = transformed_trajectory[:, 1]
zdata = all_epochs
ax.plot3D(xdata, ydata, zdata)
ax.scatter(xdata, ydata, zdata, s=7)
ax.set_title(f"Layer {layer_idx}: {num_status_of_layer}-{status_type}", fontsize=20)
ax.set_xlabel('PCA Transformed x-axis', fontsize=20)
ax.set_ylabel('PCA Transformed y-axis', fontsize=20)
ax.set_zlabel('Epochs', fontsize=30, rotation=0)
file_path = save_path / f"layer_{layer_idx}_{num_status_of_layer}_{status_type}_grouped.png"
plt.savefig(file_path, bbox_inches="tight", dpi=300, format="png")
plt.clf()
plt.close(fig)
pca.fit(weight_batches)
w_transformed = pca.transform(weight_batches)
for transformed_trajectory, status in zip(
np.split(w_transformed, len(layer.Weight.unique())), fixpoint_statuses):
if status == status_type:
xdata = transformed_trajectory[:, 0]
ydata = transformed_trajectory[:, 1]
zdata = all_epochs
ax.plot3D(xdata, ydata, zdata)
ax.scatter(xdata, ydata, zdata, s=7)
ax.set_title(f"Layer {layer_idx}: {num_status_of_layer}-{status_type}", fontsize=20)
ax.set_xlabel('PCA Transformed x-axis', fontsize=20)
ax.set_ylabel('PCA Transformed y-axis', fontsize=20)
ax.set_zlabel('Epochs', fontsize=30, rotation=0)
file_path = save_path / f"layer_{layer_idx}_{num_status_of_layer}_{status_type}_grouped.png"
plt.savefig(file_path, bbox_inches="tight", dpi=300, format="png")
plt.clf()
plt.close(fig)
if __name__ == '__main__':