Late Evening Commit

This commit is contained in:
Steffen Illium 2022-03-06 22:24:00 +01:00
parent b3d4987cb8
commit ce5a36c8f4
8 changed files with 311 additions and 215 deletions

View File

@ -13,20 +13,33 @@ from torch.utils.data import Dataset
from tqdm import tqdm
from network import FixTypes as ft
from functionalities_test import test_for_fixpoints
from functionalities_test import test_for_fixpoints, FixTypes as ft
from sanity_check_weights import test_weights_as_model, extract_weights_from_model
WORKER = 10
BATCHSIZE = 500
EPOCH = 50
VALIDATION_FRQ = 3
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DATA_PATH = Path('data')
DATA_PATH.mkdir(exist_ok=True, parents=True)
PALETTE = sns.color_palette()
PALETTE.insert(0, PALETTE.pop(1)) # Orange First
FINAL_CHECKPOINT_NAME = f'trained_model_ckpt_FINAL.tp'
class AddGaussianNoise(object):
def __init__(self, ratio=1e-4):
self.ratio = ratio
def __call__(self, tensor: torch.Tensor):
return tensor + (torch.randn_like(tensor, device=tensor.device) * self.ratio)
def __repr__(self):
return self.__class__.__name__ + f'(ratio={self.ratio}'
class ToFloat:
@ -56,7 +69,10 @@ def set_checkpoint(model, out_path, epoch_n, final_model=False):
if not final_model:
ckpt_path = Path(out_path) / 'ckpt' / f'{epoch_n.zfill(4)}_model_ckpt.tp'
else:
ckpt_path = Path(out_path) / f'trained_model_ckpt_e{epoch_n}.tp'
if isinstance(epoch_n, str):
ckpt_path = Path(out_path) / f'{epoch_n}_{FINAL_CHECKPOINT_NAME}'
else:
ckpt_path = Path(out_path) / FINAL_CHECKPOINT_NAME
ckpt_path.parent.mkdir(exist_ok=True, parents=True)
torch.save(model, ckpt_path, pickle_protocol=pickle.HIGHEST_PROTOCOL)
@ -109,26 +125,27 @@ def plot_training_particle_types(path_to_dataframe):
plt.close('all')
plt.clf()
# load from Drive
df = pd.read_csv(path_to_dataframe, index_col=False)
df = pd.read_csv(path_to_dataframe, index_col=False).sort_values('Metric')
# Set up figure
fig, ax = plt.subplots() # initializes figure and plots
data = df.loc[df['Metric'].isin(ft.all_types())]
fix_types = data['Metric'].unique()
data = data.pivot(index='Epoch', columns='Metric', values='Score').reset_index().fillna(0)
_ = plt.stackplot(data['Epoch'], *[data[fixtype] for fixtype in fix_types], labels=fix_types.tolist())
_ = plt.stackplot(data['Epoch'], *[data[fixtype] for fixtype in fix_types],
labels=fix_types.tolist(), colors=PALETTE)
ax.set(ylabel='Particle Count', xlabel='Epoch')
ax.set_title('Particle Type Count')
# ax.set_title('Particle Type Count')
fig.legend(loc="center right", title='Particle Type', bbox_to_anchor=(0.85, 0.5))
plt.tight_layout()
plt.savefig(Path(path_to_dataframe.parent / 'training_particle_type_lp.png'), dpi=300)
def plot_training_result(path_to_dataframe, metric='Accuracy', plot_name=None):
def plot_training_result(path_to_dataframe, metric_name='Accuracy', plot_name=None):
plt.clf()
# load from Drive
df = pd.read_csv(path_to_dataframe, index_col=False)
df = pd.read_csv(path_to_dataframe, index_col=False).sort_values('Metric')
# Check if this is a single lineplot or if aggregated
group = ['Epoch', 'Metric']
@ -141,20 +158,22 @@ def plot_training_result(path_to_dataframe, metric='Accuracy', plot_name=None):
# plots the first set of data
data = df[(df['Metric'] == 'Task Loss') | (df['Metric'] == 'Self Train Loss')].groupby(['Epoch', 'Metric']).mean()
palette = sns.color_palette()[1:data.reset_index()['Metric'].unique().shape[0]+1]
grouped_for_lineplot = data.groupby(group).mean()
palette_len_1 = len(grouped_for_lineplot.droplevel(0).reset_index().Metric.unique())
sns.lineplot(data=data.groupby(group).mean(), x='Epoch', y='Score', hue='Metric',
palette=palette, ax=ax1, ci='sd')
sns.lineplot(data=grouped_for_lineplot, x='Epoch', y='Score', hue='Metric',
palette=PALETTE[:palette_len_1], ax=ax1, ci='sd')
# plots the second set of data
data = df[(df['Metric'] == f'Test {metric}') | (df['Metric'] == f'Train {metric}')]
palette = sns.color_palette()[len(palette)+1:data.reset_index()['Metric'].unique().shape[0] + len(palette)+1]
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', hue='Metric', palette=palette, ci='sd')
data = df[(df['Metric'] == f'Test {metric_name}') | (df['Metric'] == f'Train {metric_name}')]
palette_len_2 = len(data.Metric.unique())
sns.lineplot(data=data, x='Epoch', y='Score', hue='Metric',
palette=PALETTE[palette_len_1:palette_len_2+palette_len_1], ci='sd')
ax1.set(yscale='log', ylabel='Losses')
# ax1.set_title('Training Lineplot')
ax2.set(ylabel=metric)
if metric != 'MAE':
ax2.set(ylabel=metric_name)
if metric_name != 'Accuracy':
ax2.set(yscale='log')
fig.legend(loc="center right", title='Metric', bbox_to_anchor=(0.85, 0.5))
@ -166,35 +185,39 @@ def plot_training_result(path_to_dataframe, metric='Accuracy', plot_name=None):
def plot_network_connectivity_by_fixtype(path_to_trained_model):
m = torch.load(path_to_trained_model, map_location=torch.device('cpu')).eval()
m = torch.load(path_to_trained_model, map_location=DEVICE).eval()
# noinspection PyProtectedMember
particles = list(m.particles)
df = pd.DataFrame(columns=['type', 'layer', 'neuron', 'name'])
df = pd.DataFrame(columns=['Type', 'Layer', 'Neuron', 'Name'])
for prtcl in particles:
l, c, w = [float(x) for x in re.sub("[^0-9|_]", "", prtcl.name).split('_')]
df.loc[df.shape[0]] = (prtcl.is_fixpoint, l-1, w, prtcl.name)
df.loc[df.shape[0]] = (prtcl.is_fixpoint, l, c, prtcl.name)
for layer in list(df['layer'].unique()):
for layer in list(df['Layer'].unique()):
# Rescale
divisor = df.loc[(df['layer'] == layer), 'neuron'].max()
df.loc[(df['layer'] == layer), 'neuron'] /= divisor
divisor = df.loc[(df['Layer'] == layer), 'Neuron'].max()
df.loc[(df['Layer'] == layer), 'Neuron'] /= divisor
tqdm.write(f'Connectivity Data gathered')
for n, fixtype in enumerate(ft.all_types()):
if df[df['type'] == fixtype].shape[0] > 0:
df = df.sort_values('Type')
n = 0
for fixtype in ft.all_types():
if df[df['Type'] == fixtype].shape[0] > 0:
plt.clf()
ax = sns.lineplot(y='neuron', x='layer', hue='name', data=df[df['type'] == fixtype],
ax = sns.lineplot(y='Neuron', x='Layer', hue='Name', data=df[df['Type'] == fixtype],
legend=False, estimator=None, lw=1)
_ = sns.lineplot(y=[0, 1], x=[-1, df['layer'].max()], legend=False, estimator=None, lw=0)
_ = sns.lineplot(y=[0, 1], x=[-1, df['Layer'].max()], legend=False, estimator=None, lw=0)
ax.set_title(fixtype)
lines = ax.get_lines()
for line in lines:
line.set_color(sns.color_palette()[n])
line.set_color(PALETTE[n])
plt.savefig(Path(path_to_trained_model.parent / f'net_connectivity_{fixtype}.png'), dpi=300)
tqdm.write(f'Connectivity plottet: {fixtype} - n = {df[df["type"] == fixtype].shape[0] // 2}')
tqdm.write(f'Connectivity plottet: {fixtype} - n = {df[df["Type"] == fixtype].shape[0] // 2}')
n += 1
else:
tqdm.write(f'No Connectivity {fixtype}')
# tqdm.write(f'No Connectivity {fixtype}')
pass
# noinspection PyProtectedMember
@ -218,27 +241,33 @@ def run_particle_dropout_test(model_path, valid_loader, metric_class=torchmetric
tqdm.write(f'Zero_ident diff = {acc_diff}')
diff_df.loc[diff_df.shape[0]] = (fixpoint_type, acc_post, acc_diff)
diff_df.to_csv(diff_store_path, mode='a', header=not diff_store_path.exists(), index=False)
diff_df.to_csv(diff_store_path, mode='w', header=True, index=False)
return diff_store_path
# noinspection PyProtectedMember
def plot_dropout_stacked_barplot(mdl_path, diff_store_path, metric_class=torchmetrics.Accuracy):
metric_name = metric_class()._get_name()
diff_df = pd.read_csv(diff_store_path)
diff_df = pd.read_csv(diff_store_path).sort_values('Particle Type')
particle_dict = defaultdict(lambda: 0)
latest_model = torch.load(mdl_path, map_location=DEVICE).eval()
_ = test_for_fixpoints(particle_dict, list(latest_model.particles))
tqdm.write(str(dict(particle_dict)))
particle_dict = dict(particle_dict)
sorted_particle_dict = dict(sorted(particle_dict.items()))
tqdm.write(str(sorted_particle_dict))
plt.clf()
fig, ax = plt.subplots(ncols=2)
colors = sns.color_palette()[1:diff_df.shape[0]+1]
_ = sns.barplot(data=diff_df, y=metric_name, x='Particle Type', ax=ax[0], palette=colors)
colors = PALETTE.copy()
colors.insert(0, colors.pop(-1))
palette_len = len(diff_df['Particle Type'].unique())
_ = sns.barplot(data=diff_df, y=metric_name, x='Particle Type', ax=ax[0], palette=colors[:palette_len], ci=None)
ax[0].set_title(f'{metric_name} after particle dropout')
ax[0].set_xlabel('Particle Type')
ax[0].set_xticklabels(ax[0].get_xticklabels(), rotation=30)
ax[1].pie(particle_dict.values(), labels=particle_dict.keys(), colors=list(reversed(colors)), )
ax[1].pie(sorted_particle_dict.values(), labels=sorted_particle_dict.keys(),
colors=PALETTE[:len(sorted_particle_dict)])
ax[1].set_title('Particle Count')
plt.tight_layout()
@ -278,5 +307,83 @@ def train_task(model, optimizer, loss_func, btch_x, btch_y) -> (dict, torch.Tens
return stp_log, y_prd
def highlight_fixpoints_vs_mnist_mean(mdl_path, dataloader):
latest_model = torch.load(mdl_path, map_location=DEVICE).eval()
activation_vector = torch.as_tensor([[0, 0, 0, 0, 1]], dtype=torch.float32, device=DEVICE)
binary_images = []
real_images = []
with torch.no_grad():
# noinspection PyProtectedMember
for cell in latest_model._meta_layer_first.meta_cell_list:
cell_image_binary = torch.zeros((len(cell.meta_weight_list)), device=DEVICE)
cell_image_real = torch.zeros((len(cell.meta_weight_list)), device=DEVICE)
for idx, particle in enumerate(cell.particles):
if particle.is_fixpoint == ft.identity_func:
cell_image_binary[idx] += 1
cell_image_real[idx] = particle(activation_vector).abs().squeeze().item()
binary_images.append(cell_image_binary.reshape((15, 15)))
real_images.append(cell_image_real.reshape((15, 15)))
binary_images = torch.stack(binary_images)
real_images = torch.stack(real_images)
binary_image = torch.sum(binary_images, keepdim=True, dim=0)
real_image = torch.sum(real_images, keepdim=True, dim=0)
mnist_images = [x for x, _ in dataloader]
mnist_mean = torch.cat(mnist_images).reshape(10000, 15, 15).abs().sum(dim=0)
fig, axs = plt.subplots(1, 3)
for idx, image in enumerate([binary_image, real_image, mnist_mean]):
img = axs[idx].imshow(image.squeeze().detach().cpu())
img.axes.axis('off')
plt.tight_layout()
plt.savefig(mdl_path.parent / 'heatmap.png', dpi=300)
plt.clf()
plt.close('all')
def plot_training_results_over_n_seeds(exp_path, df_train_store_name='train_store.csv', metric_name='Accuracy'):
combined_df_store_path = exp_path / f'comb_train_{exp_path.stem[:-1]}n.csv'
# noinspection PyUnboundLocalVariable
found_train_stores = exp_path.rglob(df_train_store_name)
train_dfs = []
for found_train_store in found_train_stores:
train_store_df = pd.read_csv(found_train_store, index_col=False)
train_store_df['Seed'] = int(found_train_store.parent.name)
train_dfs.append(train_store_df)
combined_train_df = pd.concat(train_dfs)
combined_train_df.to_csv(combined_df_store_path, index=False)
plot_training_result(combined_df_store_path, metric_name=metric_name,
plot_name=f"{combined_df_store_path.stem}.png"
)
plt.clf()
plt.close('all')
def sanity_weight_swap(exp_path, dataloader, metric_class=torchmetrics.Accuracy):
# noinspection PyProtectedMember
metric_name = metric_class()._get_name()
found_models = exp_path.rglob(f'*{FINAL_CHECKPOINT_NAME}')
df = pd.DataFrame(columns=['Seed', 'Model', metric_name])
for model_idx, found_model in enumerate(found_models):
model = torch.load(found_model, map_location=DEVICE).eval()
weights = extract_weights_from_model(model)
results = test_weights_as_model(model, weights, dataloader, metric_class=metric_class)
for model_name, measurement in results.items():
df.loc[df.shape[0]] = (model_idx, model_name, measurement)
df.loc[df.shape[0]] = (model_idx, 'Difference', np.abs(np.subtract(*results.values())))
df.to_csv(exp_path / 'sanity_weight_swap.csv', index=False)
_ = sns.boxplot(data=df, x='Model', y=metric_name)
plt.tight_layout()
plt.savefig(exp_path / 'sanity_weight_swap.png', dpi=300)
plt.clf()
plt.close('all')
if __name__ == '__main__':
raise NotImplementedError('Test this here!!!')

View File

@ -1,17 +1,13 @@
import pickle
import pandas as pd
import torch
import random
import copy
from pathlib import Path
from tqdm import tqdm
from functionalities_test import is_identity_function, is_zero_fixpoint, test_for_fixpoints, is_divergent
from functionalities_test import (is_identity_function, is_zero_fixpoint, test_for_fixpoints, is_divergent,
FixTypes as FT)
from network import Net
from torch.nn import functional as F
from visualization import plot_loss, bar_chart_fixpoints
import seaborn as sns
from matplotlib import pyplot as plt
@ -26,7 +22,6 @@ def generate_perfekt_synthetic_fixpoint_weights():
[1.0], [0.0]
], dtype=torch.float32)
PALETTE = 10 * (
"#377eb8",
"#4daf4a",
@ -44,14 +39,16 @@ PALETTE = 10 * (
)
def test_robustness(networks: list, exp_path, noise_levels=10, seeds=10, log_step_size=10):
def test_robustness(model_path, noise_levels=10, seeds=10, log_step_size=10):
model = torch.load(model_path, map_location='cpu')
networks = [x for x in model.particles if x.is_fixpoint == FT.identity_func]
time_to_vergence = [[0 for _ in range(noise_levels)] for _ in range(len(networks))]
time_as_fixpoint = [[0 for _ in range(noise_levels)] for _ in range(len(networks))]
row_headers = []
df = pd.DataFrame(columns=['setting', 'Noise Level', 'Self Train Steps', 'absolute_loss',
'Time to convergence', 'Time as fixpoint'])
with tqdm(total=max(len(networks), seeds)) as pbar:
with tqdm(total=(seeds * noise_levels * len(networks))) as pbar:
for setting, fixpoint in enumerate(networks): # 1 / n
row_headers.append(fixpoint.name)
for seed in range(seeds): # n / 1
@ -84,7 +81,7 @@ def test_robustness(networks: list, exp_path, noise_levels=10, seeds=10, log_ste
steps, absolute_loss,
time_to_vergence[setting][noise_level],
time_as_fixpoint[setting][noise_level]]
pbar.update(1)
pbar.update(1)
# Get the measuremts at the highest time_time_to_vergence
df_sorted = df.sort_values('Self Train Steps', ascending=False).drop_duplicates(['setting', 'Noise Level'])
@ -92,6 +89,9 @@ def test_robustness(networks: list, exp_path, noise_levels=10, seeds=10, log_ste
value_vars=['Time to convergence', 'Time as fixpoint'],
var_name="Measurement",
value_name="Steps").sort_values('Noise Level')
df_melted.to_csv(model_path.parent / 'robustness_boxplot.csv', index=False)
# Plotting
# plt.rcParams.update({
# "text.usetex": True,
@ -108,8 +108,8 @@ def test_robustness(networks: list, exp_path, noise_levels=10, seeds=10, log_ste
# bx = sns.catplot(data=df[df['absolute_loss'] < 1], y='absolute_loss', x='application_step', kind='box',
# col='noise_level', col_wrap=3, showfliers=False)
filename = f"absolute_loss_perapplication_boxplot_grid_wild.png"
filepath = exp_path / filename
filename = f"robustness_boxplot.png"
filepath = model_path.parent / filename
plt.savefig(str(filepath))
plt.close('all')
return time_as_fixpoint, time_to_vergence

View File

@ -61,13 +61,13 @@ def test_for_fixpoints(fixpoint_counter: Dict, nets: List, id_functions=None):
if is_divergent(net):
fixpoint_counter[FixTypes.divergent] += 1
net.is_fixpoint = FixTypes.divergent
elif is_zero_fixpoint(net):
fixpoint_counter[FixTypes.fix_zero] += 1
net.is_fixpoint = FixTypes.fix_zero
elif is_identity_function(net): # is default value
fixpoint_counter[FixTypes.identity_func] += 1
net.is_fixpoint = FixTypes.identity_func
id_functions.append(net)
elif is_zero_fixpoint(net):
fixpoint_counter[FixTypes.fix_zero] += 1
net.is_fixpoint = FixTypes.fix_zero
elif is_secondary_fixpoint(net):
fixpoint_counter[FixTypes.fix_sec] += 1
net.is_fixpoint = FixTypes.fix_sec

View File

@ -1,10 +1,9 @@
# # # Imports
from collections import defaultdict
from pathlib import Path
import platform
import pandas as pd
import torchmetrics
import numpy as np
import torch
@ -19,7 +18,12 @@ from tqdm import tqdm
# noinspection DuplicatedCode
from experiments.meta_task_utility import (ToFloat, new_storage_df, train_task, checkpoint_and_validate, flat_for_store,
plot_training_result, plot_training_particle_types,
plot_network_connectivity_by_fixtype, run_particle_dropout_and_plot)
plot_network_connectivity_by_fixtype, run_particle_dropout_and_plot,
highlight_fixpoints_vs_mnist_mean, AddGaussianNoise,
plot_training_results_over_n_seeds, sanity_weight_swap,
FINAL_CHECKPOINT_NAME)
from experiments.robustness_tester import test_robustness
from plot_3d_trajectories import plot_single_3d_trajectories_by_layer, plot_grouped_3d_trajectories_by_layer
if platform.node() == 'CarbonX':
debug = True
@ -29,33 +33,40 @@ if platform.node() == 'CarbonX':
else:
debug = False
from network import MetaNet
from network import MetaNet, FixTypes
from functionalities_test import test_for_fixpoints
utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0), AddGaussianNoise()])
WORKER = 10 if not debug else 2
debug = False
BATCHSIZE = 2000 if not debug else 50
EPOCH = 50
EPOCH = 200
VALIDATION_FRQ = 3 if not debug else 1
VALIDATION_METRIC = torchmetrics.Accuracy
VAL_METRIC_CLASS = torchmetrics.Accuracy
# noinspection PyProtectedMember
VAL_METRIC_NAME = VALIDATION_METRIC()._get_name()
VAL_METRIC_NAME = VAL_METRIC_CLASS()._get_name()
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DATA_PATH = Path('data')
DATA_PATH.mkdir(exist_ok=True, parents=True)
if debug:
torch.autograd.set_detect_anomaly(True)
try:
plot_dataset = MNIST(str(DATA_PATH), transform=utility_transforms, train=False)
except RuntimeError:
plot_dataset = MNIST(str(DATA_PATH), transform=utility_transforms, train=False, download=True)
plot_loader = DataLoader(plot_dataset, batch_size=BATCHSIZE, shuffle=False,
drop_last=True, num_workers=WORKER)
if __name__ == '__main__':
training = True
n_st = 150 # per batch !!
training = False
plotting = False
robustnes = True # EXPENSIV!!!!!!!
n_st = 300 # per batch !!
activation = None # nn.ReLU()
for weight_hidden_size in [4, 5, 6]:
for weight_hidden_size in [3]:
weight_hidden_size = weight_hidden_size
residual_skip = True
@ -73,11 +84,7 @@ if __name__ == '__main__':
st_str = f'_nst_{n_st}'
config_str = f'{res_str}{ac_str}{st_str}'
exp_path = Path('output') / f'add_st_{EPOCH}_{weight_hidden_size}{config_str}'
if not training:
# noinspection PyRedeclaration
exp_path = Path('output') / 'add_st_50_5'
exp_path = Path('output') / f'mn_st_{EPOCH}_{weight_hidden_size}{config_str}_gauss'
for seed in range(n_seeds):
seed_path = exp_path / str(seed)
@ -90,8 +97,6 @@ if __name__ == '__main__':
# Check if files do exist on project location, warn and break.
for path in [df_store_path, weight_store_path]:
assert not path.exists(), f'Path "{path}" already exists. Check your configuration!'
utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)])
try:
train_dataset = MNIST(str(DATA_PATH), transform=utility_transforms)
except RuntimeError:
@ -122,8 +127,13 @@ if __name__ == '__main__':
metanet = metanet.train()
# Init metrics, even we do not need:
metric = VALIDATION_METRIC()
n_st_per_batch = n_st // len(train_loader)
metric = VAL_METRIC_CLASS()
n_st_per_batch = max(n_st // len(train_loader), 1)
if is_validation_epoch:
for particle in metanet.particles:
weight_log = (epoch, particle.name, *flat_for_store(particle.parameters()))
weight_store.loc[weight_store.shape[0]] = weight_log
for batch, (batch_x, batch_y) in tqdm(enumerate(train_loader),
total=len(train_loader), desc='MetaNet Train - Batch'
@ -166,17 +176,14 @@ if __name__ == '__main__':
train_store.loc[train_store.shape[0]] = val_step_log
tqdm.write(f'Fixpoint Tester Results: {counter_dict}')
# FLUSH to disk
if is_validation_epoch:
for particle in metanet.particles:
weight_log = (epoch, particle.name, *flat_for_store(particle.parameters()))
weight_store.loc[weight_store.shape[0]] = weight_log
train_store.to_csv(df_store_path, mode='a',
header=not df_store_path.exists(), index=False)
weight_store.to_csv(weight_store_path, mode='a',
header=not weight_store_path.exists(), index=False)
train_store = new_storage_df('train', None)
weight_store = new_storage_df('weights', metanet.particle_parameter_count)
# FLUSH to disk
if is_validation_epoch:
train_store.to_csv(df_store_path, mode='a',
header=not df_store_path.exists(), index=False)
weight_store.to_csv(weight_store_path, mode='a',
header=not weight_store_path.exists(), index=False)
train_store = new_storage_df('train', None)
weight_store = new_storage_df('weights', metanet.particle_parameter_count)
###########################################################
# EPOCHS endet
@ -200,38 +207,48 @@ if __name__ == '__main__':
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False)
weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
plot_training_result(df_store_path)
plot_training_particle_types(df_store_path)
try:
model_path = next(seed_path.glob(f'*e{EPOCH}.tp'))
model_path = next(seed_path.glob(f'*{FINAL_CHECKPOINT_NAME}'))
except StopIteration:
print('Model pattern did not trigger.')
print(f'Search path was: {seed_path}:')
print(f'Found Models are: {list(seed_path.rglob(".tp"))}')
exit(1)
try:
# noinspection PyUnboundLocalVariable
run_particle_dropout_and_plot(model_path, valid_loader=valid_loader, metric_class=VALIDATION_METRIC)
except (ValueError, NameError) as e:
print(e)
try:
plot_network_connectivity_by_fixtype(model_path)
except (ValueError, NameError) as e:
print(e)
if plotting:
highlight_fixpoints_vs_mnist_mean(model_path, plot_loader)
if n_seeds >= 2:
combined_df_store_path = exp_path.parent / f'comb_train_{exp_path.stem[:-1]}n.csv'
# noinspection PyUnboundLocalVariable
found_train_stores = exp_path.rglob(df_store_path.name)
train_dfs = []
for found_train_store in found_train_stores:
train_store_df = pd.read_csv(found_train_store, index_col=False)
train_store_df['Seed'] = int(found_train_store.parent.name)
train_dfs.append(train_store_df)
combined_train_df = pd.concat(train_dfs)
combined_train_df.to_csv(combined_df_store_path, index=False)
plot_training_result(combined_df_store_path, metric=VAL_METRIC_NAME,
plot_name=f"{combined_df_store_path.stem}.png"
)
plot_training_result(df_store_path)
plot_training_particle_types(df_store_path)
try:
# noinspection PyUnboundLocalVariable
run_particle_dropout_and_plot(model_path, valid_loader=plot_loader, metric_class=VAL_METRIC_CLASS)
except (ValueError, NameError) as e:
print(e)
try:
plot_network_connectivity_by_fixtype(model_path)
except (ValueError, NameError) as e:
print(e)
highlight_fixpoints_vs_mnist_mean(model_path, plot_loader)
try:
for fixtype in FixTypes.all_types():
plot_single_3d_trajectories_by_layer(model_path, weight_store_path, status_type=fixtype)
plot_grouped_3d_trajectories_by_layer(model_path, weight_store_path, status_type=fixtype)
except ValueError as e:
print('ERROR:', e)
if robustnes:
try:
test_robustness(model_path, seeds=1)
except ValueError as e:
print('ERROR:', e)
if 2 <= n_seeds <= sum(list(x.is_dir() for x in exp_path.iterdir())):
if plotting:
plot_training_results_over_n_seeds(exp_path, metric_name=VAL_METRIC_NAME)
sanity_weight_swap(exp_path, plot_loader, VAL_METRIC_CLASS)

View File

@ -2,7 +2,6 @@ from collections import defaultdict
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torchmetrics
from torch import nn
@ -15,32 +14,35 @@ from network import MetaNet
from functionalities_test import test_for_fixpoints, FixTypes as ft
from experiments.meta_task_utility import new_storage_df, flat_for_store, plot_training_result, \
plot_training_particle_types, run_particle_dropout_and_plot, plot_network_connectivity_by_fixtype, \
checkpoint_and_validate
checkpoint_and_validate, plot_training_results_over_n_seeds, sanity_weight_swap, FINAL_CHECKPOINT_NAME
from plot_3d_trajectories import plot_single_3d_trajectories_by_layer, plot_grouped_3d_trajectories_by_layer
WORKER = 0
BATCHSIZE = 50
EPOCH = 30
VALIDATION_FRQ = 3
VALIDATION_METRIC = torchmetrics.MeanAbsoluteError
VAL_METRIC_CLASS = torchmetrics.MeanAbsoluteError
# noinspection PyProtectedMember
VAL_METRIC_NAME = VALIDATION_METRIC()._get_name()
VAL_METRIC_NAME = VAL_METRIC_CLASS()._get_name()
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
plot_loader = DataLoader(AddTaskDataset(), batch_size=BATCHSIZE, shuffle=True,
drop_last=True, num_workers=WORKER)
if __name__ == '__main__':
training = False
plotting = False
n_st = 700
plotting = True
n_st = 100
activation = None # nn.ReLU()
for weight_hidden_size in [3, 4, 5]:
for weight_hidden_size in [2]:
tsk_threshold = 0.85
weight_hidden_size = weight_hidden_size
residual_skip = True
n_seeds = 10
n_seeds = 3
depth = 3
width = 3
out = 1
@ -97,8 +99,8 @@ if __name__ == '__main__':
metanet = metanet.train()
# Init metrics, even we do not need:
metric = VALIDATION_METRIC()
n_st_per_batch = n_st // len(train_load)
metric = VAL_METRIC_CLASS()
n_st_per_batch = max(1, (n_st // len(train_load)))
for batch, (batch_x, batch_y) in tqdm(enumerate(train_load),
total=len(train_load), desc='MetaNet Train - Batch'
@ -125,7 +127,7 @@ if __name__ == '__main__':
train_store.loc[train_store.shape[0]] = validation_log
mae = checkpoint_and_validate(metanet, vali_load, seed_path, epoch,
validation_metric=VALIDATION_METRIC).item()
validation_metric=VAL_METRIC_CLASS).item()
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric=f'Test {VAL_METRIC_NAME}', Score=mae)
train_store.loc[train_store.shape[0]] = validation_log
@ -163,7 +165,7 @@ if __name__ == '__main__':
step_log = dict(Epoch=int(EPOCH), Batch=BATCHSIZE, Metric=key, Score=value)
train_store.loc[train_store.shape[0]] = step_log
accuracy = checkpoint_and_validate(metanet, vali_load, seed_path, EPOCH, final_model=True,
validation_metric=VALIDATION_METRIC)
validation_metric=VAL_METRIC_CLASS)
validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
Metric=f'Test {VAL_METRIC_NAME}', Score=accuracy.item())
for particle in metanet.particles:
@ -175,11 +177,11 @@ if __name__ == '__main__':
weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
if plotting:
plot_training_result(df_store_path, metric=VAL_METRIC_NAME)
plot_training_result(df_store_path, metric_name=VAL_METRIC_NAME)
plot_training_particle_types(df_store_path)
try:
model_path = next(seed_path.glob(f'*e{EPOCH}.tp'))
model_path = next(seed_path.glob(f'*{FINAL_CHECKPOINT_NAME}'))
except StopIteration:
print('####################################################')
print('ERROR: Model pattern did not trigger.')
@ -190,7 +192,7 @@ if __name__ == '__main__':
try:
# noinspection PyUnboundLocalVariable
run_particle_dropout_and_plot(model_path, valid_loader=vali_load, metric_class=VALIDATION_METRIC)
run_particle_dropout_and_plot(model_path, valid_loader=plot_loader, metric_class=VAL_METRIC_CLASS)
except ValueError as e:
print('ERROR:', e)
try:
@ -204,24 +206,15 @@ if __name__ == '__main__':
plot_grouped_3d_trajectories_by_layer(model_path, weight_store_path, status_type=ft.other_func)
except ValueError as e:
print('ERROR:', e)
try:
model_path = next(seed_path.glob(f'*e{EPOCH}.tp'))
model = torch.load(model_path, map_location='cpu')
test_robustness(list(model.particles), seed_path)
except ValueError as e:
print('ERROR:', e)
try:
test_robustness(model_path, seeds=10)
pass
except ValueError as e:
print('ERROR:', e)
if n_seeds >= 2:
combined_df_store_path = exp_path.parent / f'comb_train_{exp_path.stem[:-1]}n.csv'
# noinspection PyUnboundLocalVariable
found_train_stores = exp_path.rglob(df_store_path.name)
train_dfs = []
for found_train_store in found_train_stores:
train_store_df = pd.read_csv(found_train_store, index_col=False)
train_store_df['Seed'] = int(found_train_store.parent.name)
train_dfs.append(train_store_df)
combined_train_df = pd.concat(train_dfs)
combined_train_df.to_csv(combined_df_store_path, index=False)
plot_training_result(combined_df_store_path, metric=VAL_METRIC_NAME,
plot_name=f"{combined_df_store_path.stem}.png"
)
if 2 <= n_seeds == sum(list(x.is_dir() for x in exp_path.iterdir())):
if plotting:
plot_training_results_over_n_seeds(exp_path, metric_name=VAL_METRIC_NAME)
sanity_weight_swap(exp_path, plot_loader, VAL_METRIC_CLASS)

View File

@ -443,7 +443,8 @@ class MetaNet(nn.Module):
{key: torch.zeros_like(state) for key, state in particle.state_dict().items()}
)
replaced_particles += 1
tqdm.write(f'Particle Parameters replaced: {str(replaced_particles)}')
if replaced_particles != 0:
tqdm.write(f'Particle Parameters replaced: {str(replaced_particles)}')
return self
def forward(self, x):
@ -538,17 +539,15 @@ class MetaNetCompareBaseline(nn.Module):
def forward(self, x):
tensor = self._first_layer(x)
if self.activation:
tensor = self.activation(tensor)
residual = None
for idx, meta_layer in enumerate(self._meta_layer_list, start=1):
tensor = meta_layer(tensor)
if idx % 2 == 1 and self.residual_skip:
# if idx % 2 == 1 and self.residual_skip:
if self.residual_skip:
residual = tensor
if idx % 2 == 0 and self.residual_skip:
tensor = meta_layer(tensor)
# if idx % 2 == 0 and self.residual_skip:
if self.residual_skip:
tensor = tensor + residual
if self.activation:
tensor = self.activation(tensor)
tensor = self._last_layer(tensor)
return tensor

View File

@ -24,34 +24,35 @@ def plot_single_3d_trajectories_by_layer(model_path, all_weights_path, status_ty
fixpoint_statuses = [net.is_fixpoint for net in model_layer.particles]
num_status_of_layer = sum([net.is_fixpoint == status_type for net in model_layer.particles])
layer = all_weights[all_weights.Weight.str.startswith(f"L{layer_idx}")]
weight_batches = [np.array(layer[layer.Weight == name].values.tolist())[:, 2:]
for name in layer.Weight.unique()]
plt.clf()
fig = plt.figure()
ax = plt.axes(projection='3d')
fig.set_figheight(10)
fig.set_figwidth(12)
plt.tight_layout()
if num_status_of_layer != 0:
layer = all_weights[all_weights.Weight.str.startswith(f"L{layer_idx}")]
weight_batches = [np.array(layer[layer.Weight == name].values.tolist())[:, 2:]
for name in layer.Weight.unique()]
plt.clf()
fig = plt.figure()
ax = plt.axes(projection='3d')
fig.set_figheight(10)
fig.set_figwidth(12)
plt.tight_layout()
for weights_of_net, status in zip(weight_batches, fixpoint_statuses):
if status == status_type:
pca.fit(weights_of_net)
transformed_trajectory = pca.transform(weights_of_net)
xdata = transformed_trajectory[:, 0]
ydata = transformed_trajectory[:, 1]
zdata = all_epochs
ax.plot3D(xdata, ydata, zdata)
ax.scatter(xdata, ydata, zdata, s=7)
ax.set_title(f"Layer {layer_idx}: {num_status_of_layer}-{status_type}", fontsize=20)
ax.set_xlabel('PCA Transformed x-axis', fontsize=20)
ax.set_ylabel('PCA Transformed y-axis', fontsize=20)
ax.set_zlabel('Epochs', fontsize=30, rotation=0)
file_path = save_path / f"layer_{layer_idx}_{num_status_of_layer}_{status_type}.png"
plt.savefig(file_path, bbox_inches="tight", dpi=300, format="png")
plt.clf()
plt.close(fig)
for weights_of_net, status in zip(weight_batches, fixpoint_statuses):
if status == status_type:
pca.fit(weights_of_net)
transformed_trajectory = pca.transform(weights_of_net)
xdata = transformed_trajectory[:, 0]
ydata = transformed_trajectory[:, 1]
zdata = all_epochs
ax.plot3D(xdata, ydata, zdata)
ax.scatter(xdata, ydata, zdata, s=7)
ax.set_title(f"Layer {layer_idx}: {num_status_of_layer}-{status_type}", fontsize=20)
ax.set_xlabel('PCA Transformed x-axis', fontsize=20)
ax.set_ylabel('PCA Transformed y-axis', fontsize=20)
ax.set_zlabel('Epochs', fontsize=30, rotation=0)
file_path = save_path / f"layer_{layer_idx}_{num_status_of_layer}_{status_type}.png"
plt.savefig(file_path, bbox_inches="tight", dpi=300, format="png")
plt.clf()
plt.close(fig)
def plot_grouped_3d_trajectories_by_layer(model_path, all_weights_path, status_type: FixTypes):

View File

@ -28,9 +28,10 @@ def extract_weights_from_model(model:MetaNet)->dict:
return dict(weights)
def test_weights_as_model(meta_net, new_weights:dict, data):
transfer_net = MetaNetCompareBaseline(meta_net.interface, depth=meta_net.depth, width=meta_net.width, out=meta_net.out,
residual_skip=True)
def test_weights_as_model(meta_net, new_weights, data, metric_class=torchmetrics.Accuracy):
transfer_net = MetaNetCompareBaseline(meta_net.interface, depth=meta_net.depth,
width=meta_net.width, out=meta_net.out,
residual_skip=meta_net.residual_skip)
with torch.no_grad():
new_weight_values = list(new_weights.values())
old_parameters = list(transfer_net.parameters())
@ -39,40 +40,18 @@ def test_weights_as_model(meta_net, new_weights:dict, data):
parameters[:] = torch.Tensor(weights).view(parameters.shape)[:]
transfer_net.eval()
# Test if the margin of error is similar
im_t = defaultdict(list)
rand = torch.randn((1, 15 * 15))
for net in [meta_net, transfer_net]:
tensor = rand.clone()
for layer in net.all_layers:
tensor = layer(tensor)
im_t[net.__class__.__name__].append(tensor.detach())
im_t = dict(im_t)
all_close = {f'layer_{idx}': torch.allclose(y1.detach(), y2.detach(), rtol=0, atol=e
) for idx, (y1, y2) in enumerate(zip(*im_t.values()))
}
print(f'Cummulative differences per layer is smaller then {e}:\n {all_close}')
# all_errors = {f'layer_{idx}': torch.absolute(y1.detach(), y2.detach(), rtol=0, atol=e
# ) for idx, (y1, y2) in enumerate(zip(*im_t.values()))
# }
results = dict()
for net in [meta_net, transfer_net]:
net.eval()
metric = torchmetrics.Accuracy()
with tqdm(desc='Test Batch: ') as pbar:
for batch, (batch_x, batch_y) in tqdm(enumerate(data), total=len(data), desc='MetaNet Sanity Check'):
y = net(batch_x)
acc = metric(y.cpu(), batch_y.cpu())
pbar.set_postfix_str(f'Acc: {acc}')
pbar.update()
metric = metric_class()
for batch, (batch_x, batch_y) in tqdm(enumerate(data), total=len(data), desc='Test Batch: '):
y = net(batch_x)
metric(y.cpu(), batch_y.cpu())
# metric on all batches using custom accumulation
acc = metric.compute()
tqdm.write(f"Avg. accuracy on {net.__class__.__name__}: {acc}")
# metric on all batches using custom accumulation
measure = metric.compute()
results[net.__class__.__name__] = measure.item()
return results
if __name__ == '__main__':