MetaNetworks Debugged II

This commit is contained in:
Steffen Illium 2022-02-01 18:17:11 +01:00
parent 246d825bb4
commit 1b7581e656
4 changed files with 105 additions and 61 deletions

View File

@ -2,4 +2,5 @@ from .mixed_setting_exp import run_mixed_experiment
from .robustness_exp import run_robustness_experiment
from .self_application_exp import run_SA_experiment
from .self_train_exp import run_ST_experiment
from .soup_exp import run_soup_experiment
from .soup_exp import run_soup_experiment
import functionalities_test

View File

@ -6,8 +6,16 @@ import platform
import pandas as pd
import torchmetrics
from functionalities_test import test_for_fixpoints
import numpy as np
import torch
from matplotlib import pyplot as plt
import seaborn as sns
from torch import nn
from torch.nn import Flatten
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Compose, Resize
from tqdm import tqdm
if platform.node() == 'CarbonX':
debug = True
@ -28,23 +36,12 @@ else:
DIR = None
pass
import numpy as np
import torch
from matplotlib import pyplot as plt
import seaborn as sns
from torch import nn
from torch.nn import Flatten
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Compose, Resize
from tqdm import tqdm
from network import MetaNet
from functionalities_test import test_for_fixpoints
WORKER = 10 if not debug else 2
BATCHSIZE = 500 if not debug else 50
EPOCH = 50 if not debug else 3
EPOCH = 100 if not debug else 3
VALIDATION_FRQ = 5 if not debug else 1
SELF_TRAIN_FRQ = 1 if not debug else 1
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@ -78,7 +75,7 @@ def set_checkpoint(model, out_path, epoch_n, final_model=False):
if not final_model:
ckpt_path = Path(out_path) / 'ckpt' / f'{epoch_n.zfill(4)}_model_ckpt.tp'
else:
ckpt_path = Path(out_path) / f'trained_model_ckpt.tp'
ckpt_path = Path(out_path) / f'trained_model_ckpt_e{epoch_n}.tp'
ckpt_path.parent.mkdir(exist_ok=True, parents=True)
torch.save(model, ckpt_path, pickle_protocol=pickle.HIGHEST_PROTOCOL)
@ -91,15 +88,16 @@ def validate(checkpoint_path, ratio=0.1):
# initialize metric
validmetric = torchmetrics.Accuracy()
ut = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)])
try:
datas = MNIST(str(data_path), transform=utility_transforms, train=False)
datas = MNIST(str(data_path), transform=ut, train=False)
except RuntimeError:
datas = MNIST(str(data_path), transform=utility_transforms, train=False, download=True)
datas = MNIST(str(data_path), transform=ut, train=False, download=True)
valid_d = DataLoader(datas, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
model = torch.load(checkpoint_path, map_location=DEVICE).eval()
n_samples = int(len(d) * ratio)
n_samples = int(len(valid_d) * ratio)
with tqdm(total=n_samples, desc='Validation Run: ') as pbar:
for idx, (valid_batch_x, valid_batch_y) in enumerate(valid_d):
@ -119,6 +117,10 @@ def validate(checkpoint_path, ratio=0.1):
return acc
def new_train_storage_df():
return pd.DataFrame(columns=['Epoch', 'Batch', 'Metric', 'Score'])
def checkpoint_and_validate(model, out_path, epoch_n, final_model=False):
out_path = Path(out_path)
ckpt_path = set_checkpoint(model, out_path, epoch_n, final_model=final_model)
@ -130,18 +132,28 @@ def plot_training_result(path_to_dataframe):
# load from Drive
df = pd.read_csv(path_to_dataframe, index_col=0)
# Set up figure
fig, ax1 = plt.subplots() # initializes figure and plots
ax2 = ax1.twinx() # applies twinx to ax2, which is the second y-axis.
ax2 = ax1.twinx() # applies twinx to ax2, which is the second y-axis.
# plots the first set of data, and sets it to ax1.
data = df[df['Metric'] == 'BatchLoss']
# plots the second set, and sets to ax2.
sns.lineplot(data=data.groupby('Epoch').mean(), x='Epoch', y='Score', legend=True, ax=ax1, color='blue')
# plots the first set of data
data = df[(df['Metric'] == 'Task Loss') | (df['Metric'] == 'Self Train Loss')].groupby(['Epoch', 'Metric']).mean()
palette = sns.color_palette()[0:data.reset_index()['Metric'].unique().shape[0]]
sns.lineplot(data=data.groupby(['Epoch', 'Metric']).mean(), x='Epoch', y='Score', hue='Metric',
palette=palette, ax=ax1)
# plots the second set of data
data = df[(df['Metric'] == 'Test Accuracy') | (df['Metric'] == 'Train Accuracy')]
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', hue='Metric', legend=True)
palette = sns.color_palette()[len(palette):data.reset_index()['Metric'].unique().shape[0] + len(palette)]
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', hue='Metric', palette=palette)
ax1.set(yscale='log')
ax1.set(yscale='log', ylabel='Losses')
ax1.set_title('Training Lineplot')
ax2.set(ylabel='Accuracy')
fig.legend(loc="center right", title='Metric', bbox_to_anchor=(0.85, 0.5))
ax1.get_legend().remove()
ax2.get_legend().remove()
plt.tight_layout()
if debug:
plt.show()
@ -155,16 +167,17 @@ if __name__ == '__main__':
training = False
plotting = False
particle_analysis = True
as_sparse_network_test = True
data_path = Path('data')
data_path.mkdir(exist_ok=True, parents=True)
run_path = Path('output') / 'mnist_test_half_size'
run_path = Path('output') / 'mnist_self_train_100_NEW_STYLE'
model_path = run_path / '0000_trained_model.zip'
df_store_path = run_path / 'train_store.csv'
if training:
utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)])
try:
dataset = MNIST(str(data_path), transform=utility_transforms)
except RuntimeError:
@ -177,7 +190,7 @@ if __name__ == '__main__':
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(metanet.parameters(), lr=0.004, momentum=0.9)
train_store = pd.DataFrame(columns=['Epoch', 'Batch', 'Metric', 'Score'])
train_store = new_train_storage_df()
for epoch in tqdm(range(EPOCH), desc='MetaNet Train - Epochs'):
is_validation_epoch = epoch % VALIDATION_FRQ == 0 if not debug else True
is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True
@ -187,12 +200,9 @@ if __name__ == '__main__':
metric = None
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
if self_train and is_self_train_epoch:
# Zero your gradients for every batch!
optimizer.zero_grad()
combined_self_train_loss = metanet.combined_self_train()
combined_self_train_loss.backward()
# Adjust learning weights
optimizer.step()
self_train_loss = metanet.combined_self_train(optimizer)
step_log = dict(Epoch=epoch, Batch=batch, Metric='Self Train Loss', Score=self_train_loss.item())
train_store.loc[train_store.shape[0]] = step_log
# Zero your gradients for every batch!
optimizer.zero_grad()
@ -206,7 +216,7 @@ if __name__ == '__main__':
optimizer.step()
step_log = dict(Epoch=epoch, Batch=batch,
Metric='BatchLoss', Score=loss.item())
Metric='Task Loss', Score=loss.item())
train_store.loc[train_store.shape[0]] = step_log
if is_validation_epoch:
metric(y.cpu(), batch_y.cpu())
@ -223,23 +233,39 @@ if __name__ == '__main__':
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric='Test Accuracy', Score=accuracy.item())
train_store.loc[train_store.shape[0]] = validation_log
if particle_analysis:
counter_dict = defaultdict(lambda: 0)
# This returns ID-functions
_ = test_for_fixpoints(counter_dict, list(metanet.particles))
for key, value in dict(counter_dict).items():
step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value)
train_store.loc[train_store.shape[0]] = step_log
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists())
train_store = new_train_storage_df()
accuracy = checkpoint_and_validate(metanet, run_path, EPOCH, final_model=True)
validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
Metric='Test Accuracy', Score=accuracy.item())
train_store.loc[train_store.shape[0]] = validation_log
train_store.to_csv(run_path / 'train_store.csv')
train_store.to_csv(df_store_path)
if plotting:
plot_training_result(run_path / 'train_store.csv')
plot_training_result(df_store_path)
if particle_analysis:
model_path = next(run_path.glob('*.tp'))
model_path = next(run_path.glob('*ckpt.tp'))
latest_model = torch.load(model_path, map_location=DEVICE).eval()
analysis_dict = defaultdict(dict)
counter_dict = defaultdict(lambda: 0)
for particle in latest_model.particles:
analysis_dict[particle.name]['is_diverged'] = particle.are_weights_diverged()
test_for_fixpoints(counter_dict, latest_model.particles)
_ = test_for_fixpoints(counter_dict, list(latest_model.particles))
tqdm.write(str(dict(counter_dict)))
zero_ident = torch.load(model_path, map_location=DEVICE).eval().replace_with_zero('identity_func')
zero_other = torch.load(model_path, map_location=DEVICE).eval().replace_with_zero('other_func')
if as_sparse_network_test:
acc_pre = validate(model_path, ratio=1)
ident_ckpt = set_checkpoint(zero_ident, model_path.parent, -1, final_model=True)
ident_acc_post = validate(ident_ckpt, ratio=1)
tqdm.write(f'Zero_ident diff = {abs(ident_acc_post-acc_pre)}')
other_ckpt = set_checkpoint(zero_other, model_path.parent, -2, final_model=True)
other_acc_post = validate(other_ckpt, ratio=1)
tqdm.write(f'Zero_other diff = {abs(other_acc_post - acc_pre)}')

View File

@ -1,16 +1,13 @@
import copy
from typing import Dict, List
import numpy as np
import torch
from tqdm import tqdm
from network import Net
def is_divergent(network: Net) -> bool:
for i in network.input_weight_matrix():
weight_value = i[0].item()
if np.isnan(weight_value).any() or np.isinf(weight_value).any():
return True
return False
return network.input_weight_matrix().isinf().any().item() or network.input_weight_matrix().isnan().any().item()
def is_identity_function(network: Net, epsilon=pow(10, -5)) -> bool:
@ -19,13 +16,14 @@ def is_identity_function(network: Net, epsilon=pow(10, -5)) -> bool:
target_data = network.create_target_weights(input_data)
predicted_values = network(input_data)
return np.allclose(target_data.detach().numpy(), predicted_values.detach().numpy(),
rtol=0, atol=epsilon)
return torch.allclose(target_data.detach(), predicted_values.detach(),
rtol=0, atol=epsilon)
def is_zero_fixpoint(network: Net, epsilon=pow(10, -5)) -> bool:
target_data = network.create_target_weights(network.input_weight_matrix().detach())
result = np.allclose(target_data, np.zeros_like(target_data), rtol=0, atol=epsilon)
result = torch.allclose(target_data, torch.zeros_like(target_data), rtol=0, atol=epsilon)
# result = bool(len(np.nonzero(network.create_target_weights(network.input_weight_matrix()))))
return result
@ -49,15 +47,15 @@ def is_secondary_fixpoint(network: Net, epsilon: float = pow(10, -5)) -> bool:
second_output = network(input_data_2)
# Perform the Check: all(epsilon > abs(input_data - second_output))
check_abs_within_epsilon = np.allclose(target_data.detach().numpy(), second_output.detach().numpy(),
rtol=0, atol=epsilon)
check_abs_within_epsilon = torch.allclose(target_data.detach(), second_output.detach(),
rtol=0, atol=epsilon)
return check_abs_within_epsilon
def test_for_fixpoints(fixpoint_counter: Dict, nets: List, id_functions=None):
id_functions = id_functions or list()
for net in nets:
for net in tqdm(nets, desc='Fixpoint Tester', total=len(nets)):
if is_divergent(net):
fixpoint_counter["divergent"] += 1
net.is_fixpoint = "divergent"

View File

@ -9,6 +9,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim, Tensor
from tqdm import tqdm
def prng():
@ -391,6 +392,17 @@ class MetaNet(nn.Module):
interface=self.width, width=self.out)
)
def replace_with_zero(self, ident_key):
replaced_particles = 0
for particle in self.particles:
if particle.is_fixpoint == ident_key:
particle.load_state_dict(
{key: torch.zeros_like(state) for key, state in particle.state_dict().items()}
)
replaced_particles += 1
tqdm.write(f'Particle Parameters replaced: {str(replaced_particles)}')
return self
def forward(self, x):
tensor = x
for meta_layer in self._meta_layer_list:
@ -401,15 +413,22 @@ class MetaNet(nn.Module):
def particles(self):
return (cell for metalayer in self._meta_layer_list for cell in metalayer.particles)
def combined_self_train(self):
def combined_self_train(self, external_optimizer):
losses = []
for particle in self.particles:
# Zero your gradients for every batch!
external_optimizer.zero_grad()
# Intergrate optimizer and backward function
input_data = particle.input_weight_matrix()
target_data = particle.create_target_weights(input_data)
output = particle(input_data)
losses.append(F.mse_loss(output, target_data))
return torch.hstack(losses).sum(dim=-1, keepdim=True)
loss = F.mse_loss(output, target_data)
losses.append(loss.detach)
loss.backward()
# Adjust learning weights
external_optimizer.step()
# return torch.hstack(losses).sum(dim=-1, keepdim=True)
return sum(losses)
if __name__ == '__main__':