Late Evening Commit
This commit is contained in:
parent
b3d4987cb8
commit
ce5a36c8f4
@ -13,20 +13,33 @@ from torch.utils.data import Dataset
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
from network import FixTypes as ft
|
||||
from functionalities_test import test_for_fixpoints
|
||||
from functionalities_test import test_for_fixpoints, FixTypes as ft
|
||||
from sanity_check_weights import test_weights_as_model, extract_weights_from_model
|
||||
|
||||
WORKER = 10
|
||||
BATCHSIZE = 500
|
||||
EPOCH = 50
|
||||
VALIDATION_FRQ = 3
|
||||
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
DATA_PATH = Path('data')
|
||||
DATA_PATH.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
PALETTE = sns.color_palette()
|
||||
PALETTE.insert(0, PALETTE.pop(1)) # Orange First
|
||||
|
||||
FINAL_CHECKPOINT_NAME = f'trained_model_ckpt_FINAL.tp'
|
||||
|
||||
|
||||
class AddGaussianNoise(object):
|
||||
def __init__(self, ratio=1e-4):
|
||||
self.ratio = ratio
|
||||
|
||||
def __call__(self, tensor: torch.Tensor):
|
||||
return tensor + (torch.randn_like(tensor, device=tensor.device) * self.ratio)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + f'(ratio={self.ratio}'
|
||||
|
||||
|
||||
class ToFloat:
|
||||
|
||||
@ -56,7 +69,10 @@ def set_checkpoint(model, out_path, epoch_n, final_model=False):
|
||||
if not final_model:
|
||||
ckpt_path = Path(out_path) / 'ckpt' / f'{epoch_n.zfill(4)}_model_ckpt.tp'
|
||||
else:
|
||||
ckpt_path = Path(out_path) / f'trained_model_ckpt_e{epoch_n}.tp'
|
||||
if isinstance(epoch_n, str):
|
||||
ckpt_path = Path(out_path) / f'{epoch_n}_{FINAL_CHECKPOINT_NAME}'
|
||||
else:
|
||||
ckpt_path = Path(out_path) / FINAL_CHECKPOINT_NAME
|
||||
ckpt_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
torch.save(model, ckpt_path, pickle_protocol=pickle.HIGHEST_PROTOCOL)
|
||||
@ -109,26 +125,27 @@ def plot_training_particle_types(path_to_dataframe):
|
||||
plt.close('all')
|
||||
plt.clf()
|
||||
# load from Drive
|
||||
df = pd.read_csv(path_to_dataframe, index_col=False)
|
||||
df = pd.read_csv(path_to_dataframe, index_col=False).sort_values('Metric')
|
||||
# Set up figure
|
||||
fig, ax = plt.subplots() # initializes figure and plots
|
||||
data = df.loc[df['Metric'].isin(ft.all_types())]
|
||||
fix_types = data['Metric'].unique()
|
||||
data = data.pivot(index='Epoch', columns='Metric', values='Score').reset_index().fillna(0)
|
||||
_ = plt.stackplot(data['Epoch'], *[data[fixtype] for fixtype in fix_types], labels=fix_types.tolist())
|
||||
_ = plt.stackplot(data['Epoch'], *[data[fixtype] for fixtype in fix_types],
|
||||
labels=fix_types.tolist(), colors=PALETTE)
|
||||
|
||||
ax.set(ylabel='Particle Count', xlabel='Epoch')
|
||||
ax.set_title('Particle Type Count')
|
||||
# ax.set_title('Particle Type Count')
|
||||
|
||||
fig.legend(loc="center right", title='Particle Type', bbox_to_anchor=(0.85, 0.5))
|
||||
plt.tight_layout()
|
||||
plt.savefig(Path(path_to_dataframe.parent / 'training_particle_type_lp.png'), dpi=300)
|
||||
|
||||
|
||||
def plot_training_result(path_to_dataframe, metric='Accuracy', plot_name=None):
|
||||
def plot_training_result(path_to_dataframe, metric_name='Accuracy', plot_name=None):
|
||||
plt.clf()
|
||||
# load from Drive
|
||||
df = pd.read_csv(path_to_dataframe, index_col=False)
|
||||
df = pd.read_csv(path_to_dataframe, index_col=False).sort_values('Metric')
|
||||
|
||||
# Check if this is a single lineplot or if aggregated
|
||||
group = ['Epoch', 'Metric']
|
||||
@ -141,20 +158,22 @@ def plot_training_result(path_to_dataframe, metric='Accuracy', plot_name=None):
|
||||
|
||||
# plots the first set of data
|
||||
data = df[(df['Metric'] == 'Task Loss') | (df['Metric'] == 'Self Train Loss')].groupby(['Epoch', 'Metric']).mean()
|
||||
palette = sns.color_palette()[1:data.reset_index()['Metric'].unique().shape[0]+1]
|
||||
grouped_for_lineplot = data.groupby(group).mean()
|
||||
palette_len_1 = len(grouped_for_lineplot.droplevel(0).reset_index().Metric.unique())
|
||||
|
||||
sns.lineplot(data=data.groupby(group).mean(), x='Epoch', y='Score', hue='Metric',
|
||||
palette=palette, ax=ax1, ci='sd')
|
||||
sns.lineplot(data=grouped_for_lineplot, x='Epoch', y='Score', hue='Metric',
|
||||
palette=PALETTE[:palette_len_1], ax=ax1, ci='sd')
|
||||
|
||||
# plots the second set of data
|
||||
data = df[(df['Metric'] == f'Test {metric}') | (df['Metric'] == f'Train {metric}')]
|
||||
palette = sns.color_palette()[len(palette)+1:data.reset_index()['Metric'].unique().shape[0] + len(palette)+1]
|
||||
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', hue='Metric', palette=palette, ci='sd')
|
||||
data = df[(df['Metric'] == f'Test {metric_name}') | (df['Metric'] == f'Train {metric_name}')]
|
||||
palette_len_2 = len(data.Metric.unique())
|
||||
sns.lineplot(data=data, x='Epoch', y='Score', hue='Metric',
|
||||
palette=PALETTE[palette_len_1:palette_len_2+palette_len_1], ci='sd')
|
||||
|
||||
ax1.set(yscale='log', ylabel='Losses')
|
||||
# ax1.set_title('Training Lineplot')
|
||||
ax2.set(ylabel=metric)
|
||||
if metric != 'MAE':
|
||||
ax2.set(ylabel=metric_name)
|
||||
if metric_name != 'Accuracy':
|
||||
ax2.set(yscale='log')
|
||||
|
||||
fig.legend(loc="center right", title='Metric', bbox_to_anchor=(0.85, 0.5))
|
||||
@ -166,35 +185,39 @@ def plot_training_result(path_to_dataframe, metric='Accuracy', plot_name=None):
|
||||
|
||||
|
||||
def plot_network_connectivity_by_fixtype(path_to_trained_model):
|
||||
m = torch.load(path_to_trained_model, map_location=torch.device('cpu')).eval()
|
||||
m = torch.load(path_to_trained_model, map_location=DEVICE).eval()
|
||||
# noinspection PyProtectedMember
|
||||
particles = list(m.particles)
|
||||
df = pd.DataFrame(columns=['type', 'layer', 'neuron', 'name'])
|
||||
df = pd.DataFrame(columns=['Type', 'Layer', 'Neuron', 'Name'])
|
||||
|
||||
for prtcl in particles:
|
||||
l, c, w = [float(x) for x in re.sub("[^0-9|_]", "", prtcl.name).split('_')]
|
||||
df.loc[df.shape[0]] = (prtcl.is_fixpoint, l-1, w, prtcl.name)
|
||||
df.loc[df.shape[0]] = (prtcl.is_fixpoint, l, c, prtcl.name)
|
||||
for layer in list(df['layer'].unique()):
|
||||
for layer in list(df['Layer'].unique()):
|
||||
# Rescale
|
||||
divisor = df.loc[(df['layer'] == layer), 'neuron'].max()
|
||||
df.loc[(df['layer'] == layer), 'neuron'] /= divisor
|
||||
divisor = df.loc[(df['Layer'] == layer), 'Neuron'].max()
|
||||
df.loc[(df['Layer'] == layer), 'Neuron'] /= divisor
|
||||
|
||||
tqdm.write(f'Connectivity Data gathered')
|
||||
for n, fixtype in enumerate(ft.all_types()):
|
||||
if df[df['type'] == fixtype].shape[0] > 0:
|
||||
df = df.sort_values('Type')
|
||||
n = 0
|
||||
for fixtype in ft.all_types():
|
||||
if df[df['Type'] == fixtype].shape[0] > 0:
|
||||
plt.clf()
|
||||
ax = sns.lineplot(y='neuron', x='layer', hue='name', data=df[df['type'] == fixtype],
|
||||
ax = sns.lineplot(y='Neuron', x='Layer', hue='Name', data=df[df['Type'] == fixtype],
|
||||
legend=False, estimator=None, lw=1)
|
||||
_ = sns.lineplot(y=[0, 1], x=[-1, df['layer'].max()], legend=False, estimator=None, lw=0)
|
||||
_ = sns.lineplot(y=[0, 1], x=[-1, df['Layer'].max()], legend=False, estimator=None, lw=0)
|
||||
ax.set_title(fixtype)
|
||||
lines = ax.get_lines()
|
||||
for line in lines:
|
||||
line.set_color(sns.color_palette()[n])
|
||||
line.set_color(PALETTE[n])
|
||||
plt.savefig(Path(path_to_trained_model.parent / f'net_connectivity_{fixtype}.png'), dpi=300)
|
||||
tqdm.write(f'Connectivity plottet: {fixtype} - n = {df[df["type"] == fixtype].shape[0] // 2}')
|
||||
tqdm.write(f'Connectivity plottet: {fixtype} - n = {df[df["Type"] == fixtype].shape[0] // 2}')
|
||||
n += 1
|
||||
else:
|
||||
tqdm.write(f'No Connectivity {fixtype}')
|
||||
# tqdm.write(f'No Connectivity {fixtype}')
|
||||
pass
|
||||
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
@ -218,27 +241,33 @@ def run_particle_dropout_test(model_path, valid_loader, metric_class=torchmetric
|
||||
tqdm.write(f'Zero_ident diff = {acc_diff}')
|
||||
diff_df.loc[diff_df.shape[0]] = (fixpoint_type, acc_post, acc_diff)
|
||||
|
||||
diff_df.to_csv(diff_store_path, mode='a', header=not diff_store_path.exists(), index=False)
|
||||
diff_df.to_csv(diff_store_path, mode='w', header=True, index=False)
|
||||
return diff_store_path
|
||||
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
def plot_dropout_stacked_barplot(mdl_path, diff_store_path, metric_class=torchmetrics.Accuracy):
|
||||
metric_name = metric_class()._get_name()
|
||||
diff_df = pd.read_csv(diff_store_path)
|
||||
diff_df = pd.read_csv(diff_store_path).sort_values('Particle Type')
|
||||
particle_dict = defaultdict(lambda: 0)
|
||||
latest_model = torch.load(mdl_path, map_location=DEVICE).eval()
|
||||
_ = test_for_fixpoints(particle_dict, list(latest_model.particles))
|
||||
tqdm.write(str(dict(particle_dict)))
|
||||
particle_dict = dict(particle_dict)
|
||||
sorted_particle_dict = dict(sorted(particle_dict.items()))
|
||||
tqdm.write(str(sorted_particle_dict))
|
||||
plt.clf()
|
||||
fig, ax = plt.subplots(ncols=2)
|
||||
colors = sns.color_palette()[1:diff_df.shape[0]+1]
|
||||
_ = sns.barplot(data=diff_df, y=metric_name, x='Particle Type', ax=ax[0], palette=colors)
|
||||
colors = PALETTE.copy()
|
||||
colors.insert(0, colors.pop(-1))
|
||||
palette_len = len(diff_df['Particle Type'].unique())
|
||||
_ = sns.barplot(data=diff_df, y=metric_name, x='Particle Type', ax=ax[0], palette=colors[:palette_len], ci=None)
|
||||
|
||||
ax[0].set_title(f'{metric_name} after particle dropout')
|
||||
ax[0].set_xlabel('Particle Type')
|
||||
ax[0].set_xticklabels(ax[0].get_xticklabels(), rotation=30)
|
||||
|
||||
ax[1].pie(particle_dict.values(), labels=particle_dict.keys(), colors=list(reversed(colors)), )
|
||||
ax[1].pie(sorted_particle_dict.values(), labels=sorted_particle_dict.keys(),
|
||||
colors=PALETTE[:len(sorted_particle_dict)])
|
||||
ax[1].set_title('Particle Count')
|
||||
|
||||
plt.tight_layout()
|
||||
@ -278,5 +307,83 @@ def train_task(model, optimizer, loss_func, btch_x, btch_y) -> (dict, torch.Tens
|
||||
return stp_log, y_prd
|
||||
|
||||
|
||||
def highlight_fixpoints_vs_mnist_mean(mdl_path, dataloader):
|
||||
latest_model = torch.load(mdl_path, map_location=DEVICE).eval()
|
||||
activation_vector = torch.as_tensor([[0, 0, 0, 0, 1]], dtype=torch.float32, device=DEVICE)
|
||||
binary_images = []
|
||||
real_images = []
|
||||
with torch.no_grad():
|
||||
# noinspection PyProtectedMember
|
||||
for cell in latest_model._meta_layer_first.meta_cell_list:
|
||||
cell_image_binary = torch.zeros((len(cell.meta_weight_list)), device=DEVICE)
|
||||
cell_image_real = torch.zeros((len(cell.meta_weight_list)), device=DEVICE)
|
||||
for idx, particle in enumerate(cell.particles):
|
||||
if particle.is_fixpoint == ft.identity_func:
|
||||
cell_image_binary[idx] += 1
|
||||
cell_image_real[idx] = particle(activation_vector).abs().squeeze().item()
|
||||
binary_images.append(cell_image_binary.reshape((15, 15)))
|
||||
real_images.append(cell_image_real.reshape((15, 15)))
|
||||
|
||||
binary_images = torch.stack(binary_images)
|
||||
real_images = torch.stack(real_images)
|
||||
|
||||
binary_image = torch.sum(binary_images, keepdim=True, dim=0)
|
||||
real_image = torch.sum(real_images, keepdim=True, dim=0)
|
||||
|
||||
mnist_images = [x for x, _ in dataloader]
|
||||
mnist_mean = torch.cat(mnist_images).reshape(10000, 15, 15).abs().sum(dim=0)
|
||||
|
||||
fig, axs = plt.subplots(1, 3)
|
||||
|
||||
for idx, image in enumerate([binary_image, real_image, mnist_mean]):
|
||||
img = axs[idx].imshow(image.squeeze().detach().cpu())
|
||||
img.axes.axis('off')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(mdl_path.parent / 'heatmap.png', dpi=300)
|
||||
plt.clf()
|
||||
plt.close('all')
|
||||
|
||||
|
||||
def plot_training_results_over_n_seeds(exp_path, df_train_store_name='train_store.csv', metric_name='Accuracy'):
|
||||
combined_df_store_path = exp_path / f'comb_train_{exp_path.stem[:-1]}n.csv'
|
||||
# noinspection PyUnboundLocalVariable
|
||||
found_train_stores = exp_path.rglob(df_train_store_name)
|
||||
train_dfs = []
|
||||
for found_train_store in found_train_stores:
|
||||
train_store_df = pd.read_csv(found_train_store, index_col=False)
|
||||
train_store_df['Seed'] = int(found_train_store.parent.name)
|
||||
train_dfs.append(train_store_df)
|
||||
combined_train_df = pd.concat(train_dfs)
|
||||
combined_train_df.to_csv(combined_df_store_path, index=False)
|
||||
plot_training_result(combined_df_store_path, metric_name=metric_name,
|
||||
plot_name=f"{combined_df_store_path.stem}.png"
|
||||
)
|
||||
plt.clf()
|
||||
plt.close('all')
|
||||
|
||||
|
||||
def sanity_weight_swap(exp_path, dataloader, metric_class=torchmetrics.Accuracy):
|
||||
# noinspection PyProtectedMember
|
||||
metric_name = metric_class()._get_name()
|
||||
found_models = exp_path.rglob(f'*{FINAL_CHECKPOINT_NAME}')
|
||||
df = pd.DataFrame(columns=['Seed', 'Model', metric_name])
|
||||
for model_idx, found_model in enumerate(found_models):
|
||||
model = torch.load(found_model, map_location=DEVICE).eval()
|
||||
weights = extract_weights_from_model(model)
|
||||
|
||||
results = test_weights_as_model(model, weights, dataloader, metric_class=metric_class)
|
||||
for model_name, measurement in results.items():
|
||||
df.loc[df.shape[0]] = (model_idx, model_name, measurement)
|
||||
df.loc[df.shape[0]] = (model_idx, 'Difference', np.abs(np.subtract(*results.values())))
|
||||
|
||||
df.to_csv(exp_path / 'sanity_weight_swap.csv', index=False)
|
||||
_ = sns.boxplot(data=df, x='Model', y=metric_name)
|
||||
plt.tight_layout()
|
||||
plt.savefig(exp_path / 'sanity_weight_swap.png', dpi=300)
|
||||
plt.clf()
|
||||
plt.close('all')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise NotImplementedError('Test this here!!!')
|
||||
|
@ -1,17 +1,13 @@
|
||||
import pickle
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import random
|
||||
import copy
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm import tqdm
|
||||
from functionalities_test import is_identity_function, is_zero_fixpoint, test_for_fixpoints, is_divergent
|
||||
from functionalities_test import (is_identity_function, is_zero_fixpoint, test_for_fixpoints, is_divergent,
|
||||
FixTypes as FT)
|
||||
from network import Net
|
||||
from torch.nn import functional as F
|
||||
from visualization import plot_loss, bar_chart_fixpoints
|
||||
import seaborn as sns
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
@ -26,7 +22,6 @@ def generate_perfekt_synthetic_fixpoint_weights():
|
||||
[1.0], [0.0]
|
||||
], dtype=torch.float32)
|
||||
|
||||
|
||||
PALETTE = 10 * (
|
||||
"#377eb8",
|
||||
"#4daf4a",
|
||||
@ -44,14 +39,16 @@ PALETTE = 10 * (
|
||||
)
|
||||
|
||||
|
||||
def test_robustness(networks: list, exp_path, noise_levels=10, seeds=10, log_step_size=10):
|
||||
def test_robustness(model_path, noise_levels=10, seeds=10, log_step_size=10):
|
||||
model = torch.load(model_path, map_location='cpu')
|
||||
networks = [x for x in model.particles if x.is_fixpoint == FT.identity_func]
|
||||
time_to_vergence = [[0 for _ in range(noise_levels)] for _ in range(len(networks))]
|
||||
time_as_fixpoint = [[0 for _ in range(noise_levels)] for _ in range(len(networks))]
|
||||
row_headers = []
|
||||
|
||||
df = pd.DataFrame(columns=['setting', 'Noise Level', 'Self Train Steps', 'absolute_loss',
|
||||
'Time to convergence', 'Time as fixpoint'])
|
||||
with tqdm(total=max(len(networks), seeds)) as pbar:
|
||||
with tqdm(total=(seeds * noise_levels * len(networks))) as pbar:
|
||||
for setting, fixpoint in enumerate(networks): # 1 / n
|
||||
row_headers.append(fixpoint.name)
|
||||
for seed in range(seeds): # n / 1
|
||||
@ -84,7 +81,7 @@ def test_robustness(networks: list, exp_path, noise_levels=10, seeds=10, log_ste
|
||||
steps, absolute_loss,
|
||||
time_to_vergence[setting][noise_level],
|
||||
time_as_fixpoint[setting][noise_level]]
|
||||
pbar.update(1)
|
||||
pbar.update(1)
|
||||
|
||||
# Get the measuremts at the highest time_time_to_vergence
|
||||
df_sorted = df.sort_values('Self Train Steps', ascending=False).drop_duplicates(['setting', 'Noise Level'])
|
||||
@ -92,6 +89,9 @@ def test_robustness(networks: list, exp_path, noise_levels=10, seeds=10, log_ste
|
||||
value_vars=['Time to convergence', 'Time as fixpoint'],
|
||||
var_name="Measurement",
|
||||
value_name="Steps").sort_values('Noise Level')
|
||||
|
||||
df_melted.to_csv(model_path.parent / 'robustness_boxplot.csv', index=False)
|
||||
|
||||
# Plotting
|
||||
# plt.rcParams.update({
|
||||
# "text.usetex": True,
|
||||
@ -108,8 +108,8 @@ def test_robustness(networks: list, exp_path, noise_levels=10, seeds=10, log_ste
|
||||
# bx = sns.catplot(data=df[df['absolute_loss'] < 1], y='absolute_loss', x='application_step', kind='box',
|
||||
# col='noise_level', col_wrap=3, showfliers=False)
|
||||
|
||||
filename = f"absolute_loss_perapplication_boxplot_grid_wild.png"
|
||||
filepath = exp_path / filename
|
||||
filename = f"robustness_boxplot.png"
|
||||
filepath = model_path.parent / filename
|
||||
plt.savefig(str(filepath))
|
||||
plt.close('all')
|
||||
return time_as_fixpoint, time_to_vergence
|
||||
|
@ -61,13 +61,13 @@ def test_for_fixpoints(fixpoint_counter: Dict, nets: List, id_functions=None):
|
||||
if is_divergent(net):
|
||||
fixpoint_counter[FixTypes.divergent] += 1
|
||||
net.is_fixpoint = FixTypes.divergent
|
||||
elif is_zero_fixpoint(net):
|
||||
fixpoint_counter[FixTypes.fix_zero] += 1
|
||||
net.is_fixpoint = FixTypes.fix_zero
|
||||
elif is_identity_function(net): # is default value
|
||||
fixpoint_counter[FixTypes.identity_func] += 1
|
||||
net.is_fixpoint = FixTypes.identity_func
|
||||
id_functions.append(net)
|
||||
elif is_zero_fixpoint(net):
|
||||
fixpoint_counter[FixTypes.fix_zero] += 1
|
||||
net.is_fixpoint = FixTypes.fix_zero
|
||||
elif is_secondary_fixpoint(net):
|
||||
fixpoint_counter[FixTypes.fix_sec] += 1
|
||||
net.is_fixpoint = FixTypes.fix_sec
|
||||
|
135
meta_task_exp.py
135
meta_task_exp.py
@ -1,10 +1,9 @@
|
||||
|
||||
# # # Imports
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import platform
|
||||
|
||||
import pandas as pd
|
||||
import torchmetrics
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -19,7 +18,12 @@ from tqdm import tqdm
|
||||
# noinspection DuplicatedCode
|
||||
from experiments.meta_task_utility import (ToFloat, new_storage_df, train_task, checkpoint_and_validate, flat_for_store,
|
||||
plot_training_result, plot_training_particle_types,
|
||||
plot_network_connectivity_by_fixtype, run_particle_dropout_and_plot)
|
||||
plot_network_connectivity_by_fixtype, run_particle_dropout_and_plot,
|
||||
highlight_fixpoints_vs_mnist_mean, AddGaussianNoise,
|
||||
plot_training_results_over_n_seeds, sanity_weight_swap,
|
||||
FINAL_CHECKPOINT_NAME)
|
||||
from experiments.robustness_tester import test_robustness
|
||||
from plot_3d_trajectories import plot_single_3d_trajectories_by_layer, plot_grouped_3d_trajectories_by_layer
|
||||
|
||||
if platform.node() == 'CarbonX':
|
||||
debug = True
|
||||
@ -29,33 +33,40 @@ if platform.node() == 'CarbonX':
|
||||
else:
|
||||
debug = False
|
||||
|
||||
from network import MetaNet
|
||||
from network import MetaNet, FixTypes
|
||||
from functionalities_test import test_for_fixpoints
|
||||
|
||||
utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0), AddGaussianNoise()])
|
||||
WORKER = 10 if not debug else 2
|
||||
debug = False
|
||||
BATCHSIZE = 2000 if not debug else 50
|
||||
EPOCH = 50
|
||||
EPOCH = 200
|
||||
VALIDATION_FRQ = 3 if not debug else 1
|
||||
VALIDATION_METRIC = torchmetrics.Accuracy
|
||||
VAL_METRIC_CLASS = torchmetrics.Accuracy
|
||||
# noinspection PyProtectedMember
|
||||
VAL_METRIC_NAME = VALIDATION_METRIC()._get_name()
|
||||
VAL_METRIC_NAME = VAL_METRIC_CLASS()._get_name()
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
DATA_PATH = Path('data')
|
||||
DATA_PATH.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
if debug:
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
try:
|
||||
plot_dataset = MNIST(str(DATA_PATH), transform=utility_transforms, train=False)
|
||||
except RuntimeError:
|
||||
plot_dataset = MNIST(str(DATA_PATH), transform=utility_transforms, train=False, download=True)
|
||||
plot_loader = DataLoader(plot_dataset, batch_size=BATCHSIZE, shuffle=False,
|
||||
drop_last=True, num_workers=WORKER)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
training = True
|
||||
n_st = 150 # per batch !!
|
||||
training = False
|
||||
plotting = False
|
||||
robustnes = True # EXPENSIV!!!!!!!
|
||||
n_st = 300 # per batch !!
|
||||
activation = None # nn.ReLU()
|
||||
|
||||
for weight_hidden_size in [4, 5, 6]:
|
||||
for weight_hidden_size in [3]:
|
||||
|
||||
weight_hidden_size = weight_hidden_size
|
||||
residual_skip = True
|
||||
@ -73,11 +84,7 @@ if __name__ == '__main__':
|
||||
st_str = f'_nst_{n_st}'
|
||||
|
||||
config_str = f'{res_str}{ac_str}{st_str}'
|
||||
exp_path = Path('output') / f'add_st_{EPOCH}_{weight_hidden_size}{config_str}'
|
||||
|
||||
if not training:
|
||||
# noinspection PyRedeclaration
|
||||
exp_path = Path('output') / 'add_st_50_5'
|
||||
exp_path = Path('output') / f'mn_st_{EPOCH}_{weight_hidden_size}{config_str}_gauss'
|
||||
|
||||
for seed in range(n_seeds):
|
||||
seed_path = exp_path / str(seed)
|
||||
@ -90,8 +97,6 @@ if __name__ == '__main__':
|
||||
# Check if files do exist on project location, warn and break.
|
||||
for path in [df_store_path, weight_store_path]:
|
||||
assert not path.exists(), f'Path "{path}" already exists. Check your configuration!'
|
||||
|
||||
utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)])
|
||||
try:
|
||||
train_dataset = MNIST(str(DATA_PATH), transform=utility_transforms)
|
||||
except RuntimeError:
|
||||
@ -122,8 +127,13 @@ if __name__ == '__main__':
|
||||
metanet = metanet.train()
|
||||
|
||||
# Init metrics, even we do not need:
|
||||
metric = VALIDATION_METRIC()
|
||||
n_st_per_batch = n_st // len(train_loader)
|
||||
metric = VAL_METRIC_CLASS()
|
||||
n_st_per_batch = max(n_st // len(train_loader), 1)
|
||||
|
||||
if is_validation_epoch:
|
||||
for particle in metanet.particles:
|
||||
weight_log = (epoch, particle.name, *flat_for_store(particle.parameters()))
|
||||
weight_store.loc[weight_store.shape[0]] = weight_log
|
||||
|
||||
for batch, (batch_x, batch_y) in tqdm(enumerate(train_loader),
|
||||
total=len(train_loader), desc='MetaNet Train - Batch'
|
||||
@ -166,17 +176,14 @@ if __name__ == '__main__':
|
||||
train_store.loc[train_store.shape[0]] = val_step_log
|
||||
tqdm.write(f'Fixpoint Tester Results: {counter_dict}')
|
||||
|
||||
# FLUSH to disk
|
||||
if is_validation_epoch:
|
||||
for particle in metanet.particles:
|
||||
weight_log = (epoch, particle.name, *flat_for_store(particle.parameters()))
|
||||
weight_store.loc[weight_store.shape[0]] = weight_log
|
||||
train_store.to_csv(df_store_path, mode='a',
|
||||
header=not df_store_path.exists(), index=False)
|
||||
weight_store.to_csv(weight_store_path, mode='a',
|
||||
header=not weight_store_path.exists(), index=False)
|
||||
train_store = new_storage_df('train', None)
|
||||
weight_store = new_storage_df('weights', metanet.particle_parameter_count)
|
||||
# FLUSH to disk
|
||||
if is_validation_epoch:
|
||||
train_store.to_csv(df_store_path, mode='a',
|
||||
header=not df_store_path.exists(), index=False)
|
||||
weight_store.to_csv(weight_store_path, mode='a',
|
||||
header=not weight_store_path.exists(), index=False)
|
||||
train_store = new_storage_df('train', None)
|
||||
weight_store = new_storage_df('weights', metanet.particle_parameter_count)
|
||||
|
||||
###########################################################
|
||||
# EPOCHS endet
|
||||
@ -200,38 +207,48 @@ if __name__ == '__main__':
|
||||
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False)
|
||||
weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
|
||||
|
||||
plot_training_result(df_store_path)
|
||||
plot_training_particle_types(df_store_path)
|
||||
|
||||
try:
|
||||
model_path = next(seed_path.glob(f'*e{EPOCH}.tp'))
|
||||
model_path = next(seed_path.glob(f'*{FINAL_CHECKPOINT_NAME}'))
|
||||
except StopIteration:
|
||||
print('Model pattern did not trigger.')
|
||||
print(f'Search path was: {seed_path}:')
|
||||
print(f'Found Models are: {list(seed_path.rglob(".tp"))}')
|
||||
exit(1)
|
||||
|
||||
try:
|
||||
# noinspection PyUnboundLocalVariable
|
||||
run_particle_dropout_and_plot(model_path, valid_loader=valid_loader, metric_class=VALIDATION_METRIC)
|
||||
except (ValueError, NameError) as e:
|
||||
print(e)
|
||||
try:
|
||||
plot_network_connectivity_by_fixtype(model_path)
|
||||
except (ValueError, NameError) as e:
|
||||
print(e)
|
||||
if plotting:
|
||||
highlight_fixpoints_vs_mnist_mean(model_path, plot_loader)
|
||||
|
||||
if n_seeds >= 2:
|
||||
combined_df_store_path = exp_path.parent / f'comb_train_{exp_path.stem[:-1]}n.csv'
|
||||
# noinspection PyUnboundLocalVariable
|
||||
found_train_stores = exp_path.rglob(df_store_path.name)
|
||||
train_dfs = []
|
||||
for found_train_store in found_train_stores:
|
||||
train_store_df = pd.read_csv(found_train_store, index_col=False)
|
||||
train_store_df['Seed'] = int(found_train_store.parent.name)
|
||||
train_dfs.append(train_store_df)
|
||||
combined_train_df = pd.concat(train_dfs)
|
||||
combined_train_df.to_csv(combined_df_store_path, index=False)
|
||||
plot_training_result(combined_df_store_path, metric=VAL_METRIC_NAME,
|
||||
plot_name=f"{combined_df_store_path.stem}.png"
|
||||
)
|
||||
plot_training_result(df_store_path)
|
||||
plot_training_particle_types(df_store_path)
|
||||
|
||||
try:
|
||||
# noinspection PyUnboundLocalVariable
|
||||
run_particle_dropout_and_plot(model_path, valid_loader=plot_loader, metric_class=VAL_METRIC_CLASS)
|
||||
except (ValueError, NameError) as e:
|
||||
print(e)
|
||||
try:
|
||||
plot_network_connectivity_by_fixtype(model_path)
|
||||
except (ValueError, NameError) as e:
|
||||
print(e)
|
||||
|
||||
highlight_fixpoints_vs_mnist_mean(model_path, plot_loader)
|
||||
|
||||
try:
|
||||
for fixtype in FixTypes.all_types():
|
||||
plot_single_3d_trajectories_by_layer(model_path, weight_store_path, status_type=fixtype)
|
||||
plot_grouped_3d_trajectories_by_layer(model_path, weight_store_path, status_type=fixtype)
|
||||
|
||||
except ValueError as e:
|
||||
print('ERROR:', e)
|
||||
if robustnes:
|
||||
try:
|
||||
test_robustness(model_path, seeds=1)
|
||||
except ValueError as e:
|
||||
print('ERROR:', e)
|
||||
|
||||
if 2 <= n_seeds <= sum(list(x.is_dir() for x in exp_path.iterdir())):
|
||||
if plotting:
|
||||
|
||||
plot_training_results_over_n_seeds(exp_path, metric_name=VAL_METRIC_NAME)
|
||||
|
||||
sanity_weight_swap(exp_path, plot_loader, VAL_METRIC_CLASS)
|
||||
|
@ -2,7 +2,6 @@ from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torchmetrics
|
||||
from torch import nn
|
||||
@ -15,32 +14,35 @@ from network import MetaNet
|
||||
from functionalities_test import test_for_fixpoints, FixTypes as ft
|
||||
from experiments.meta_task_utility import new_storage_df, flat_for_store, plot_training_result, \
|
||||
plot_training_particle_types, run_particle_dropout_and_plot, plot_network_connectivity_by_fixtype, \
|
||||
checkpoint_and_validate
|
||||
checkpoint_and_validate, plot_training_results_over_n_seeds, sanity_weight_swap, FINAL_CHECKPOINT_NAME
|
||||
from plot_3d_trajectories import plot_single_3d_trajectories_by_layer, plot_grouped_3d_trajectories_by_layer
|
||||
|
||||
WORKER = 0
|
||||
BATCHSIZE = 50
|
||||
EPOCH = 30
|
||||
VALIDATION_FRQ = 3
|
||||
VALIDATION_METRIC = torchmetrics.MeanAbsoluteError
|
||||
VAL_METRIC_CLASS = torchmetrics.MeanAbsoluteError
|
||||
# noinspection PyProtectedMember
|
||||
VAL_METRIC_NAME = VALIDATION_METRIC()._get_name()
|
||||
VAL_METRIC_NAME = VAL_METRIC_CLASS()._get_name()
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
plot_loader = DataLoader(AddTaskDataset(), batch_size=BATCHSIZE, shuffle=True,
|
||||
drop_last=True, num_workers=WORKER)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
training = False
|
||||
plotting = False
|
||||
n_st = 700
|
||||
plotting = True
|
||||
n_st = 100
|
||||
activation = None # nn.ReLU()
|
||||
|
||||
for weight_hidden_size in [3, 4, 5]:
|
||||
for weight_hidden_size in [2]:
|
||||
|
||||
tsk_threshold = 0.85
|
||||
weight_hidden_size = weight_hidden_size
|
||||
residual_skip = True
|
||||
n_seeds = 10
|
||||
n_seeds = 3
|
||||
depth = 3
|
||||
width = 3
|
||||
out = 1
|
||||
@ -97,8 +99,8 @@ if __name__ == '__main__':
|
||||
metanet = metanet.train()
|
||||
|
||||
# Init metrics, even we do not need:
|
||||
metric = VALIDATION_METRIC()
|
||||
n_st_per_batch = n_st // len(train_load)
|
||||
metric = VAL_METRIC_CLASS()
|
||||
n_st_per_batch = max(1, (n_st // len(train_load)))
|
||||
|
||||
for batch, (batch_x, batch_y) in tqdm(enumerate(train_load),
|
||||
total=len(train_load), desc='MetaNet Train - Batch'
|
||||
@ -125,7 +127,7 @@ if __name__ == '__main__':
|
||||
train_store.loc[train_store.shape[0]] = validation_log
|
||||
|
||||
mae = checkpoint_and_validate(metanet, vali_load, seed_path, epoch,
|
||||
validation_metric=VALIDATION_METRIC).item()
|
||||
validation_metric=VAL_METRIC_CLASS).item()
|
||||
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
|
||||
Metric=f'Test {VAL_METRIC_NAME}', Score=mae)
|
||||
train_store.loc[train_store.shape[0]] = validation_log
|
||||
@ -163,7 +165,7 @@ if __name__ == '__main__':
|
||||
step_log = dict(Epoch=int(EPOCH), Batch=BATCHSIZE, Metric=key, Score=value)
|
||||
train_store.loc[train_store.shape[0]] = step_log
|
||||
accuracy = checkpoint_and_validate(metanet, vali_load, seed_path, EPOCH, final_model=True,
|
||||
validation_metric=VALIDATION_METRIC)
|
||||
validation_metric=VAL_METRIC_CLASS)
|
||||
validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
|
||||
Metric=f'Test {VAL_METRIC_NAME}', Score=accuracy.item())
|
||||
for particle in metanet.particles:
|
||||
@ -175,11 +177,11 @@ if __name__ == '__main__':
|
||||
weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
|
||||
if plotting:
|
||||
|
||||
plot_training_result(df_store_path, metric=VAL_METRIC_NAME)
|
||||
plot_training_result(df_store_path, metric_name=VAL_METRIC_NAME)
|
||||
plot_training_particle_types(df_store_path)
|
||||
|
||||
try:
|
||||
model_path = next(seed_path.glob(f'*e{EPOCH}.tp'))
|
||||
model_path = next(seed_path.glob(f'*{FINAL_CHECKPOINT_NAME}'))
|
||||
except StopIteration:
|
||||
print('####################################################')
|
||||
print('ERROR: Model pattern did not trigger.')
|
||||
@ -190,7 +192,7 @@ if __name__ == '__main__':
|
||||
|
||||
try:
|
||||
# noinspection PyUnboundLocalVariable
|
||||
run_particle_dropout_and_plot(model_path, valid_loader=vali_load, metric_class=VALIDATION_METRIC)
|
||||
run_particle_dropout_and_plot(model_path, valid_loader=plot_loader, metric_class=VAL_METRIC_CLASS)
|
||||
except ValueError as e:
|
||||
print('ERROR:', e)
|
||||
try:
|
||||
@ -204,24 +206,15 @@ if __name__ == '__main__':
|
||||
plot_grouped_3d_trajectories_by_layer(model_path, weight_store_path, status_type=ft.other_func)
|
||||
except ValueError as e:
|
||||
print('ERROR:', e)
|
||||
try:
|
||||
model_path = next(seed_path.glob(f'*e{EPOCH}.tp'))
|
||||
model = torch.load(model_path, map_location='cpu')
|
||||
test_robustness(list(model.particles), seed_path)
|
||||
except ValueError as e:
|
||||
print('ERROR:', e)
|
||||
try:
|
||||
test_robustness(model_path, seeds=10)
|
||||
pass
|
||||
except ValueError as e:
|
||||
print('ERROR:', e)
|
||||
|
||||
if n_seeds >= 2:
|
||||
combined_df_store_path = exp_path.parent / f'comb_train_{exp_path.stem[:-1]}n.csv'
|
||||
# noinspection PyUnboundLocalVariable
|
||||
found_train_stores = exp_path.rglob(df_store_path.name)
|
||||
train_dfs = []
|
||||
for found_train_store in found_train_stores:
|
||||
train_store_df = pd.read_csv(found_train_store, index_col=False)
|
||||
train_store_df['Seed'] = int(found_train_store.parent.name)
|
||||
train_dfs.append(train_store_df)
|
||||
combined_train_df = pd.concat(train_dfs)
|
||||
combined_train_df.to_csv(combined_df_store_path, index=False)
|
||||
plot_training_result(combined_df_store_path, metric=VAL_METRIC_NAME,
|
||||
plot_name=f"{combined_df_store_path.stem}.png"
|
||||
)
|
||||
if 2 <= n_seeds == sum(list(x.is_dir() for x in exp_path.iterdir())):
|
||||
if plotting:
|
||||
|
||||
plot_training_results_over_n_seeds(exp_path, metric_name=VAL_METRIC_NAME)
|
||||
|
||||
sanity_weight_swap(exp_path, plot_loader, VAL_METRIC_CLASS)
|
||||
|
15
network.py
15
network.py
@ -443,7 +443,8 @@ class MetaNet(nn.Module):
|
||||
{key: torch.zeros_like(state) for key, state in particle.state_dict().items()}
|
||||
)
|
||||
replaced_particles += 1
|
||||
tqdm.write(f'Particle Parameters replaced: {str(replaced_particles)}')
|
||||
if replaced_particles != 0:
|
||||
tqdm.write(f'Particle Parameters replaced: {str(replaced_particles)}')
|
||||
return self
|
||||
|
||||
def forward(self, x):
|
||||
@ -538,17 +539,15 @@ class MetaNetCompareBaseline(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
tensor = self._first_layer(x)
|
||||
if self.activation:
|
||||
tensor = self.activation(tensor)
|
||||
residual = None
|
||||
for idx, meta_layer in enumerate(self._meta_layer_list, start=1):
|
||||
tensor = meta_layer(tensor)
|
||||
if idx % 2 == 1 and self.residual_skip:
|
||||
# if idx % 2 == 1 and self.residual_skip:
|
||||
if self.residual_skip:
|
||||
residual = tensor
|
||||
if idx % 2 == 0 and self.residual_skip:
|
||||
tensor = meta_layer(tensor)
|
||||
# if idx % 2 == 0 and self.residual_skip:
|
||||
if self.residual_skip:
|
||||
tensor = tensor + residual
|
||||
if self.activation:
|
||||
tensor = self.activation(tensor)
|
||||
tensor = self._last_layer(tensor)
|
||||
return tensor
|
||||
|
||||
|
@ -24,34 +24,35 @@ def plot_single_3d_trajectories_by_layer(model_path, all_weights_path, status_ty
|
||||
|
||||
fixpoint_statuses = [net.is_fixpoint for net in model_layer.particles]
|
||||
num_status_of_layer = sum([net.is_fixpoint == status_type for net in model_layer.particles])
|
||||
layer = all_weights[all_weights.Weight.str.startswith(f"L{layer_idx}")]
|
||||
weight_batches = [np.array(layer[layer.Weight == name].values.tolist())[:, 2:]
|
||||
for name in layer.Weight.unique()]
|
||||
plt.clf()
|
||||
fig = plt.figure()
|
||||
ax = plt.axes(projection='3d')
|
||||
fig.set_figheight(10)
|
||||
fig.set_figwidth(12)
|
||||
plt.tight_layout()
|
||||
if num_status_of_layer != 0:
|
||||
layer = all_weights[all_weights.Weight.str.startswith(f"L{layer_idx}")]
|
||||
weight_batches = [np.array(layer[layer.Weight == name].values.tolist())[:, 2:]
|
||||
for name in layer.Weight.unique()]
|
||||
plt.clf()
|
||||
fig = plt.figure()
|
||||
ax = plt.axes(projection='3d')
|
||||
fig.set_figheight(10)
|
||||
fig.set_figwidth(12)
|
||||
plt.tight_layout()
|
||||
|
||||
for weights_of_net, status in zip(weight_batches, fixpoint_statuses):
|
||||
if status == status_type:
|
||||
pca.fit(weights_of_net)
|
||||
transformed_trajectory = pca.transform(weights_of_net)
|
||||
xdata = transformed_trajectory[:, 0]
|
||||
ydata = transformed_trajectory[:, 1]
|
||||
zdata = all_epochs
|
||||
ax.plot3D(xdata, ydata, zdata)
|
||||
ax.scatter(xdata, ydata, zdata, s=7)
|
||||
|
||||
ax.set_title(f"Layer {layer_idx}: {num_status_of_layer}-{status_type}", fontsize=20)
|
||||
ax.set_xlabel('PCA Transformed x-axis', fontsize=20)
|
||||
ax.set_ylabel('PCA Transformed y-axis', fontsize=20)
|
||||
ax.set_zlabel('Epochs', fontsize=30, rotation=0)
|
||||
file_path = save_path / f"layer_{layer_idx}_{num_status_of_layer}_{status_type}.png"
|
||||
plt.savefig(file_path, bbox_inches="tight", dpi=300, format="png")
|
||||
plt.clf()
|
||||
plt.close(fig)
|
||||
for weights_of_net, status in zip(weight_batches, fixpoint_statuses):
|
||||
if status == status_type:
|
||||
pca.fit(weights_of_net)
|
||||
transformed_trajectory = pca.transform(weights_of_net)
|
||||
xdata = transformed_trajectory[:, 0]
|
||||
ydata = transformed_trajectory[:, 1]
|
||||
zdata = all_epochs
|
||||
ax.plot3D(xdata, ydata, zdata)
|
||||
ax.scatter(xdata, ydata, zdata, s=7)
|
||||
|
||||
ax.set_title(f"Layer {layer_idx}: {num_status_of_layer}-{status_type}", fontsize=20)
|
||||
ax.set_xlabel('PCA Transformed x-axis', fontsize=20)
|
||||
ax.set_ylabel('PCA Transformed y-axis', fontsize=20)
|
||||
ax.set_zlabel('Epochs', fontsize=30, rotation=0)
|
||||
file_path = save_path / f"layer_{layer_idx}_{num_status_of_layer}_{status_type}.png"
|
||||
plt.savefig(file_path, bbox_inches="tight", dpi=300, format="png")
|
||||
plt.clf()
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def plot_grouped_3d_trajectories_by_layer(model_path, all_weights_path, status_type: FixTypes):
|
||||
|
@ -28,9 +28,10 @@ def extract_weights_from_model(model:MetaNet)->dict:
|
||||
return dict(weights)
|
||||
|
||||
|
||||
def test_weights_as_model(meta_net, new_weights:dict, data):
|
||||
transfer_net = MetaNetCompareBaseline(meta_net.interface, depth=meta_net.depth, width=meta_net.width, out=meta_net.out,
|
||||
residual_skip=True)
|
||||
def test_weights_as_model(meta_net, new_weights, data, metric_class=torchmetrics.Accuracy):
|
||||
transfer_net = MetaNetCompareBaseline(meta_net.interface, depth=meta_net.depth,
|
||||
width=meta_net.width, out=meta_net.out,
|
||||
residual_skip=meta_net.residual_skip)
|
||||
with torch.no_grad():
|
||||
new_weight_values = list(new_weights.values())
|
||||
old_parameters = list(transfer_net.parameters())
|
||||
@ -39,40 +40,18 @@ def test_weights_as_model(meta_net, new_weights:dict, data):
|
||||
parameters[:] = torch.Tensor(weights).view(parameters.shape)[:]
|
||||
|
||||
transfer_net.eval()
|
||||
|
||||
# Test if the margin of error is similar
|
||||
|
||||
im_t = defaultdict(list)
|
||||
rand = torch.randn((1, 15 * 15))
|
||||
for net in [meta_net, transfer_net]:
|
||||
tensor = rand.clone()
|
||||
for layer in net.all_layers:
|
||||
tensor = layer(tensor)
|
||||
im_t[net.__class__.__name__].append(tensor.detach())
|
||||
|
||||
im_t = dict(im_t)
|
||||
|
||||
all_close = {f'layer_{idx}': torch.allclose(y1.detach(), y2.detach(), rtol=0, atol=e
|
||||
) for idx, (y1, y2) in enumerate(zip(*im_t.values()))
|
||||
}
|
||||
print(f'Cummulative differences per layer is smaller then {e}:\n {all_close}')
|
||||
# all_errors = {f'layer_{idx}': torch.absolute(y1.detach(), y2.detach(), rtol=0, atol=e
|
||||
# ) for idx, (y1, y2) in enumerate(zip(*im_t.values()))
|
||||
# }
|
||||
|
||||
results = dict()
|
||||
for net in [meta_net, transfer_net]:
|
||||
net.eval()
|
||||
metric = torchmetrics.Accuracy()
|
||||
with tqdm(desc='Test Batch: ') as pbar:
|
||||
for batch, (batch_x, batch_y) in tqdm(enumerate(data), total=len(data), desc='MetaNet Sanity Check'):
|
||||
y = net(batch_x)
|
||||
acc = metric(y.cpu(), batch_y.cpu())
|
||||
pbar.set_postfix_str(f'Acc: {acc}')
|
||||
pbar.update()
|
||||
metric = metric_class()
|
||||
for batch, (batch_x, batch_y) in tqdm(enumerate(data), total=len(data), desc='Test Batch: '):
|
||||
y = net(batch_x)
|
||||
metric(y.cpu(), batch_y.cpu())
|
||||
|
||||
# metric on all batches using custom accumulation
|
||||
acc = metric.compute()
|
||||
tqdm.write(f"Avg. accuracy on {net.__class__.__name__}: {acc}")
|
||||
# metric on all batches using custom accumulation
|
||||
measure = metric.compute()
|
||||
results[net.__class__.__name__] = measure.item()
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
x
Reference in New Issue
Block a user