refactoring and running experiments

This commit is contained in:
Steffen Illium
2022-03-05 16:51:19 +01:00
parent 69c904e156
commit f3ff4c9239
5 changed files with 196 additions and 203 deletions

View File

@@ -1,12 +1,5 @@
from pathlib import Path
import torch
import torchmetrics
from torch.utils.data import Dataset
from tqdm import tqdm
from experiments.meta_task_utility import set_checkpoint
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -24,31 +17,6 @@ class AddTaskDataset(Dataset):
return ab, ab.sum(axis=-1, keepdims=True)
def validate(checkpoint_path, valid_d, ratio=1, validmetric=torchmetrics.MeanAbsoluteError()):
checkpoint_path = Path(checkpoint_path)
# initialize metric
model = torch.load(checkpoint_path, map_location=DEVICE).eval()
n_samples = int(len(valid_d) * ratio)
with tqdm(total=n_samples, desc='Validation Run: ') as pbar:
for idx, (valid_batch_x, valid_batch_y) in enumerate(valid_d):
valid_batch_x, valid_batch_y = valid_batch_x.to(DEVICE), valid_batch_y.to(DEVICE)
y_valid = model(valid_batch_x)
# metric on current batch
acc = validmetric(y_valid.cpu(), valid_batch_y.cpu())
pbar.set_postfix_str(f'Acc: {acc}')
pbar.update()
if idx == n_samples:
break
# metric on all batches using custom accumulation
acc = validmetric.compute()
tqdm.write(f"Avg. Accuracy on all data: {acc}")
return acc
def train_task(model, optimizer, loss_func, btch_x, btch_y) -> (dict, torch.Tensor):
# Zero your gradients for every batch!
optimizer.zero_grad()
@@ -66,12 +34,5 @@ def train_task(model, optimizer, loss_func, btch_x, btch_y) -> (dict, torch.Tens
return stp_log, y_prd
def checkpoint_and_validate(model, out_path, epoch_n, valid_d, final_model=False):
out_path = Path(out_path)
ckpt_path = set_checkpoint(model, out_path, epoch_n, final_model=final_model)
result = validate(ckpt_path, valid_d)
return result
if __name__ == '__main__':
raise(NotImplementedError('Get out of here'))

View File

@@ -3,57 +3,30 @@ import re
import shutil
from collections import defaultdict
from pathlib import Path
import sys
import platform
import pandas as pd
import numpy as np
import torch
import torchmetrics
from matplotlib import pyplot as plt
import seaborn as sns
from torch.nn import Flatten
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Compose, Resize
from torch.utils.data import Dataset
from tqdm import tqdm
# noinspection DuplicatedCode
if platform.node() == 'CarbonX':
debug = True
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
print("@ Warning, Debugging Config@!!!!!! @")
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
else:
debug = False
try:
# noinspection PyUnboundLocalVariable
if __package__ is None:
DIR = Path(__file__).resolve().parent
sys.path.insert(0, str(DIR.parent))
__package__ = DIR.name
else:
DIR = None
except NameError:
DIR = None
pass
from network import FixTypes as ft
from functionalities_test import test_for_fixpoints
WORKER = 10 if not debug else 0
debug = False
BATCHSIZE = 500 if not debug else 50
WORKER = 10
BATCHSIZE = 500
EPOCH = 50
VALIDATION_FRQ = 3 if not debug else 1
SELF_TRAIN_FRQ = 1 if not debug else 1
VALIDATION_FRQ = 3
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DATA_PATH = Path('data')
DATA_PATH.mkdir(exist_ok=True, parents=True)
if debug:
torch.autograd.set_detect_anomaly(True)
class ToFloat:
@@ -93,39 +66,28 @@ def set_checkpoint(model, out_path, epoch_n, final_model=False):
return ckpt_path
def validate(checkpoint_path, ratio=0.1):
# noinspection PyProtectedMember
def validate(checkpoint_path, valid_loader, metric_class=torchmetrics.Accuracy):
checkpoint_path = Path(checkpoint_path)
import torchmetrics
# initialize metric
validmetric = torchmetrics.Accuracy()
ut = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)])
try:
datas = MNIST(str(DATA_PATH), transform=ut, train=False)
except RuntimeError:
datas = MNIST(str(DATA_PATH), transform=ut, train=False, download=True)
valid_d = DataLoader(datas, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
validmetric = metric_class()
model = torch.load(checkpoint_path, map_location=DEVICE).eval()
n_samples = int(len(valid_d) * ratio)
with tqdm(total=n_samples, desc='Validation Run: ') as pbar:
for idx, (valid_batch_x, valid_batch_y) in enumerate(valid_d):
with tqdm(total=len(valid_loader), desc='Validation Run: ') as pbar:
for idx, (valid_batch_x, valid_batch_y) in enumerate(valid_loader):
valid_batch_x, valid_batch_y = valid_batch_x.to(DEVICE), valid_batch_y.to(DEVICE)
y_valid = model(valid_batch_x)
# metric on current batch
acc = validmetric(y_valid.cpu(), valid_batch_y.cpu())
pbar.set_postfix_str(f'Acc: {acc}')
measure = validmetric(y_valid.cpu(), valid_batch_y.cpu())
pbar.set_postfix_str(f'Measure: {measure}')
pbar.update()
if idx == n_samples:
break
# metric on all batches using custom accumulation
acc = validmetric.compute()
tqdm.write(f"Avg. accuracy on all data: {acc}")
return acc
measure = validmetric.compute()
tqdm.write(f"Avg. {validmetric._get_name()} on all data: {measure}")
return measure
def new_storage_df(identifier, weight_count):
@@ -135,14 +97,16 @@ def new_storage_df(identifier, weight_count):
return pd.DataFrame(columns=['Epoch', 'Weight', *(f'weight_{x}' for x in range(weight_count))])
def checkpoint_and_validate(model, out_path, epoch_n, final_model=False):
def checkpoint_and_validate(model, valid_loader, out_path, epoch_n, final_model=False,
validation_metric=torchmetrics.Accuracy):
out_path = Path(out_path)
ckpt_path = set_checkpoint(model, out_path, epoch_n, final_model=final_model)
result = validate(ckpt_path)
result = validate(ckpt_path, valid_loader, metric_class=validation_metric)
return result
def plot_training_particle_types(path_to_dataframe):
plt.close('all')
plt.clf()
# load from Drive
df = pd.read_csv(path_to_dataframe, index_col=False)
@@ -158,17 +122,19 @@ def plot_training_particle_types(path_to_dataframe):
fig.legend(loc="center right", title='Particle Type', bbox_to_anchor=(0.85, 0.5))
plt.tight_layout()
if debug:
plt.show()
else:
plt.savefig(Path(path_to_dataframe.parent / 'training_particle_type_lp.png'), dpi=300)
plt.savefig(Path(path_to_dataframe.parent / 'training_particle_type_lp.png'), dpi=300)
def plot_training_result(path_to_dataframe):
def plot_training_result(path_to_dataframe, metric='Accuracy', plot_name=None):
plt.clf()
# load from Drive
df = pd.read_csv(path_to_dataframe, index_col=False)
# Check if this is a single lineplot or if aggregated
group = ['Epoch', 'Metric']
if 'Seed' in df.columns:
group.append('Seed')
# Set up figure
fig, ax1 = plt.subplots() # initializes figure and plots
ax2 = ax1.twinx() # applies twinx to ax2, which is the second y-axis.
@@ -176,26 +142,27 @@ def plot_training_result(path_to_dataframe):
# plots the first set of data
data = df[(df['Metric'] == 'Task Loss') | (df['Metric'] == 'Self Train Loss')].groupby(['Epoch', 'Metric']).mean()
palette = sns.color_palette()[1:data.reset_index()['Metric'].unique().shape[0]+1]
sns.lineplot(data=data.groupby(['Epoch', 'Metric']).mean(), x='Epoch', y='Score', hue='Metric',
palette=palette, ax=ax1)
sns.lineplot(data=data.groupby(group).mean(), x='Epoch', y='Score', hue='Metric',
palette=palette, ax=ax1, ci='sd')
# plots the second set of data
data = df[(df['Metric'] == 'Test Accuracy') | (df['Metric'] == 'Train Accuracy')]
data = df[(df['Metric'] == f'Test {metric}') | (df['Metric'] == f'Train {metric}')]
palette = sns.color_palette()[len(palette)+1:data.reset_index()['Metric'].unique().shape[0] + len(palette)+1]
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', hue='Metric', palette=palette)
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', hue='Metric', palette=palette, ci='sd')
ax1.set(yscale='log', ylabel='Losses')
ax1.set_title('Training Lineplot')
ax2.set(ylabel='Accuracy')
# ax1.set_title('Training Lineplot')
ax2.set(ylabel=metric)
if metric != 'MAE':
ax2.set(yscale='log')
fig.legend(loc="center right", title='Metric', bbox_to_anchor=(0.85, 0.5))
ax1.get_legend().remove()
ax2.get_legend().remove()
for ax in [ax1, ax2]:
if legend := ax.get_legend():
legend.remove()
plt.tight_layout()
if debug:
plt.show()
else:
plt.savefig(Path(path_to_dataframe.parent / 'training_lineplot.png'), dpi=300)
plt.savefig(Path(path_to_dataframe.parent / ('training_lineplot.png' if plot_name is None else plot_name)), dpi=300)
def plot_network_connectivity_by_fixtype(path_to_trained_model):
@@ -224,31 +191,29 @@ def plot_network_connectivity_by_fixtype(path_to_trained_model):
lines = ax.get_lines()
for line in lines:
line.set_color(sns.color_palette()[n])
if debug:
plt.show()
else:
plt.savefig(Path(path_to_trained_model.parent / f'net_connectivity_{fixtype}.png'), dpi=300)
plt.savefig(Path(path_to_trained_model.parent / f'net_connectivity_{fixtype}.png'), dpi=300)
tqdm.write(f'Connectivity plottet: {fixtype} - n = {df[df["type"] == fixtype].shape[0] // 2}')
else:
tqdm.write(f'No Connectivity {fixtype}')
def run_particle_dropout_test(model_path):
# noinspection PyProtectedMember
def run_particle_dropout_test(model_path, valid_loader, metric_class=torchmetrics.Accuracy):
diff_store_path = model_path.parent / 'diff_store.csv'
latest_model = torch.load(model_path, map_location=DEVICE).eval()
prtcl_dict = defaultdict(lambda: 0)
_ = test_for_fixpoints(prtcl_dict, list(latest_model.particles))
tqdm.write(str(dict(prtcl_dict)))
diff_df = pd.DataFrame(columns=['Particle Type', 'Accuracy', 'Diff'])
diff_df = pd.DataFrame(columns=['Particle Type', metric_class()._get_name(), 'Diff'])
acc_pre = validate(model_path, ratio=1).item()
acc_pre = validate(model_path, valid_loader, metric_class=metric_class).item()
diff_df.loc[diff_df.shape[0]] = ('All Organism', acc_pre, 0)
for fixpoint_type in ft.all_types():
new_model = torch.load(model_path, map_location=DEVICE).eval().replace_with_zero(fixpoint_type)
if [x for x in new_model.particles if x.is_fixpoint == fixpoint_type]:
new_ckpt = set_checkpoint(new_model, model_path.parent, fixpoint_type, final_model=True)
acc_post = validate(new_ckpt, ratio=1).item()
acc_post = validate(new_ckpt, valid_loader, metric_class=metric_class).item()
acc_diff = abs(acc_post - acc_pre)
tqdm.write(f'Zero_ident diff = {acc_diff}')
diff_df.loc[diff_df.shape[0]] = (fixpoint_type, acc_post, acc_diff)
@@ -257,8 +222,9 @@ def run_particle_dropout_test(model_path):
return diff_store_path
def plot_dropout_stacked_barplot(mdl_path, diff_store_path):
# noinspection PyProtectedMember
def plot_dropout_stacked_barplot(mdl_path, diff_store_path, metric_class=torchmetrics.Accuracy):
metric_name = metric_class()._get_name()
diff_df = pd.read_csv(diff_store_path)
particle_dict = defaultdict(lambda: 0)
latest_model = torch.load(mdl_path, map_location=DEVICE).eval()
@@ -267,24 +233,21 @@ def plot_dropout_stacked_barplot(mdl_path, diff_store_path):
plt.clf()
fig, ax = plt.subplots(ncols=2)
colors = sns.color_palette()[1:diff_df.shape[0]+1]
_ = sns.barplot(data=diff_df, y='Accuracy', x='Particle Type', ax=ax[0], palette=colors)
_ = sns.barplot(data=diff_df, y=metric_name, x='Particle Type', ax=ax[0], palette=colors)
ax[0].set_title('Accuracy after particle dropout')
ax[0].set_title(f'{metric_name} after particle dropout')
ax[0].set_xlabel('Particle Type')
ax[1].pie(particle_dict.values(), labels=particle_dict.keys(), colors=list(reversed(colors)), )
ax[1].set_title('Particle Count')
plt.tight_layout()
if debug:
plt.show()
else:
plt.savefig(Path(diff_store_path.parent / 'dropout_stacked_barplot.png'), dpi=300)
plt.savefig(Path(diff_store_path.parent / 'dropout_stacked_barplot.png'), dpi=300)
def run_particle_dropout_and_plot(model_path):
diff_store_path = run_particle_dropout_test(model_path)
plot_dropout_stacked_barplot(model_path, diff_store_path)
def run_particle_dropout_and_plot(model_path, valid_loader, metric_class=torchmetrics.Accuracy):
diff_store_path = run_particle_dropout_test(model_path, valid_loader=valid_loader, metric_class=metric_class)
plot_dropout_stacked_barplot(model_path, diff_store_path, metric_class=metric_class)
def flat_for_store(parameters):