MetaNetworks Debugged
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
import pickle
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import platform
|
||||
@ -7,7 +7,14 @@ import platform
|
||||
import pandas as pd
|
||||
import torchmetrics
|
||||
|
||||
if platform.node() != 'CarbonX':
|
||||
from functionalities_test import test_for_fixpoints
|
||||
|
||||
if platform.node() == 'CarbonX':
|
||||
debug = True
|
||||
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
|
||||
print("@ Warning, Debugging Config@!!!!!! @")
|
||||
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
|
||||
else:
|
||||
debug = False
|
||||
try:
|
||||
# noinspection PyUnboundLocalVariable
|
||||
@ -20,8 +27,7 @@ if platform.node() != 'CarbonX':
|
||||
except NameError:
|
||||
DIR = None
|
||||
pass
|
||||
else:
|
||||
debug = True
|
||||
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -31,7 +37,7 @@ from torch import nn
|
||||
from torch.nn import Flatten
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision.transforms import ToTensor, Compose
|
||||
from torchvision.transforms import ToTensor, Compose, Resize
|
||||
from tqdm import tqdm
|
||||
|
||||
from network import MetaNet
|
||||
@ -69,7 +75,7 @@ class AddTaskDataset(Dataset):
|
||||
|
||||
def set_checkpoint(model, out_path, epoch_n, final_model=False):
|
||||
epoch_n = str(epoch_n)
|
||||
if final_model:
|
||||
if not final_model:
|
||||
ckpt_path = Path(out_path) / 'ckpt' / f'{epoch_n.zfill(4)}_model_ckpt.tp'
|
||||
else:
|
||||
ckpt_path = Path(out_path) / f'trained_model_ckpt.tp'
|
||||
@ -84,32 +90,32 @@ def validate(checkpoint_path, ratio=0.1):
|
||||
import torchmetrics
|
||||
|
||||
# initialize metric
|
||||
metric = torchmetrics.Accuracy()
|
||||
validmetric = torchmetrics.Accuracy()
|
||||
|
||||
try:
|
||||
dataset = MNIST(str(data_path), transform=utility_transforms, train=False)
|
||||
datas = MNIST(str(data_path), transform=utility_transforms, train=False)
|
||||
except RuntimeError:
|
||||
dataset = MNIST(str(data_path), transform=utility_transforms, train=False, download=True)
|
||||
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
|
||||
datas = MNIST(str(data_path), transform=utility_transforms, train=False, download=True)
|
||||
valid_d = DataLoader(datas, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
|
||||
|
||||
model = torch.load(checkpoint_path, map_location=DEVICE).eval()
|
||||
n_samples = int(len(d) * ratio)
|
||||
|
||||
with tqdm(total=n_samples, desc='Validation Run: ') as pbar:
|
||||
for idx, (batch_x, batch_y) in enumerate(d):
|
||||
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
|
||||
y = model(batch_x)
|
||||
for idx, (valid_batch_x, valid_batch_y) in enumerate(valid_d):
|
||||
valid_batch_x, valid_batch_y = valid_batch_x.to(DEVICE), valid_batch_y.to(DEVICE)
|
||||
y_valid = model(valid_batch_x)
|
||||
|
||||
# metric on current batch
|
||||
acc = metric(y.cpu(), batch_y.cpu())
|
||||
acc = validmetric(y_valid.cpu(), valid_batch_y.cpu())
|
||||
pbar.set_postfix_str(f'Acc: {acc}')
|
||||
pbar.update()
|
||||
if idx == n_samples:
|
||||
break
|
||||
|
||||
# metric on all batches using custom accumulation
|
||||
acc = metric.compute()
|
||||
print(f"Accuracy on all data: {acc}")
|
||||
acc = validmetric.compute()
|
||||
tqdm.write(f"Avg. accuracy on all data: {acc}")
|
||||
return acc
|
||||
|
||||
|
||||
@ -125,41 +131,39 @@ def plot_training_result(path_to_dataframe):
|
||||
df = pd.read_csv(path_to_dataframe, index_col=0)
|
||||
|
||||
fig, ax1 = plt.subplots() # initializes figure and plots
|
||||
ax2 = ax1.twinx() # applies twinx to ax2, which is the second y axis.
|
||||
ax2 = ax1.twinx() # applies twinx to ax2, which is the second y-axis.
|
||||
|
||||
# plots the first set of data, and sets it to ax1.
|
||||
data = df[df['Metric'] == 'BatchLoss']
|
||||
# plots the second set, and sets to ax2.
|
||||
sns.lineplot(data=data.groupby('Epoch').mean(), x='Epoch', y='Score', legend=True, ax=ax2)
|
||||
data = df[df['Metric'] == 'Test Accuracy']
|
||||
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', color='red')
|
||||
data = df[df['Metric'] == 'Train Accuracy']
|
||||
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', color='green')
|
||||
sns.lineplot(data=data.groupby('Epoch').mean(), x='Epoch', y='Score', legend=True, ax=ax1, color='blue')
|
||||
data = df[(df['Metric'] == 'Test Accuracy') | (df['Metric'] == 'Train Accuracy')]
|
||||
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', hue='Metric', legend=True)
|
||||
|
||||
ax2.set(yscale='log')
|
||||
ax1.set(yscale='log')
|
||||
ax1.set_title('Training Lineplot')
|
||||
plt.tight_layout()
|
||||
if debug:
|
||||
plt.show()
|
||||
else:
|
||||
plt.savefig(Path(path_to_dataframe.parent / 'training_lineplot.png'))
|
||||
plt.savefig(Path(path_to_dataframe.parent / 'training_lineplot.png'), dpi=300)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
self_train = True
|
||||
soup_interaction = True
|
||||
training = True
|
||||
plotting = True
|
||||
self_train = False
|
||||
training = False
|
||||
plotting = False
|
||||
particle_analysis = True
|
||||
|
||||
data_path = Path('data')
|
||||
data_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
run_path = Path('output') / 'intergrated_self_train'
|
||||
run_path = Path('output') / 'mnist_test_half_size'
|
||||
model_path = run_path / '0000_trained_model.zip'
|
||||
|
||||
if training:
|
||||
utility_transforms = Compose([ToTensor(), ToFloat(), Flatten(start_dim=0)])
|
||||
utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)])
|
||||
|
||||
try:
|
||||
dataset = MNIST(str(data_path), transform=utility_transforms)
|
||||
@ -179,6 +183,8 @@ if __name__ == '__main__':
|
||||
is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True
|
||||
if is_validation_epoch:
|
||||
metric = torchmetrics.Accuracy()
|
||||
else:
|
||||
metric = None
|
||||
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
|
||||
if self_train and is_self_train_epoch:
|
||||
# Zero your gradients for every batch!
|
||||
@ -221,10 +227,19 @@ if __name__ == '__main__':
|
||||
accuracy = checkpoint_and_validate(metanet, run_path, EPOCH, final_model=True)
|
||||
validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
|
||||
Metric='Test Accuracy', Score=accuracy.item())
|
||||
train_store.loc[train_store.shape[0]] = validation_log
|
||||
|
||||
torch.save(metanet, model_path, pickle_protocol=pickle.HIGHEST_PROTOCOL)
|
||||
train_store.loc[train_store.shape[0]] = validation_log
|
||||
train_store.to_csv(run_path / 'train_store.csv')
|
||||
|
||||
if plotting:
|
||||
plot_training_result(run_path / 'train_store.csv')
|
||||
|
||||
if particle_analysis:
|
||||
model_path = next(run_path.glob('*.tp'))
|
||||
latest_model = torch.load(model_path, map_location=DEVICE).eval()
|
||||
analysis_dict = defaultdict(dict)
|
||||
counter_dict = defaultdict(lambda: 0)
|
||||
for particle in latest_model.particles:
|
||||
analysis_dict[particle.name]['is_diverged'] = particle.are_weights_diverged()
|
||||
test_for_fixpoints(counter_dict, latest_model.particles)
|
||||
|
||||
|
Reference in New Issue
Block a user