MetaNetworks Debugged

This commit is contained in:
Steffen Illium
2022-01-31 10:35:11 +01:00
parent 49c0d8a621
commit 246d825bb4
8 changed files with 169 additions and 109 deletions

View File

@ -1,5 +1,5 @@
import pickle
import time
from collections import defaultdict
from pathlib import Path
import sys
import platform
@ -7,7 +7,14 @@ import platform
import pandas as pd
import torchmetrics
if platform.node() != 'CarbonX':
from functionalities_test import test_for_fixpoints
if platform.node() == 'CarbonX':
debug = True
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
print("@ Warning, Debugging Config@!!!!!! @")
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
else:
debug = False
try:
# noinspection PyUnboundLocalVariable
@ -20,8 +27,7 @@ if platform.node() != 'CarbonX':
except NameError:
DIR = None
pass
else:
debug = True
import numpy as np
import torch
@ -31,7 +37,7 @@ from torch import nn
from torch.nn import Flatten
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Compose
from torchvision.transforms import ToTensor, Compose, Resize
from tqdm import tqdm
from network import MetaNet
@ -69,7 +75,7 @@ class AddTaskDataset(Dataset):
def set_checkpoint(model, out_path, epoch_n, final_model=False):
epoch_n = str(epoch_n)
if final_model:
if not final_model:
ckpt_path = Path(out_path) / 'ckpt' / f'{epoch_n.zfill(4)}_model_ckpt.tp'
else:
ckpt_path = Path(out_path) / f'trained_model_ckpt.tp'
@ -84,32 +90,32 @@ def validate(checkpoint_path, ratio=0.1):
import torchmetrics
# initialize metric
metric = torchmetrics.Accuracy()
validmetric = torchmetrics.Accuracy()
try:
dataset = MNIST(str(data_path), transform=utility_transforms, train=False)
datas = MNIST(str(data_path), transform=utility_transforms, train=False)
except RuntimeError:
dataset = MNIST(str(data_path), transform=utility_transforms, train=False, download=True)
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
datas = MNIST(str(data_path), transform=utility_transforms, train=False, download=True)
valid_d = DataLoader(datas, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
model = torch.load(checkpoint_path, map_location=DEVICE).eval()
n_samples = int(len(d) * ratio)
with tqdm(total=n_samples, desc='Validation Run: ') as pbar:
for idx, (batch_x, batch_y) in enumerate(d):
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
y = model(batch_x)
for idx, (valid_batch_x, valid_batch_y) in enumerate(valid_d):
valid_batch_x, valid_batch_y = valid_batch_x.to(DEVICE), valid_batch_y.to(DEVICE)
y_valid = model(valid_batch_x)
# metric on current batch
acc = metric(y.cpu(), batch_y.cpu())
acc = validmetric(y_valid.cpu(), valid_batch_y.cpu())
pbar.set_postfix_str(f'Acc: {acc}')
pbar.update()
if idx == n_samples:
break
# metric on all batches using custom accumulation
acc = metric.compute()
print(f"Accuracy on all data: {acc}")
acc = validmetric.compute()
tqdm.write(f"Avg. accuracy on all data: {acc}")
return acc
@ -125,41 +131,39 @@ def plot_training_result(path_to_dataframe):
df = pd.read_csv(path_to_dataframe, index_col=0)
fig, ax1 = plt.subplots() # initializes figure and plots
ax2 = ax1.twinx() # applies twinx to ax2, which is the second y axis.
ax2 = ax1.twinx() # applies twinx to ax2, which is the second y-axis.
# plots the first set of data, and sets it to ax1.
data = df[df['Metric'] == 'BatchLoss']
# plots the second set, and sets to ax2.
sns.lineplot(data=data.groupby('Epoch').mean(), x='Epoch', y='Score', legend=True, ax=ax2)
data = df[df['Metric'] == 'Test Accuracy']
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', color='red')
data = df[df['Metric'] == 'Train Accuracy']
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', color='green')
sns.lineplot(data=data.groupby('Epoch').mean(), x='Epoch', y='Score', legend=True, ax=ax1, color='blue')
data = df[(df['Metric'] == 'Test Accuracy') | (df['Metric'] == 'Train Accuracy')]
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', hue='Metric', legend=True)
ax2.set(yscale='log')
ax1.set(yscale='log')
ax1.set_title('Training Lineplot')
plt.tight_layout()
if debug:
plt.show()
else:
plt.savefig(Path(path_to_dataframe.parent / 'training_lineplot.png'))
plt.savefig(Path(path_to_dataframe.parent / 'training_lineplot.png'), dpi=300)
if __name__ == '__main__':
self_train = True
soup_interaction = True
training = True
plotting = True
self_train = False
training = False
plotting = False
particle_analysis = True
data_path = Path('data')
data_path.mkdir(exist_ok=True, parents=True)
run_path = Path('output') / 'intergrated_self_train'
run_path = Path('output') / 'mnist_test_half_size'
model_path = run_path / '0000_trained_model.zip'
if training:
utility_transforms = Compose([ToTensor(), ToFloat(), Flatten(start_dim=0)])
utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)])
try:
dataset = MNIST(str(data_path), transform=utility_transforms)
@ -179,6 +183,8 @@ if __name__ == '__main__':
is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True
if is_validation_epoch:
metric = torchmetrics.Accuracy()
else:
metric = None
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
if self_train and is_self_train_epoch:
# Zero your gradients for every batch!
@ -221,10 +227,19 @@ if __name__ == '__main__':
accuracy = checkpoint_and_validate(metanet, run_path, EPOCH, final_model=True)
validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
Metric='Test Accuracy', Score=accuracy.item())
train_store.loc[train_store.shape[0]] = validation_log
torch.save(metanet, model_path, pickle_protocol=pickle.HIGHEST_PROTOCOL)
train_store.loc[train_store.shape[0]] = validation_log
train_store.to_csv(run_path / 'train_store.csv')
if plotting:
plot_training_result(run_path / 'train_store.csv')
if particle_analysis:
model_path = next(run_path.glob('*.tp'))
latest_model = torch.load(model_path, map_location=DEVICE).eval()
analysis_dict = defaultdict(dict)
counter_dict = defaultdict(lambda: 0)
for particle in latest_model.particles:
analysis_dict[particle.name]['is_diverged'] = particle.are_weights_diverged()
test_for_fixpoints(counter_dict, latest_model.particles)