refactoring and running experiments

This commit is contained in:
Steffen Illium
2022-03-05 16:51:19 +01:00
parent 69c904e156
commit f3ff4c9239
5 changed files with 196 additions and 203 deletions

View File

@ -1,12 +1,5 @@
from pathlib import Path
import torch import torch
import torchmetrics
from torch.utils.data import Dataset from torch.utils.data import Dataset
from tqdm import tqdm
from experiments.meta_task_utility import set_checkpoint
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@ -24,31 +17,6 @@ class AddTaskDataset(Dataset):
return ab, ab.sum(axis=-1, keepdims=True) return ab, ab.sum(axis=-1, keepdims=True)
def validate(checkpoint_path, valid_d, ratio=1, validmetric=torchmetrics.MeanAbsoluteError()):
checkpoint_path = Path(checkpoint_path)
# initialize metric
model = torch.load(checkpoint_path, map_location=DEVICE).eval()
n_samples = int(len(valid_d) * ratio)
with tqdm(total=n_samples, desc='Validation Run: ') as pbar:
for idx, (valid_batch_x, valid_batch_y) in enumerate(valid_d):
valid_batch_x, valid_batch_y = valid_batch_x.to(DEVICE), valid_batch_y.to(DEVICE)
y_valid = model(valid_batch_x)
# metric on current batch
acc = validmetric(y_valid.cpu(), valid_batch_y.cpu())
pbar.set_postfix_str(f'Acc: {acc}')
pbar.update()
if idx == n_samples:
break
# metric on all batches using custom accumulation
acc = validmetric.compute()
tqdm.write(f"Avg. Accuracy on all data: {acc}")
return acc
def train_task(model, optimizer, loss_func, btch_x, btch_y) -> (dict, torch.Tensor): def train_task(model, optimizer, loss_func, btch_x, btch_y) -> (dict, torch.Tensor):
# Zero your gradients for every batch! # Zero your gradients for every batch!
optimizer.zero_grad() optimizer.zero_grad()
@ -66,12 +34,5 @@ def train_task(model, optimizer, loss_func, btch_x, btch_y) -> (dict, torch.Tens
return stp_log, y_prd return stp_log, y_prd
def checkpoint_and_validate(model, out_path, epoch_n, valid_d, final_model=False):
out_path = Path(out_path)
ckpt_path = set_checkpoint(model, out_path, epoch_n, final_model=final_model)
result = validate(ckpt_path, valid_d)
return result
if __name__ == '__main__': if __name__ == '__main__':
raise(NotImplementedError('Get out of here')) raise(NotImplementedError('Get out of here'))

View File

@ -3,57 +3,30 @@ import re
import shutil import shutil
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
import sys
import platform
import pandas as pd import pandas as pd
import numpy as np import numpy as np
import torch import torch
import torchmetrics
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
import seaborn as sns import seaborn as sns
from torch.nn import Flatten from torch.utils.data import Dataset
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Compose, Resize
from tqdm import tqdm from tqdm import tqdm
# noinspection DuplicatedCode
if platform.node() == 'CarbonX':
debug = True
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
print("@ Warning, Debugging Config@!!!!!! @")
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
else:
debug = False
try:
# noinspection PyUnboundLocalVariable
if __package__ is None:
DIR = Path(__file__).resolve().parent
sys.path.insert(0, str(DIR.parent))
__package__ = DIR.name
else:
DIR = None
except NameError:
DIR = None
pass
from network import FixTypes as ft from network import FixTypes as ft
from functionalities_test import test_for_fixpoints from functionalities_test import test_for_fixpoints
WORKER = 10 if not debug else 0 WORKER = 10
debug = False BATCHSIZE = 500
BATCHSIZE = 500 if not debug else 50
EPOCH = 50 EPOCH = 50
VALIDATION_FRQ = 3 if not debug else 1 VALIDATION_FRQ = 3
SELF_TRAIN_FRQ = 1 if not debug else 1
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DATA_PATH = Path('data') DATA_PATH = Path('data')
DATA_PATH.mkdir(exist_ok=True, parents=True) DATA_PATH.mkdir(exist_ok=True, parents=True)
if debug:
torch.autograd.set_detect_anomaly(True)
class ToFloat: class ToFloat:
@ -93,39 +66,28 @@ def set_checkpoint(model, out_path, epoch_n, final_model=False):
return ckpt_path return ckpt_path
def validate(checkpoint_path, ratio=0.1): # noinspection PyProtectedMember
def validate(checkpoint_path, valid_loader, metric_class=torchmetrics.Accuracy):
checkpoint_path = Path(checkpoint_path) checkpoint_path = Path(checkpoint_path)
import torchmetrics
# initialize metric # initialize metric
validmetric = torchmetrics.Accuracy() validmetric = metric_class()
ut = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)])
try:
datas = MNIST(str(DATA_PATH), transform=ut, train=False)
except RuntimeError:
datas = MNIST(str(DATA_PATH), transform=ut, train=False, download=True)
valid_d = DataLoader(datas, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
model = torch.load(checkpoint_path, map_location=DEVICE).eval() model = torch.load(checkpoint_path, map_location=DEVICE).eval()
n_samples = int(len(valid_d) * ratio)
with tqdm(total=n_samples, desc='Validation Run: ') as pbar: with tqdm(total=len(valid_loader), desc='Validation Run: ') as pbar:
for idx, (valid_batch_x, valid_batch_y) in enumerate(valid_d): for idx, (valid_batch_x, valid_batch_y) in enumerate(valid_loader):
valid_batch_x, valid_batch_y = valid_batch_x.to(DEVICE), valid_batch_y.to(DEVICE) valid_batch_x, valid_batch_y = valid_batch_x.to(DEVICE), valid_batch_y.to(DEVICE)
y_valid = model(valid_batch_x) y_valid = model(valid_batch_x)
# metric on current batch # metric on current batch
acc = validmetric(y_valid.cpu(), valid_batch_y.cpu()) measure = validmetric(y_valid.cpu(), valid_batch_y.cpu())
pbar.set_postfix_str(f'Acc: {acc}') pbar.set_postfix_str(f'Measure: {measure}')
pbar.update() pbar.update()
if idx == n_samples:
break
# metric on all batches using custom accumulation # metric on all batches using custom accumulation
acc = validmetric.compute() measure = validmetric.compute()
tqdm.write(f"Avg. accuracy on all data: {acc}") tqdm.write(f"Avg. {validmetric._get_name()} on all data: {measure}")
return acc return measure
def new_storage_df(identifier, weight_count): def new_storage_df(identifier, weight_count):
@ -135,14 +97,16 @@ def new_storage_df(identifier, weight_count):
return pd.DataFrame(columns=['Epoch', 'Weight', *(f'weight_{x}' for x in range(weight_count))]) return pd.DataFrame(columns=['Epoch', 'Weight', *(f'weight_{x}' for x in range(weight_count))])
def checkpoint_and_validate(model, out_path, epoch_n, final_model=False): def checkpoint_and_validate(model, valid_loader, out_path, epoch_n, final_model=False,
validation_metric=torchmetrics.Accuracy):
out_path = Path(out_path) out_path = Path(out_path)
ckpt_path = set_checkpoint(model, out_path, epoch_n, final_model=final_model) ckpt_path = set_checkpoint(model, out_path, epoch_n, final_model=final_model)
result = validate(ckpt_path) result = validate(ckpt_path, valid_loader, metric_class=validation_metric)
return result return result
def plot_training_particle_types(path_to_dataframe): def plot_training_particle_types(path_to_dataframe):
plt.close('all')
plt.clf() plt.clf()
# load from Drive # load from Drive
df = pd.read_csv(path_to_dataframe, index_col=False) df = pd.read_csv(path_to_dataframe, index_col=False)
@ -158,17 +122,19 @@ def plot_training_particle_types(path_to_dataframe):
fig.legend(loc="center right", title='Particle Type', bbox_to_anchor=(0.85, 0.5)) fig.legend(loc="center right", title='Particle Type', bbox_to_anchor=(0.85, 0.5))
plt.tight_layout() plt.tight_layout()
if debug: plt.savefig(Path(path_to_dataframe.parent / 'training_particle_type_lp.png'), dpi=300)
plt.show()
else:
plt.savefig(Path(path_to_dataframe.parent / 'training_particle_type_lp.png'), dpi=300)
def plot_training_result(path_to_dataframe): def plot_training_result(path_to_dataframe, metric='Accuracy', plot_name=None):
plt.clf() plt.clf()
# load from Drive # load from Drive
df = pd.read_csv(path_to_dataframe, index_col=False) df = pd.read_csv(path_to_dataframe, index_col=False)
# Check if this is a single lineplot or if aggregated
group = ['Epoch', 'Metric']
if 'Seed' in df.columns:
group.append('Seed')
# Set up figure # Set up figure
fig, ax1 = plt.subplots() # initializes figure and plots fig, ax1 = plt.subplots() # initializes figure and plots
ax2 = ax1.twinx() # applies twinx to ax2, which is the second y-axis. ax2 = ax1.twinx() # applies twinx to ax2, which is the second y-axis.
@ -176,26 +142,27 @@ def plot_training_result(path_to_dataframe):
# plots the first set of data # plots the first set of data
data = df[(df['Metric'] == 'Task Loss') | (df['Metric'] == 'Self Train Loss')].groupby(['Epoch', 'Metric']).mean() data = df[(df['Metric'] == 'Task Loss') | (df['Metric'] == 'Self Train Loss')].groupby(['Epoch', 'Metric']).mean()
palette = sns.color_palette()[1:data.reset_index()['Metric'].unique().shape[0]+1] palette = sns.color_palette()[1:data.reset_index()['Metric'].unique().shape[0]+1]
sns.lineplot(data=data.groupby(['Epoch', 'Metric']).mean(), x='Epoch', y='Score', hue='Metric',
palette=palette, ax=ax1) sns.lineplot(data=data.groupby(group).mean(), x='Epoch', y='Score', hue='Metric',
palette=palette, ax=ax1, ci='sd')
# plots the second set of data # plots the second set of data
data = df[(df['Metric'] == 'Test Accuracy') | (df['Metric'] == 'Train Accuracy')] data = df[(df['Metric'] == f'Test {metric}') | (df['Metric'] == f'Train {metric}')]
palette = sns.color_palette()[len(palette)+1:data.reset_index()['Metric'].unique().shape[0] + len(palette)+1] palette = sns.color_palette()[len(palette)+1:data.reset_index()['Metric'].unique().shape[0] + len(palette)+1]
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', hue='Metric', palette=palette) sns.lineplot(data=data, x='Epoch', y='Score', marker='o', hue='Metric', palette=palette, ci='sd')
ax1.set(yscale='log', ylabel='Losses') ax1.set(yscale='log', ylabel='Losses')
ax1.set_title('Training Lineplot') # ax1.set_title('Training Lineplot')
ax2.set(ylabel='Accuracy') ax2.set(ylabel=metric)
if metric != 'MAE':
ax2.set(yscale='log')
fig.legend(loc="center right", title='Metric', bbox_to_anchor=(0.85, 0.5)) fig.legend(loc="center right", title='Metric', bbox_to_anchor=(0.85, 0.5))
ax1.get_legend().remove() for ax in [ax1, ax2]:
ax2.get_legend().remove() if legend := ax.get_legend():
legend.remove()
plt.tight_layout() plt.tight_layout()
if debug: plt.savefig(Path(path_to_dataframe.parent / ('training_lineplot.png' if plot_name is None else plot_name)), dpi=300)
plt.show()
else:
plt.savefig(Path(path_to_dataframe.parent / 'training_lineplot.png'), dpi=300)
def plot_network_connectivity_by_fixtype(path_to_trained_model): def plot_network_connectivity_by_fixtype(path_to_trained_model):
@ -224,31 +191,29 @@ def plot_network_connectivity_by_fixtype(path_to_trained_model):
lines = ax.get_lines() lines = ax.get_lines()
for line in lines: for line in lines:
line.set_color(sns.color_palette()[n]) line.set_color(sns.color_palette()[n])
if debug: plt.savefig(Path(path_to_trained_model.parent / f'net_connectivity_{fixtype}.png'), dpi=300)
plt.show()
else:
plt.savefig(Path(path_to_trained_model.parent / f'net_connectivity_{fixtype}.png'), dpi=300)
tqdm.write(f'Connectivity plottet: {fixtype} - n = {df[df["type"] == fixtype].shape[0] // 2}') tqdm.write(f'Connectivity plottet: {fixtype} - n = {df[df["type"] == fixtype].shape[0] // 2}')
else: else:
tqdm.write(f'No Connectivity {fixtype}') tqdm.write(f'No Connectivity {fixtype}')
def run_particle_dropout_test(model_path): # noinspection PyProtectedMember
def run_particle_dropout_test(model_path, valid_loader, metric_class=torchmetrics.Accuracy):
diff_store_path = model_path.parent / 'diff_store.csv' diff_store_path = model_path.parent / 'diff_store.csv'
latest_model = torch.load(model_path, map_location=DEVICE).eval() latest_model = torch.load(model_path, map_location=DEVICE).eval()
prtcl_dict = defaultdict(lambda: 0) prtcl_dict = defaultdict(lambda: 0)
_ = test_for_fixpoints(prtcl_dict, list(latest_model.particles)) _ = test_for_fixpoints(prtcl_dict, list(latest_model.particles))
tqdm.write(str(dict(prtcl_dict))) tqdm.write(str(dict(prtcl_dict)))
diff_df = pd.DataFrame(columns=['Particle Type', 'Accuracy', 'Diff']) diff_df = pd.DataFrame(columns=['Particle Type', metric_class()._get_name(), 'Diff'])
acc_pre = validate(model_path, ratio=1).item() acc_pre = validate(model_path, valid_loader, metric_class=metric_class).item()
diff_df.loc[diff_df.shape[0]] = ('All Organism', acc_pre, 0) diff_df.loc[diff_df.shape[0]] = ('All Organism', acc_pre, 0)
for fixpoint_type in ft.all_types(): for fixpoint_type in ft.all_types():
new_model = torch.load(model_path, map_location=DEVICE).eval().replace_with_zero(fixpoint_type) new_model = torch.load(model_path, map_location=DEVICE).eval().replace_with_zero(fixpoint_type)
if [x for x in new_model.particles if x.is_fixpoint == fixpoint_type]: if [x for x in new_model.particles if x.is_fixpoint == fixpoint_type]:
new_ckpt = set_checkpoint(new_model, model_path.parent, fixpoint_type, final_model=True) new_ckpt = set_checkpoint(new_model, model_path.parent, fixpoint_type, final_model=True)
acc_post = validate(new_ckpt, ratio=1).item() acc_post = validate(new_ckpt, valid_loader, metric_class=metric_class).item()
acc_diff = abs(acc_post - acc_pre) acc_diff = abs(acc_post - acc_pre)
tqdm.write(f'Zero_ident diff = {acc_diff}') tqdm.write(f'Zero_ident diff = {acc_diff}')
diff_df.loc[diff_df.shape[0]] = (fixpoint_type, acc_post, acc_diff) diff_df.loc[diff_df.shape[0]] = (fixpoint_type, acc_post, acc_diff)
@ -257,8 +222,9 @@ def run_particle_dropout_test(model_path):
return diff_store_path return diff_store_path
def plot_dropout_stacked_barplot(mdl_path, diff_store_path): # noinspection PyProtectedMember
def plot_dropout_stacked_barplot(mdl_path, diff_store_path, metric_class=torchmetrics.Accuracy):
metric_name = metric_class()._get_name()
diff_df = pd.read_csv(diff_store_path) diff_df = pd.read_csv(diff_store_path)
particle_dict = defaultdict(lambda: 0) particle_dict = defaultdict(lambda: 0)
latest_model = torch.load(mdl_path, map_location=DEVICE).eval() latest_model = torch.load(mdl_path, map_location=DEVICE).eval()
@ -267,24 +233,21 @@ def plot_dropout_stacked_barplot(mdl_path, diff_store_path):
plt.clf() plt.clf()
fig, ax = plt.subplots(ncols=2) fig, ax = plt.subplots(ncols=2)
colors = sns.color_palette()[1:diff_df.shape[0]+1] colors = sns.color_palette()[1:diff_df.shape[0]+1]
_ = sns.barplot(data=diff_df, y='Accuracy', x='Particle Type', ax=ax[0], palette=colors) _ = sns.barplot(data=diff_df, y=metric_name, x='Particle Type', ax=ax[0], palette=colors)
ax[0].set_title('Accuracy after particle dropout') ax[0].set_title(f'{metric_name} after particle dropout')
ax[0].set_xlabel('Particle Type') ax[0].set_xlabel('Particle Type')
ax[1].pie(particle_dict.values(), labels=particle_dict.keys(), colors=list(reversed(colors)), ) ax[1].pie(particle_dict.values(), labels=particle_dict.keys(), colors=list(reversed(colors)), )
ax[1].set_title('Particle Count') ax[1].set_title('Particle Count')
plt.tight_layout() plt.tight_layout()
if debug: plt.savefig(Path(diff_store_path.parent / 'dropout_stacked_barplot.png'), dpi=300)
plt.show()
else:
plt.savefig(Path(diff_store_path.parent / 'dropout_stacked_barplot.png'), dpi=300)
def run_particle_dropout_and_plot(model_path): def run_particle_dropout_and_plot(model_path, valid_loader, metric_class=torchmetrics.Accuracy):
diff_store_path = run_particle_dropout_test(model_path) diff_store_path = run_particle_dropout_test(model_path, valid_loader=valid_loader, metric_class=metric_class)
plot_dropout_stacked_barplot(model_path, diff_store_path) plot_dropout_stacked_barplot(model_path, diff_store_path, metric_class=metric_class)
def flat_for_store(parameters): def flat_for_store(parameters):

View File

@ -4,7 +4,7 @@ from pathlib import Path
import platform import platform
import pandas as pd
import torchmetrics import torchmetrics
import numpy as np import numpy as np
import torch import torch
@ -17,9 +17,9 @@ from torchvision.transforms import ToTensor, Compose, Resize
from tqdm import tqdm from tqdm import tqdm
# noinspection DuplicatedCode # noinspection DuplicatedCode
from experiments.meta_task_utility import ToFloat, new_storage_df, train_task, checkpoint_and_validate, flat_for_store, \ from experiments.meta_task_utility import (ToFloat, new_storage_df, train_task, checkpoint_and_validate, flat_for_store,
plot_training_result, plot_training_particle_types, plot_network_connectivity_by_fixtype, \ plot_training_result, plot_training_particle_types,
run_particle_dropout_and_plot plot_network_connectivity_by_fixtype, run_particle_dropout_and_plot)
if platform.node() == 'CarbonX': if platform.node() == 'CarbonX':
debug = True debug = True
@ -37,7 +37,9 @@ debug = False
BATCHSIZE = 2000 if not debug else 50 BATCHSIZE = 2000 if not debug else 50
EPOCH = 50 EPOCH = 50
VALIDATION_FRQ = 3 if not debug else 1 VALIDATION_FRQ = 3 if not debug else 1
SELF_TRAIN_FRQ = 1 if not debug else 1 VALIDATION_METRIC = torchmetrics.Accuracy
# noinspection PyProtectedMember
VAL_METRIC_NAME = VALIDATION_METRIC()._get_name()
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DATA_PATH = Path('data') DATA_PATH = Path('data')
@ -50,15 +52,15 @@ if debug:
if __name__ == '__main__': if __name__ == '__main__':
training = True training = True
n_st = 300 n_st = 150 # per batch !!
activation = None # nn.ReLU() activation = None # nn.ReLU()
for weight_hidden_size in [4, 5]: for weight_hidden_size in [4, 5, 6]:
weight_hidden_size = weight_hidden_size weight_hidden_size = weight_hidden_size
residual_skip = True residual_skip = True
n_seeds = 3 n_seeds = 3
depth = 3 depth = 5
width = 3 width = 3
out = 10 out = 10
@ -96,6 +98,12 @@ if __name__ == '__main__':
train_dataset = MNIST(str(DATA_PATH), transform=utility_transforms, download=True) train_dataset = MNIST(str(DATA_PATH), transform=utility_transforms, download=True)
train_loader = DataLoader(train_dataset, batch_size=BATCHSIZE, shuffle=True, train_loader = DataLoader(train_dataset, batch_size=BATCHSIZE, shuffle=True,
drop_last=True, num_workers=WORKER) drop_last=True, num_workers=WORKER)
try:
valid_dataset = MNIST(str(DATA_PATH), transform=utility_transforms, train=False)
except RuntimeError:
valid_dataset = MNIST(str(DATA_PATH), transform=utility_transforms, train=False, download=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCHSIZE, shuffle=True,
drop_last=True, num_workers=WORKER)
interface = np.prod(train_dataset[0][0].shape) interface = np.prod(train_dataset[0][0].shape)
metanet = MetaNet(interface, depth=depth, width=width, out=out, metanet = MetaNet(interface, depth=depth, width=width, out=out,
@ -111,11 +119,10 @@ if __name__ == '__main__':
for epoch in tqdm(range(EPOCH), desc=f'Train - Epochs'): for epoch in tqdm(range(EPOCH), desc=f'Train - Epochs'):
is_validation_epoch = epoch % VALIDATION_FRQ == 0 if not debug else True is_validation_epoch = epoch % VALIDATION_FRQ == 0 if not debug else True
is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True
metanet = metanet.train() metanet = metanet.train()
# Init metrics, even we do not need: # Init metrics, even we do not need:
metric = torchmetrics.Accuracy() metric = VALIDATION_METRIC()
n_st_per_batch = n_st // len(train_loader) n_st_per_batch = n_st // len(train_loader)
for batch, (batch_x, batch_y) in tqdm(enumerate(train_loader), for batch, (batch_x, batch_y) in tqdm(enumerate(train_loader),
@ -139,14 +146,14 @@ if __name__ == '__main__':
metanet = metanet.eval() metanet = metanet.eval()
try: try:
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric='Train Accuracy', Score=metric.compute().item()) Metric=f'Train {VAL_METRIC_NAME}', Score=metric.compute().item())
train_store.loc[train_store.shape[0]] = validation_log train_store.loc[train_store.shape[0]] = validation_log
except RuntimeError: except RuntimeError:
pass pass
accuracy = checkpoint_and_validate(metanet, seed_path, epoch).item() accuracy = checkpoint_and_validate(metanet, valid_loader, seed_path, epoch).item()
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric='Test Accuracy', Score=accuracy) Metric=f'Test {VAL_METRIC_NAME}', Score=accuracy)
train_store.loc[train_store.shape[0]] = validation_log train_store.loc[train_store.shape[0]] = validation_log
if is_validation_epoch: if is_validation_epoch:
@ -181,9 +188,9 @@ if __name__ == '__main__':
for key, value in dict(counter_dict).items(): for key, value in dict(counter_dict).items():
step_log = dict(Epoch=int(EPOCH)+1, Batch=BATCHSIZE, Metric=key, Score=value) step_log = dict(Epoch=int(EPOCH)+1, Batch=BATCHSIZE, Metric=key, Score=value)
train_store.loc[train_store.shape[0]] = step_log train_store.loc[train_store.shape[0]] = step_log
accuracy = checkpoint_and_validate(metanet, seed_path, EPOCH, final_model=True) accuracy = checkpoint_and_validate(metanet, valid_loader, seed_path, EPOCH, final_model=True)
validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE, validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
Metric='Test Accuracy', Score=accuracy.item()) Metric=f'Test {VAL_METRIC_NAME}', Score=accuracy.item())
train_store.loc[train_store.shape[0]] = validation_log train_store.loc[train_store.shape[0]] = validation_log
for particle in metanet.particles: for particle in metanet.particles:
weight_log = (EPOCH, particle.name, *(flat_for_store(particle.parameters()))) weight_log = (EPOCH, particle.name, *(flat_for_store(particle.parameters())))
@ -206,13 +213,25 @@ if __name__ == '__main__':
try: try:
# noinspection PyUnboundLocalVariable # noinspection PyUnboundLocalVariable
run_particle_dropout_and_plot(model_path) run_particle_dropout_and_plot(model_path, valid_loader=valid_loader, metric_class=VALIDATION_METRIC)
except (ValueError, NameError) as e: except (ValueError, NameError) as e:
print(e) print(e)
try: try:
plot_network_connectivity_by_fixtype(model_path) plot_network_connectivity_by_fixtype(model_path)
except (ValueError, NameError)as e: except (ValueError, NameError) as e:
print(e) print(e)
if n_seeds >= 2: if n_seeds >= 2:
pass combined_df_store_path = exp_path.parent / f'comb_train_{exp_path.stem[:-1]}n.csv'
# noinspection PyUnboundLocalVariable
found_train_stores = exp_path.rglob(df_store_path.name)
train_dfs = []
for found_train_store in found_train_stores:
train_store_df = pd.read_csv(found_train_store, index_col=False)
train_store_df['Seed'] = int(found_train_store.parent.name)
train_dfs.append(train_store_df)
combined_train_df = pd.concat(train_dfs)
combined_train_df.to_csv(combined_df_store_path, index=False)
plot_training_result(combined_df_store_path, metric=VAL_METRIC_NAME,
plot_name=f"{combined_df_store_path.stem}.png"
)

View File

@ -2,37 +2,44 @@ from collections import defaultdict
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import pandas as pd
import torch import torch
import torchmetrics import torchmetrics
from torch import nn from torch import nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from experiments.meta_task_small_utility import AddTaskDataset, checkpoint_and_validate, train_task from experiments.meta_task_small_utility import AddTaskDataset, train_task
from network import MetaNet from network import MetaNet
from functionalities_test import test_for_fixpoints from functionalities_test import test_for_fixpoints, FixTypes as ft
from experiments.meta_task_utility import new_storage_df, flat_for_store, plot_training_result, \ from experiments.meta_task_utility import new_storage_df, flat_for_store, plot_training_result, \
plot_training_particle_types, run_particle_dropout_and_plot, plot_network_connectivity_by_fixtype plot_training_particle_types, run_particle_dropout_and_plot, plot_network_connectivity_by_fixtype, \
checkpoint_and_validate
from plot_3d_trajectories import plot_single_3d_trajectories_by_layer, plot_grouped_3d_trajectories_by_layer
WORKER = 0 WORKER = 0
BATCHSIZE = 50 BATCHSIZE = 50
EPOCH = 30 EPOCH = 30
VALIDATION_FRQ = 3 VALIDATION_FRQ = 3
VALIDATION_METRIC = torchmetrics.MeanAbsoluteError
# noinspection PyProtectedMember
VAL_METRIC_NAME = VALIDATION_METRIC()._get_name()
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if __name__ == '__main__': if __name__ == '__main__':
training = True training = True
n_st = 500 plotting = True
n_st = 700
activation = None # nn.ReLU() activation = None # nn.ReLU()
for weight_hidden_size in [3,4,5]: for weight_hidden_size in [3, 4, 5, 6]:
tsk_threshold = 0.85 tsk_threshold = 0.85
weight_hidden_size = weight_hidden_size weight_hidden_size = weight_hidden_size
residual_skip = True residual_skip = True
n_seeds = 3 n_seeds = 10
depth = 3 depth = 3
width = 3 width = 3
out = 1 out = 1
@ -48,9 +55,9 @@ if __name__ == '__main__':
config_str = f'{res_str}' config_str = f'{res_str}'
exp_path = Path('output') / f'add_st_{EPOCH}_{weight_hidden_size}{config_str}{ac_str}' exp_path = Path('output') / f'add_st_{EPOCH}_{weight_hidden_size}{config_str}{ac_str}'
if not training: # if not training:
# noinspection PyRedeclaration # # noinspection PyRedeclaration
exp_path = Path('output') / 'mn_st_n_2_100_4' # exp_path = Path('output') / f'add_st_{n_st}_{weight_hidden_size}'
for seed in range(n_seeds): for seed in range(n_seeds):
seed_path = exp_path / str(seed) seed_path = exp_path / str(seed)
@ -60,17 +67,18 @@ if __name__ == '__main__':
weight_store_path = seed_path / 'weight_store.csv' weight_store_path = seed_path / 'weight_store.csv'
srnn_parameters = dict() srnn_parameters = dict()
valid_data = AddTaskDataset()
vali_load = DataLoader(valid_data, batch_size=BATCHSIZE, shuffle=True,
drop_last=True, num_workers=WORKER)
if training: if training:
# Check if files do exist on project location, warn and break. # Check if files do exist on project location, warn and break.
for path in [model_path, df_store_path, weight_store_path]: for path in [model_path, df_store_path, weight_store_path]:
assert not path.exists(), f'Path "{path}" already exists. Check your configuration!' assert not path.exists(), f'Path "{path}" already exists. Check your configuration!'
train_data = AddTaskDataset() train_data = AddTaskDataset()
valid_data = AddTaskDataset()
train_load = DataLoader(train_data, batch_size=BATCHSIZE, shuffle=True, train_load = DataLoader(train_data, batch_size=BATCHSIZE, shuffle=True,
drop_last=True, num_workers=WORKER) drop_last=True, num_workers=WORKER)
vali_load = DataLoader(valid_data, batch_size=BATCHSIZE, shuffle=False,
drop_last=True, num_workers=WORKER)
interface = np.prod(train_data[0][0].shape) interface = np.prod(train_data[0][0].shape)
metanet = MetaNet(interface, depth=depth, width=width, out=out, metanet = MetaNet(interface, depth=depth, width=width, out=out,
@ -89,7 +97,7 @@ if __name__ == '__main__':
metanet = metanet.train() metanet = metanet.train()
# Init metrics, even we do not need: # Init metrics, even we do not need:
metric = torchmetrics.MeanAbsoluteError() metric = VALIDATION_METRIC()
n_st_per_batch = n_st // len(train_load) n_st_per_batch = n_st // len(train_load)
for batch, (batch_x, batch_y) in tqdm(enumerate(train_load), for batch, (batch_x, batch_y) in tqdm(enumerate(train_load),
@ -113,12 +121,13 @@ if __name__ == '__main__':
metanet = metanet.eval() metanet = metanet.eval()
if metric.total.item(): if metric.total.item():
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric='Train Accuracy', Score=metric.compute().item()) Metric=f'Train {VAL_METRIC_NAME}', Score=metric.compute().item())
train_store.loc[train_store.shape[0]] = validation_log train_store.loc[train_store.shape[0]] = validation_log
accuracy = checkpoint_and_validate(metanet, seed_path, epoch, vali_load).item() mae = checkpoint_and_validate(metanet, vali_load, seed_path, epoch,
validation_metric=VALIDATION_METRIC).item()
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric='Test Accuracy', Score=accuracy) Metric=f'Test {VAL_METRIC_NAME}', Score=mae)
train_store.loc[train_store.shape[0]] = validation_log train_store.loc[train_store.shape[0]] = validation_log
if is_validation_epoch: if is_validation_epoch:
@ -153,9 +162,10 @@ if __name__ == '__main__':
for key, value in dict(counter_dict).items(): for key, value in dict(counter_dict).items():
step_log = dict(Epoch=int(EPOCH), Batch=BATCHSIZE, Metric=key, Score=value) step_log = dict(Epoch=int(EPOCH), Batch=BATCHSIZE, Metric=key, Score=value)
train_store.loc[train_store.shape[0]] = step_log train_store.loc[train_store.shape[0]] = step_log
accuracy = checkpoint_and_validate(metanet, seed_path, EPOCH, vali_load, final_model=True) accuracy = checkpoint_and_validate(metanet, vali_load, seed_path, EPOCH, final_model=True,
validation_metric=VALIDATION_METRIC)
validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE, validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
Metric='Test Accuracy', Score=accuracy.item()) Metric=f'Test {VAL_METRIC_NAME}', Score=accuracy.item())
for particle in metanet.particles: for particle in metanet.particles:
weight_log = (EPOCH, particle.name, *(flat_for_store(particle.parameters()))) weight_log = (EPOCH, particle.name, *(flat_for_store(particle.parameters())))
weight_store.loc[weight_store.shape[0]] = weight_log weight_store.loc[weight_store.shape[0]] = weight_log
@ -163,26 +173,48 @@ if __name__ == '__main__':
train_store.loc[train_store.shape[0]] = validation_log train_store.loc[train_store.shape[0]] = validation_log
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False) train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False)
weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False) weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
if plotting:
plot_training_result(df_store_path) plot_training_result(df_store_path, metric=VAL_METRIC_NAME)
plot_training_particle_types(df_store_path) plot_training_particle_types(df_store_path)
try: try:
model_path = next(seed_path.glob(f'*e{EPOCH}.tp')) model_path = next(seed_path.glob(f'*e{EPOCH}.tp'))
except StopIteration: except StopIteration:
print('Model pattern did not trigger.') print('####################################################')
print(f'Search path was: {seed_path}:') print('ERROR: Model pattern did not trigger.')
print(f'Found Models are: {list(seed_path.rglob(".tp"))}') print(f'INFO: Search path was: {seed_path}:')
exit(1) print(f'INFO: Found Models are: {list(seed_path.rglob(".tp"))}')
print('####################################################')
exit(1)
try: try:
run_particle_dropout_and_plot(model_path) run_particle_dropout_and_plot(model_path, valid_loader=vali_load, metric_class=VALIDATION_METRIC)
except ValueError as e: except ValueError as e:
print(e) print('ERROR:', e)
try: try:
plot_network_connectivity_by_fixtype(model_path) plot_network_connectivity_by_fixtype(model_path)
except ValueError as e: except ValueError as e:
print(e) print('ERROR:', e)
try:
plot_single_3d_trajectories_by_layer(model_path, weight_store_path, status_type=ft.identity_func)
plot_single_3d_trajectories_by_layer(model_path, weight_store_path, status_type=ft.other_func)
plot_grouped_3d_trajectories_by_layer(model_path, weight_store_path, status_type=ft.identity_func)
plot_grouped_3d_trajectories_by_layer(model_path, weight_store_path, status_type=ft.other_func)
except ValueError as e:
print('ERROR:', e)
if n_seeds >= 2: if n_seeds >= 2:
pass combined_df_store_path = exp_path.parent / f'comb_train_{exp_path.stem[:-1]}n.csv'
# noinspection PyUnboundLocalVariable
found_train_stores = exp_path.rglob(df_store_path.name)
train_dfs = []
for found_train_store in found_train_stores:
train_store_df = pd.read_csv(found_train_store, index_col=False)
train_store_df['Seed'] = int(found_train_store.parent.name)
train_dfs.append(train_store_df)
combined_train_df = pd.concat(train_dfs)
combined_train_df.to_csv(combined_df_store_path, index=False)
plot_training_result(combined_df_store_path, metric=VAL_METRIC_NAME,
plot_name=f"{combined_df_store_path.stem}.png"
)

View File

@ -1,13 +1,20 @@
import pandas as pd import pandas as pd
from pathlib import Path
import torch import torch
import numpy as np import numpy as np
from network import MetaNet, FixTypes from network import FixTypes
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from sklearn.decomposition import PCA from sklearn.decomposition import PCA
def plot_single_3d_trajectories_by_layer(model:MetaNet, all_weights:pd.DataFrame, save_path:Path, status_type:FixTypes):
''' This plots one PCA for every net (over its n epochs) as one trajectory and then combines all of them in one plot ''' def plot_single_3d_trajectories_by_layer(model_path, all_weights_path, status_type: FixTypes):
"""
This plots one PCA for every net (over its n epochs) as one trajectory
and then combines all of them in one plot
"""
model = torch.load(model_path, map_location=torch.device('cpu')).eval()
all_weights = pd.read_csv(all_weights_path, index_col=False)
save_path = model_path.parent / 'trajec_plots'
all_epochs = all_weights.Epoch.unique() all_epochs = all_weights.Epoch.unique()
pca = PCA(n_components=2, whiten=True) pca = PCA(n_components=2, whiten=True)
@ -18,8 +25,9 @@ def plot_single_3d_trajectories_by_layer(model:MetaNet, all_weights:pd.DataFrame
fixpoint_statuses = [net.is_fixpoint for net in model_layer.particles] fixpoint_statuses = [net.is_fixpoint for net in model_layer.particles]
num_status_of_layer = sum([net.is_fixpoint == status_type for net in model_layer.particles]) num_status_of_layer = sum([net.is_fixpoint == status_type for net in model_layer.particles])
layer = all_weights[all_weights.Weight.str.startswith(f"L{layer_idx}")] layer = all_weights[all_weights.Weight.str.startswith(f"L{layer_idx}")]
weight_batches = [np.array(layer[layer.Weight == name].values.tolist())[:,2:] for name in layer.Weight.unique()] weight_batches = [np.array(layer[layer.Weight == name].values.tolist())[:, 2:]
for name in layer.Weight.unique()]
plt.clf()
fig = plt.figure() fig = plt.figure()
ax = plt.axes(projection='3d') ax = plt.axes(projection='3d')
fig.set_figheight(10) fig.set_figheight(10)
@ -39,15 +47,19 @@ def plot_single_3d_trajectories_by_layer(model:MetaNet, all_weights:pd.DataFrame
ax.set_title(f"Layer {layer_idx}: {num_status_of_layer}-{status_type}", fontsize=20) ax.set_title(f"Layer {layer_idx}: {num_status_of_layer}-{status_type}", fontsize=20)
ax.set_xlabel('PCA Transformed x-axis', fontsize=20) ax.set_xlabel('PCA Transformed x-axis', fontsize=20)
ax.set_ylabel('PCA Transformed y-axis', fontsize=20) ax.set_ylabel('PCA Transformed y-axis', fontsize=20)
ax.set_zlabel('Epochs', fontsize=30, rotation = 0) ax.set_zlabel('Epochs', fontsize=30, rotation=0)
file_path = save_path / f"layer_{layer_idx}_{num_status_of_layer}_{status_type}.png" file_path = save_path / f"layer_{layer_idx}_{num_status_of_layer}_{status_type}.png"
plt.savefig(file_path, bbox_inches="tight", dpi=300, format="png") plt.savefig(file_path, bbox_inches="tight", dpi=300, format="png")
plt.clf() plt.clf()
plt.close(fig)
def plot_grouped_3d_trajectories_by_layer(model:MetaNet, all_weights:pd.DataFrame, save_path:Path, status_type:FixTypes): def plot_grouped_3d_trajectories_by_layer(model_path, all_weights_path, status_type: FixTypes):
''' This computes the PCA over all the net-weights at once and then plots that.''' """ This computes the PCA over all the net-weights at once and then plots that."""
model = torch.load(model_path, map_location=torch.device('cpu')).eval()
save_path = model_path.parent / 'trajec_plots'
all_weights = pd.read_csv(all_weights_path, index_col=False)
all_epochs = all_weights.Epoch.unique() all_epochs = all_weights.Epoch.unique()
pca = PCA(n_components=2, whiten=True) pca = PCA(n_components=2, whiten=True)
save_path.mkdir(exist_ok=True, parents=True) save_path.mkdir(exist_ok=True, parents=True)
@ -57,8 +69,9 @@ def plot_grouped_3d_trajectories_by_layer(model:MetaNet, all_weights:pd.DataFram
fixpoint_statuses = [net.is_fixpoint for net in model_layer.particles] fixpoint_statuses = [net.is_fixpoint for net in model_layer.particles]
num_status_of_layer = sum([net.is_fixpoint == status_type for net in model_layer.particles]) num_status_of_layer = sum([net.is_fixpoint == status_type for net in model_layer.particles])
layer = all_weights[all_weights.Weight.str.startswith(f"L{layer_idx}")] layer = all_weights[all_weights.Weight.str.startswith(f"L{layer_idx}")]
weight_batches = np.vstack([np.array(layer[layer.Weight == name].values.tolist())[:,2:] for name in layer.Weight.unique()]) weight_batches = np.vstack([np.array(layer[layer.Weight == name].values.tolist())[:, 2:]
for name in layer.Weight.unique()])
plt.clf()
fig = plt.figure() fig = plt.figure()
fig.set_figheight(10) fig.set_figheight(10)
fig.set_figwidth(12) fig.set_figwidth(12)
@ -67,7 +80,8 @@ def plot_grouped_3d_trajectories_by_layer(model:MetaNet, all_weights:pd.DataFram
pca.fit(weight_batches) pca.fit(weight_batches)
w_transformed = pca.transform(weight_batches) w_transformed = pca.transform(weight_batches)
for transformed_trajectory,status in zip(np.split(w_transformed, len(layer.Weight.unique())), fixpoint_statuses): for transformed_trajectory, status in zip(
np.split(w_transformed, len(layer.Weight.unique())), fixpoint_statuses):
if status == status_type: if status == status_type:
xdata = transformed_trajectory[:, 0] xdata = transformed_trajectory[:, 0]
ydata = transformed_trajectory[:, 1] ydata = transformed_trajectory[:, 1]
@ -78,13 +92,16 @@ def plot_grouped_3d_trajectories_by_layer(model:MetaNet, all_weights:pd.DataFram
ax.set_title(f"Layer {layer_idx}: {num_status_of_layer}-{status_type}", fontsize=20) ax.set_title(f"Layer {layer_idx}: {num_status_of_layer}-{status_type}", fontsize=20)
ax.set_xlabel('PCA Transformed x-axis', fontsize=20) ax.set_xlabel('PCA Transformed x-axis', fontsize=20)
ax.set_ylabel('PCA Transformed y-axis', fontsize=20) ax.set_ylabel('PCA Transformed y-axis', fontsize=20)
ax.set_zlabel('Epochs', fontsize=30, rotation = 0) ax.set_zlabel('Epochs', fontsize=30, rotation=0)
file_path = save_path / f"layer_{layer_idx}_{num_status_of_layer}_{status_type}_grouped.png" file_path = save_path / f"layer_{layer_idx}_{num_status_of_layer}_{status_type}_grouped.png"
plt.savefig(file_path, bbox_inches="tight", dpi=300, format="png") plt.savefig(file_path, bbox_inches="tight", dpi=300, format="png")
plt.clf() plt.clf()
plt.close(fig)
if __name__ == '__main__': if __name__ == '__main__':
raise (NotImplementedError('Get out of here'))
"""
weight_path = Path("weight_store.csv") weight_path = Path("weight_store.csv")
model_path = Path("trained_model_ckpt_e100.tp") model_path = Path("trained_model_ckpt_e100.tp")
save_path = Path("figures/3d_trajectories/") save_path = Path("figures/3d_trajectories/")
@ -95,5 +112,6 @@ if __name__ == '__main__':
plot_single_3d_trajectories_by_layer(model, weight_df, save_path, status_type=FixTypes.identity_func) plot_single_3d_trajectories_by_layer(model, weight_df, save_path, status_type=FixTypes.identity_func)
plot_single_3d_trajectories_by_layer(model, weight_df, save_path, status_type=FixTypes.other_func) plot_single_3d_trajectories_by_layer(model, weight_df, save_path, status_type=FixTypes.other_func)
#plot_grouped_3d_trajectories_by_layer(model, weight_df, save_path, FixTypes.identity_func) plot_grouped_3d_trajectories_by_layer(model, weight_df, save_path, FixTypes.identity_func)
#plot_grouped_3d_trajectories_by_layer(model, weight_df, save_path, FixTypes.other_func) #plot_grouped_3d_trajectories_by_layer(model, weight_df, save_path, FixTypes.other_func)
"""