Late Evening Commit

This commit is contained in:
Steffen Illium
2022-03-06 22:24:00 +01:00
parent b3d4987cb8
commit ce5a36c8f4
8 changed files with 311 additions and 215 deletions

View File

@@ -13,20 +13,33 @@ from torch.utils.data import Dataset
from tqdm import tqdm
from network import FixTypes as ft
from functionalities_test import test_for_fixpoints
from functionalities_test import test_for_fixpoints, FixTypes as ft
from sanity_check_weights import test_weights_as_model, extract_weights_from_model
WORKER = 10
BATCHSIZE = 500
EPOCH = 50
VALIDATION_FRQ = 3
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DATA_PATH = Path('data')
DATA_PATH.mkdir(exist_ok=True, parents=True)
PALETTE = sns.color_palette()
PALETTE.insert(0, PALETTE.pop(1)) # Orange First
FINAL_CHECKPOINT_NAME = f'trained_model_ckpt_FINAL.tp'
class AddGaussianNoise(object):
def __init__(self, ratio=1e-4):
self.ratio = ratio
def __call__(self, tensor: torch.Tensor):
return tensor + (torch.randn_like(tensor, device=tensor.device) * self.ratio)
def __repr__(self):
return self.__class__.__name__ + f'(ratio={self.ratio}'
class ToFloat:
@@ -56,7 +69,10 @@ def set_checkpoint(model, out_path, epoch_n, final_model=False):
if not final_model:
ckpt_path = Path(out_path) / 'ckpt' / f'{epoch_n.zfill(4)}_model_ckpt.tp'
else:
ckpt_path = Path(out_path) / f'trained_model_ckpt_e{epoch_n}.tp'
if isinstance(epoch_n, str):
ckpt_path = Path(out_path) / f'{epoch_n}_{FINAL_CHECKPOINT_NAME}'
else:
ckpt_path = Path(out_path) / FINAL_CHECKPOINT_NAME
ckpt_path.parent.mkdir(exist_ok=True, parents=True)
torch.save(model, ckpt_path, pickle_protocol=pickle.HIGHEST_PROTOCOL)
@@ -109,26 +125,27 @@ def plot_training_particle_types(path_to_dataframe):
plt.close('all')
plt.clf()
# load from Drive
df = pd.read_csv(path_to_dataframe, index_col=False)
df = pd.read_csv(path_to_dataframe, index_col=False).sort_values('Metric')
# Set up figure
fig, ax = plt.subplots() # initializes figure and plots
data = df.loc[df['Metric'].isin(ft.all_types())]
fix_types = data['Metric'].unique()
data = data.pivot(index='Epoch', columns='Metric', values='Score').reset_index().fillna(0)
_ = plt.stackplot(data['Epoch'], *[data[fixtype] for fixtype in fix_types], labels=fix_types.tolist())
_ = plt.stackplot(data['Epoch'], *[data[fixtype] for fixtype in fix_types],
labels=fix_types.tolist(), colors=PALETTE)
ax.set(ylabel='Particle Count', xlabel='Epoch')
ax.set_title('Particle Type Count')
# ax.set_title('Particle Type Count')
fig.legend(loc="center right", title='Particle Type', bbox_to_anchor=(0.85, 0.5))
plt.tight_layout()
plt.savefig(Path(path_to_dataframe.parent / 'training_particle_type_lp.png'), dpi=300)
def plot_training_result(path_to_dataframe, metric='Accuracy', plot_name=None):
def plot_training_result(path_to_dataframe, metric_name='Accuracy', plot_name=None):
plt.clf()
# load from Drive
df = pd.read_csv(path_to_dataframe, index_col=False)
df = pd.read_csv(path_to_dataframe, index_col=False).sort_values('Metric')
# Check if this is a single lineplot or if aggregated
group = ['Epoch', 'Metric']
@@ -141,20 +158,22 @@ def plot_training_result(path_to_dataframe, metric='Accuracy', plot_name=None):
# plots the first set of data
data = df[(df['Metric'] == 'Task Loss') | (df['Metric'] == 'Self Train Loss')].groupby(['Epoch', 'Metric']).mean()
palette = sns.color_palette()[1:data.reset_index()['Metric'].unique().shape[0]+1]
grouped_for_lineplot = data.groupby(group).mean()
palette_len_1 = len(grouped_for_lineplot.droplevel(0).reset_index().Metric.unique())
sns.lineplot(data=data.groupby(group).mean(), x='Epoch', y='Score', hue='Metric',
palette=palette, ax=ax1, ci='sd')
sns.lineplot(data=grouped_for_lineplot, x='Epoch', y='Score', hue='Metric',
palette=PALETTE[:palette_len_1], ax=ax1, ci='sd')
# plots the second set of data
data = df[(df['Metric'] == f'Test {metric}') | (df['Metric'] == f'Train {metric}')]
palette = sns.color_palette()[len(palette)+1:data.reset_index()['Metric'].unique().shape[0] + len(palette)+1]
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', hue='Metric', palette=palette, ci='sd')
data = df[(df['Metric'] == f'Test {metric_name}') | (df['Metric'] == f'Train {metric_name}')]
palette_len_2 = len(data.Metric.unique())
sns.lineplot(data=data, x='Epoch', y='Score', hue='Metric',
palette=PALETTE[palette_len_1:palette_len_2+palette_len_1], ci='sd')
ax1.set(yscale='log', ylabel='Losses')
# ax1.set_title('Training Lineplot')
ax2.set(ylabel=metric)
if metric != 'MAE':
ax2.set(ylabel=metric_name)
if metric_name != 'Accuracy':
ax2.set(yscale='log')
fig.legend(loc="center right", title='Metric', bbox_to_anchor=(0.85, 0.5))
@@ -166,35 +185,39 @@ def plot_training_result(path_to_dataframe, metric='Accuracy', plot_name=None):
def plot_network_connectivity_by_fixtype(path_to_trained_model):
m = torch.load(path_to_trained_model, map_location=torch.device('cpu')).eval()
m = torch.load(path_to_trained_model, map_location=DEVICE).eval()
# noinspection PyProtectedMember
particles = list(m.particles)
df = pd.DataFrame(columns=['type', 'layer', 'neuron', 'name'])
df = pd.DataFrame(columns=['Type', 'Layer', 'Neuron', 'Name'])
for prtcl in particles:
l, c, w = [float(x) for x in re.sub("[^0-9|_]", "", prtcl.name).split('_')]
df.loc[df.shape[0]] = (prtcl.is_fixpoint, l-1, w, prtcl.name)
df.loc[df.shape[0]] = (prtcl.is_fixpoint, l, c, prtcl.name)
for layer in list(df['layer'].unique()):
for layer in list(df['Layer'].unique()):
# Rescale
divisor = df.loc[(df['layer'] == layer), 'neuron'].max()
df.loc[(df['layer'] == layer), 'neuron'] /= divisor
divisor = df.loc[(df['Layer'] == layer), 'Neuron'].max()
df.loc[(df['Layer'] == layer), 'Neuron'] /= divisor
tqdm.write(f'Connectivity Data gathered')
for n, fixtype in enumerate(ft.all_types()):
if df[df['type'] == fixtype].shape[0] > 0:
df = df.sort_values('Type')
n = 0
for fixtype in ft.all_types():
if df[df['Type'] == fixtype].shape[0] > 0:
plt.clf()
ax = sns.lineplot(y='neuron', x='layer', hue='name', data=df[df['type'] == fixtype],
ax = sns.lineplot(y='Neuron', x='Layer', hue='Name', data=df[df['Type'] == fixtype],
legend=False, estimator=None, lw=1)
_ = sns.lineplot(y=[0, 1], x=[-1, df['layer'].max()], legend=False, estimator=None, lw=0)
_ = sns.lineplot(y=[0, 1], x=[-1, df['Layer'].max()], legend=False, estimator=None, lw=0)
ax.set_title(fixtype)
lines = ax.get_lines()
for line in lines:
line.set_color(sns.color_palette()[n])
line.set_color(PALETTE[n])
plt.savefig(Path(path_to_trained_model.parent / f'net_connectivity_{fixtype}.png'), dpi=300)
tqdm.write(f'Connectivity plottet: {fixtype} - n = {df[df["type"] == fixtype].shape[0] // 2}')
tqdm.write(f'Connectivity plottet: {fixtype} - n = {df[df["Type"] == fixtype].shape[0] // 2}')
n += 1
else:
tqdm.write(f'No Connectivity {fixtype}')
# tqdm.write(f'No Connectivity {fixtype}')
pass
# noinspection PyProtectedMember
@@ -218,27 +241,33 @@ def run_particle_dropout_test(model_path, valid_loader, metric_class=torchmetric
tqdm.write(f'Zero_ident diff = {acc_diff}')
diff_df.loc[diff_df.shape[0]] = (fixpoint_type, acc_post, acc_diff)
diff_df.to_csv(diff_store_path, mode='a', header=not diff_store_path.exists(), index=False)
diff_df.to_csv(diff_store_path, mode='w', header=True, index=False)
return diff_store_path
# noinspection PyProtectedMember
def plot_dropout_stacked_barplot(mdl_path, diff_store_path, metric_class=torchmetrics.Accuracy):
metric_name = metric_class()._get_name()
diff_df = pd.read_csv(diff_store_path)
diff_df = pd.read_csv(diff_store_path).sort_values('Particle Type')
particle_dict = defaultdict(lambda: 0)
latest_model = torch.load(mdl_path, map_location=DEVICE).eval()
_ = test_for_fixpoints(particle_dict, list(latest_model.particles))
tqdm.write(str(dict(particle_dict)))
particle_dict = dict(particle_dict)
sorted_particle_dict = dict(sorted(particle_dict.items()))
tqdm.write(str(sorted_particle_dict))
plt.clf()
fig, ax = plt.subplots(ncols=2)
colors = sns.color_palette()[1:diff_df.shape[0]+1]
_ = sns.barplot(data=diff_df, y=metric_name, x='Particle Type', ax=ax[0], palette=colors)
colors = PALETTE.copy()
colors.insert(0, colors.pop(-1))
palette_len = len(diff_df['Particle Type'].unique())
_ = sns.barplot(data=diff_df, y=metric_name, x='Particle Type', ax=ax[0], palette=colors[:palette_len], ci=None)
ax[0].set_title(f'{metric_name} after particle dropout')
ax[0].set_xlabel('Particle Type')
ax[0].set_xticklabels(ax[0].get_xticklabels(), rotation=30)
ax[1].pie(particle_dict.values(), labels=particle_dict.keys(), colors=list(reversed(colors)), )
ax[1].pie(sorted_particle_dict.values(), labels=sorted_particle_dict.keys(),
colors=PALETTE[:len(sorted_particle_dict)])
ax[1].set_title('Particle Count')
plt.tight_layout()
@@ -278,5 +307,83 @@ def train_task(model, optimizer, loss_func, btch_x, btch_y) -> (dict, torch.Tens
return stp_log, y_prd
def highlight_fixpoints_vs_mnist_mean(mdl_path, dataloader):
latest_model = torch.load(mdl_path, map_location=DEVICE).eval()
activation_vector = torch.as_tensor([[0, 0, 0, 0, 1]], dtype=torch.float32, device=DEVICE)
binary_images = []
real_images = []
with torch.no_grad():
# noinspection PyProtectedMember
for cell in latest_model._meta_layer_first.meta_cell_list:
cell_image_binary = torch.zeros((len(cell.meta_weight_list)), device=DEVICE)
cell_image_real = torch.zeros((len(cell.meta_weight_list)), device=DEVICE)
for idx, particle in enumerate(cell.particles):
if particle.is_fixpoint == ft.identity_func:
cell_image_binary[idx] += 1
cell_image_real[idx] = particle(activation_vector).abs().squeeze().item()
binary_images.append(cell_image_binary.reshape((15, 15)))
real_images.append(cell_image_real.reshape((15, 15)))
binary_images = torch.stack(binary_images)
real_images = torch.stack(real_images)
binary_image = torch.sum(binary_images, keepdim=True, dim=0)
real_image = torch.sum(real_images, keepdim=True, dim=0)
mnist_images = [x for x, _ in dataloader]
mnist_mean = torch.cat(mnist_images).reshape(10000, 15, 15).abs().sum(dim=0)
fig, axs = plt.subplots(1, 3)
for idx, image in enumerate([binary_image, real_image, mnist_mean]):
img = axs[idx].imshow(image.squeeze().detach().cpu())
img.axes.axis('off')
plt.tight_layout()
plt.savefig(mdl_path.parent / 'heatmap.png', dpi=300)
plt.clf()
plt.close('all')
def plot_training_results_over_n_seeds(exp_path, df_train_store_name='train_store.csv', metric_name='Accuracy'):
combined_df_store_path = exp_path / f'comb_train_{exp_path.stem[:-1]}n.csv'
# noinspection PyUnboundLocalVariable
found_train_stores = exp_path.rglob(df_train_store_name)
train_dfs = []
for found_train_store in found_train_stores:
train_store_df = pd.read_csv(found_train_store, index_col=False)
train_store_df['Seed'] = int(found_train_store.parent.name)
train_dfs.append(train_store_df)
combined_train_df = pd.concat(train_dfs)
combined_train_df.to_csv(combined_df_store_path, index=False)
plot_training_result(combined_df_store_path, metric_name=metric_name,
plot_name=f"{combined_df_store_path.stem}.png"
)
plt.clf()
plt.close('all')
def sanity_weight_swap(exp_path, dataloader, metric_class=torchmetrics.Accuracy):
# noinspection PyProtectedMember
metric_name = metric_class()._get_name()
found_models = exp_path.rglob(f'*{FINAL_CHECKPOINT_NAME}')
df = pd.DataFrame(columns=['Seed', 'Model', metric_name])
for model_idx, found_model in enumerate(found_models):
model = torch.load(found_model, map_location=DEVICE).eval()
weights = extract_weights_from_model(model)
results = test_weights_as_model(model, weights, dataloader, metric_class=metric_class)
for model_name, measurement in results.items():
df.loc[df.shape[0]] = (model_idx, model_name, measurement)
df.loc[df.shape[0]] = (model_idx, 'Difference', np.abs(np.subtract(*results.values())))
df.to_csv(exp_path / 'sanity_weight_swap.csv', index=False)
_ = sns.boxplot(data=df, x='Model', y=metric_name)
plt.tight_layout()
plt.savefig(exp_path / 'sanity_weight_swap.png', dpi=300)
plt.clf()
plt.close('all')
if __name__ == '__main__':
raise NotImplementedError('Test this here!!!')

View File

@@ -1,17 +1,13 @@
import pickle
import pandas as pd
import torch
import random
import copy
from pathlib import Path
from tqdm import tqdm
from functionalities_test import is_identity_function, is_zero_fixpoint, test_for_fixpoints, is_divergent
from functionalities_test import (is_identity_function, is_zero_fixpoint, test_for_fixpoints, is_divergent,
FixTypes as FT)
from network import Net
from torch.nn import functional as F
from visualization import plot_loss, bar_chart_fixpoints
import seaborn as sns
from matplotlib import pyplot as plt
@@ -26,7 +22,6 @@ def generate_perfekt_synthetic_fixpoint_weights():
[1.0], [0.0]
], dtype=torch.float32)
PALETTE = 10 * (
"#377eb8",
"#4daf4a",
@@ -44,14 +39,16 @@ PALETTE = 10 * (
)
def test_robustness(networks: list, exp_path, noise_levels=10, seeds=10, log_step_size=10):
def test_robustness(model_path, noise_levels=10, seeds=10, log_step_size=10):
model = torch.load(model_path, map_location='cpu')
networks = [x for x in model.particles if x.is_fixpoint == FT.identity_func]
time_to_vergence = [[0 for _ in range(noise_levels)] for _ in range(len(networks))]
time_as_fixpoint = [[0 for _ in range(noise_levels)] for _ in range(len(networks))]
row_headers = []
df = pd.DataFrame(columns=['setting', 'Noise Level', 'Self Train Steps', 'absolute_loss',
'Time to convergence', 'Time as fixpoint'])
with tqdm(total=max(len(networks), seeds)) as pbar:
with tqdm(total=(seeds * noise_levels * len(networks))) as pbar:
for setting, fixpoint in enumerate(networks): # 1 / n
row_headers.append(fixpoint.name)
for seed in range(seeds): # n / 1
@@ -84,7 +81,7 @@ def test_robustness(networks: list, exp_path, noise_levels=10, seeds=10, log_ste
steps, absolute_loss,
time_to_vergence[setting][noise_level],
time_as_fixpoint[setting][noise_level]]
pbar.update(1)
pbar.update(1)
# Get the measuremts at the highest time_time_to_vergence
df_sorted = df.sort_values('Self Train Steps', ascending=False).drop_duplicates(['setting', 'Noise Level'])
@@ -92,6 +89,9 @@ def test_robustness(networks: list, exp_path, noise_levels=10, seeds=10, log_ste
value_vars=['Time to convergence', 'Time as fixpoint'],
var_name="Measurement",
value_name="Steps").sort_values('Noise Level')
df_melted.to_csv(model_path.parent / 'robustness_boxplot.csv', index=False)
# Plotting
# plt.rcParams.update({
# "text.usetex": True,
@@ -108,8 +108,8 @@ def test_robustness(networks: list, exp_path, noise_levels=10, seeds=10, log_ste
# bx = sns.catplot(data=df[df['absolute_loss'] < 1], y='absolute_loss', x='application_step', kind='box',
# col='noise_level', col_wrap=3, showfliers=False)
filename = f"absolute_loss_perapplication_boxplot_grid_wild.png"
filepath = exp_path / filename
filename = f"robustness_boxplot.png"
filepath = model_path.parent / filename
plt.savefig(str(filepath))
plt.close('all')
return time_as_fixpoint, time_to_vergence