MetaNetworks Debugged II

This commit is contained in:
Steffen Illium 2022-02-01 18:17:11 +01:00
parent 246d825bb4
commit 1b7581e656
4 changed files with 105 additions and 61 deletions

View File

@ -2,4 +2,5 @@ from .mixed_setting_exp import run_mixed_experiment
from .robustness_exp import run_robustness_experiment from .robustness_exp import run_robustness_experiment
from .self_application_exp import run_SA_experiment from .self_application_exp import run_SA_experiment
from .self_train_exp import run_ST_experiment from .self_train_exp import run_ST_experiment
from .soup_exp import run_soup_experiment from .soup_exp import run_soup_experiment
import functionalities_test

View File

@ -6,8 +6,16 @@ import platform
import pandas as pd import pandas as pd
import torchmetrics import torchmetrics
import numpy as np
from functionalities_test import test_for_fixpoints import torch
from matplotlib import pyplot as plt
import seaborn as sns
from torch import nn
from torch.nn import Flatten
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Compose, Resize
from tqdm import tqdm
if platform.node() == 'CarbonX': if platform.node() == 'CarbonX':
debug = True debug = True
@ -28,23 +36,12 @@ else:
DIR = None DIR = None
pass pass
import numpy as np
import torch
from matplotlib import pyplot as plt
import seaborn as sns
from torch import nn
from torch.nn import Flatten
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Compose, Resize
from tqdm import tqdm
from network import MetaNet from network import MetaNet
from functionalities_test import test_for_fixpoints
WORKER = 10 if not debug else 2 WORKER = 10 if not debug else 2
BATCHSIZE = 500 if not debug else 50 BATCHSIZE = 500 if not debug else 50
EPOCH = 50 if not debug else 3 EPOCH = 100 if not debug else 3
VALIDATION_FRQ = 5 if not debug else 1 VALIDATION_FRQ = 5 if not debug else 1
SELF_TRAIN_FRQ = 1 if not debug else 1 SELF_TRAIN_FRQ = 1 if not debug else 1
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@ -78,7 +75,7 @@ def set_checkpoint(model, out_path, epoch_n, final_model=False):
if not final_model: if not final_model:
ckpt_path = Path(out_path) / 'ckpt' / f'{epoch_n.zfill(4)}_model_ckpt.tp' ckpt_path = Path(out_path) / 'ckpt' / f'{epoch_n.zfill(4)}_model_ckpt.tp'
else: else:
ckpt_path = Path(out_path) / f'trained_model_ckpt.tp' ckpt_path = Path(out_path) / f'trained_model_ckpt_e{epoch_n}.tp'
ckpt_path.parent.mkdir(exist_ok=True, parents=True) ckpt_path.parent.mkdir(exist_ok=True, parents=True)
torch.save(model, ckpt_path, pickle_protocol=pickle.HIGHEST_PROTOCOL) torch.save(model, ckpt_path, pickle_protocol=pickle.HIGHEST_PROTOCOL)
@ -91,15 +88,16 @@ def validate(checkpoint_path, ratio=0.1):
# initialize metric # initialize metric
validmetric = torchmetrics.Accuracy() validmetric = torchmetrics.Accuracy()
ut = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)])
try: try:
datas = MNIST(str(data_path), transform=utility_transforms, train=False) datas = MNIST(str(data_path), transform=ut, train=False)
except RuntimeError: except RuntimeError:
datas = MNIST(str(data_path), transform=utility_transforms, train=False, download=True) datas = MNIST(str(data_path), transform=ut, train=False, download=True)
valid_d = DataLoader(datas, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER) valid_d = DataLoader(datas, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
model = torch.load(checkpoint_path, map_location=DEVICE).eval() model = torch.load(checkpoint_path, map_location=DEVICE).eval()
n_samples = int(len(d) * ratio) n_samples = int(len(valid_d) * ratio)
with tqdm(total=n_samples, desc='Validation Run: ') as pbar: with tqdm(total=n_samples, desc='Validation Run: ') as pbar:
for idx, (valid_batch_x, valid_batch_y) in enumerate(valid_d): for idx, (valid_batch_x, valid_batch_y) in enumerate(valid_d):
@ -119,6 +117,10 @@ def validate(checkpoint_path, ratio=0.1):
return acc return acc
def new_train_storage_df():
return pd.DataFrame(columns=['Epoch', 'Batch', 'Metric', 'Score'])
def checkpoint_and_validate(model, out_path, epoch_n, final_model=False): def checkpoint_and_validate(model, out_path, epoch_n, final_model=False):
out_path = Path(out_path) out_path = Path(out_path)
ckpt_path = set_checkpoint(model, out_path, epoch_n, final_model=final_model) ckpt_path = set_checkpoint(model, out_path, epoch_n, final_model=final_model)
@ -130,18 +132,28 @@ def plot_training_result(path_to_dataframe):
# load from Drive # load from Drive
df = pd.read_csv(path_to_dataframe, index_col=0) df = pd.read_csv(path_to_dataframe, index_col=0)
# Set up figure
fig, ax1 = plt.subplots() # initializes figure and plots fig, ax1 = plt.subplots() # initializes figure and plots
ax2 = ax1.twinx() # applies twinx to ax2, which is the second y-axis. ax2 = ax1.twinx() # applies twinx to ax2, which is the second y-axis.
# plots the first set of data, and sets it to ax1. # plots the first set of data
data = df[df['Metric'] == 'BatchLoss'] data = df[(df['Metric'] == 'Task Loss') | (df['Metric'] == 'Self Train Loss')].groupby(['Epoch', 'Metric']).mean()
# plots the second set, and sets to ax2. palette = sns.color_palette()[0:data.reset_index()['Metric'].unique().shape[0]]
sns.lineplot(data=data.groupby('Epoch').mean(), x='Epoch', y='Score', legend=True, ax=ax1, color='blue') sns.lineplot(data=data.groupby(['Epoch', 'Metric']).mean(), x='Epoch', y='Score', hue='Metric',
palette=palette, ax=ax1)
# plots the second set of data
data = df[(df['Metric'] == 'Test Accuracy') | (df['Metric'] == 'Train Accuracy')] data = df[(df['Metric'] == 'Test Accuracy') | (df['Metric'] == 'Train Accuracy')]
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', hue='Metric', legend=True) palette = sns.color_palette()[len(palette):data.reset_index()['Metric'].unique().shape[0] + len(palette)]
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', hue='Metric', palette=palette)
ax1.set(yscale='log') ax1.set(yscale='log', ylabel='Losses')
ax1.set_title('Training Lineplot') ax1.set_title('Training Lineplot')
ax2.set(ylabel='Accuracy')
fig.legend(loc="center right", title='Metric', bbox_to_anchor=(0.85, 0.5))
ax1.get_legend().remove()
ax2.get_legend().remove()
plt.tight_layout() plt.tight_layout()
if debug: if debug:
plt.show() plt.show()
@ -155,16 +167,17 @@ if __name__ == '__main__':
training = False training = False
plotting = False plotting = False
particle_analysis = True particle_analysis = True
as_sparse_network_test = True
data_path = Path('data') data_path = Path('data')
data_path.mkdir(exist_ok=True, parents=True) data_path.mkdir(exist_ok=True, parents=True)
run_path = Path('output') / 'mnist_test_half_size' run_path = Path('output') / 'mnist_self_train_100_NEW_STYLE'
model_path = run_path / '0000_trained_model.zip' model_path = run_path / '0000_trained_model.zip'
df_store_path = run_path / 'train_store.csv'
if training: if training:
utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)]) utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)])
try: try:
dataset = MNIST(str(data_path), transform=utility_transforms) dataset = MNIST(str(data_path), transform=utility_transforms)
except RuntimeError: except RuntimeError:
@ -177,7 +190,7 @@ if __name__ == '__main__':
loss_fn = nn.CrossEntropyLoss() loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(metanet.parameters(), lr=0.004, momentum=0.9) optimizer = torch.optim.SGD(metanet.parameters(), lr=0.004, momentum=0.9)
train_store = pd.DataFrame(columns=['Epoch', 'Batch', 'Metric', 'Score']) train_store = new_train_storage_df()
for epoch in tqdm(range(EPOCH), desc='MetaNet Train - Epochs'): for epoch in tqdm(range(EPOCH), desc='MetaNet Train - Epochs'):
is_validation_epoch = epoch % VALIDATION_FRQ == 0 if not debug else True is_validation_epoch = epoch % VALIDATION_FRQ == 0 if not debug else True
is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True
@ -187,12 +200,9 @@ if __name__ == '__main__':
metric = None metric = None
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'): for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
if self_train and is_self_train_epoch: if self_train and is_self_train_epoch:
# Zero your gradients for every batch! self_train_loss = metanet.combined_self_train(optimizer)
optimizer.zero_grad() step_log = dict(Epoch=epoch, Batch=batch, Metric='Self Train Loss', Score=self_train_loss.item())
combined_self_train_loss = metanet.combined_self_train() train_store.loc[train_store.shape[0]] = step_log
combined_self_train_loss.backward()
# Adjust learning weights
optimizer.step()
# Zero your gradients for every batch! # Zero your gradients for every batch!
optimizer.zero_grad() optimizer.zero_grad()
@ -206,7 +216,7 @@ if __name__ == '__main__':
optimizer.step() optimizer.step()
step_log = dict(Epoch=epoch, Batch=batch, step_log = dict(Epoch=epoch, Batch=batch,
Metric='BatchLoss', Score=loss.item()) Metric='Task Loss', Score=loss.item())
train_store.loc[train_store.shape[0]] = step_log train_store.loc[train_store.shape[0]] = step_log
if is_validation_epoch: if is_validation_epoch:
metric(y.cpu(), batch_y.cpu()) metric(y.cpu(), batch_y.cpu())
@ -223,23 +233,39 @@ if __name__ == '__main__':
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric='Test Accuracy', Score=accuracy.item()) Metric='Test Accuracy', Score=accuracy.item())
train_store.loc[train_store.shape[0]] = validation_log train_store.loc[train_store.shape[0]] = validation_log
if particle_analysis:
counter_dict = defaultdict(lambda: 0)
# This returns ID-functions
_ = test_for_fixpoints(counter_dict, list(metanet.particles))
for key, value in dict(counter_dict).items():
step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value)
train_store.loc[train_store.shape[0]] = step_log
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists())
train_store = new_train_storage_df()
accuracy = checkpoint_and_validate(metanet, run_path, EPOCH, final_model=True) accuracy = checkpoint_and_validate(metanet, run_path, EPOCH, final_model=True)
validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE, validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
Metric='Test Accuracy', Score=accuracy.item()) Metric='Test Accuracy', Score=accuracy.item())
train_store.loc[train_store.shape[0]] = validation_log train_store.loc[train_store.shape[0]] = validation_log
train_store.to_csv(run_path / 'train_store.csv') train_store.to_csv(df_store_path)
if plotting: if plotting:
plot_training_result(run_path / 'train_store.csv') plot_training_result(df_store_path)
if particle_analysis: if particle_analysis:
model_path = next(run_path.glob('*.tp')) model_path = next(run_path.glob('*ckpt.tp'))
latest_model = torch.load(model_path, map_location=DEVICE).eval() latest_model = torch.load(model_path, map_location=DEVICE).eval()
analysis_dict = defaultdict(dict)
counter_dict = defaultdict(lambda: 0) counter_dict = defaultdict(lambda: 0)
for particle in latest_model.particles: _ = test_for_fixpoints(counter_dict, list(latest_model.particles))
analysis_dict[particle.name]['is_diverged'] = particle.are_weights_diverged() tqdm.write(str(dict(counter_dict)))
test_for_fixpoints(counter_dict, latest_model.particles) zero_ident = torch.load(model_path, map_location=DEVICE).eval().replace_with_zero('identity_func')
zero_other = torch.load(model_path, map_location=DEVICE).eval().replace_with_zero('other_func')
if as_sparse_network_test:
acc_pre = validate(model_path, ratio=1)
ident_ckpt = set_checkpoint(zero_ident, model_path.parent, -1, final_model=True)
ident_acc_post = validate(ident_ckpt, ratio=1)
tqdm.write(f'Zero_ident diff = {abs(ident_acc_post-acc_pre)}')
other_ckpt = set_checkpoint(zero_other, model_path.parent, -2, final_model=True)
other_acc_post = validate(other_ckpt, ratio=1)
tqdm.write(f'Zero_other diff = {abs(other_acc_post - acc_pre)}')

View File

@ -1,16 +1,13 @@
import copy import copy
from typing import Dict, List from typing import Dict, List
import numpy as np import torch
from tqdm import tqdm
from network import Net from network import Net
def is_divergent(network: Net) -> bool: def is_divergent(network: Net) -> bool:
for i in network.input_weight_matrix(): return network.input_weight_matrix().isinf().any().item() or network.input_weight_matrix().isnan().any().item()
weight_value = i[0].item()
if np.isnan(weight_value).any() or np.isinf(weight_value).any():
return True
return False
def is_identity_function(network: Net, epsilon=pow(10, -5)) -> bool: def is_identity_function(network: Net, epsilon=pow(10, -5)) -> bool:
@ -19,13 +16,14 @@ def is_identity_function(network: Net, epsilon=pow(10, -5)) -> bool:
target_data = network.create_target_weights(input_data) target_data = network.create_target_weights(input_data)
predicted_values = network(input_data) predicted_values = network(input_data)
return np.allclose(target_data.detach().numpy(), predicted_values.detach().numpy(),
rtol=0, atol=epsilon) return torch.allclose(target_data.detach(), predicted_values.detach(),
rtol=0, atol=epsilon)
def is_zero_fixpoint(network: Net, epsilon=pow(10, -5)) -> bool: def is_zero_fixpoint(network: Net, epsilon=pow(10, -5)) -> bool:
target_data = network.create_target_weights(network.input_weight_matrix().detach()) target_data = network.create_target_weights(network.input_weight_matrix().detach())
result = np.allclose(target_data, np.zeros_like(target_data), rtol=0, atol=epsilon) result = torch.allclose(target_data, torch.zeros_like(target_data), rtol=0, atol=epsilon)
# result = bool(len(np.nonzero(network.create_target_weights(network.input_weight_matrix())))) # result = bool(len(np.nonzero(network.create_target_weights(network.input_weight_matrix()))))
return result return result
@ -49,15 +47,15 @@ def is_secondary_fixpoint(network: Net, epsilon: float = pow(10, -5)) -> bool:
second_output = network(input_data_2) second_output = network(input_data_2)
# Perform the Check: all(epsilon > abs(input_data - second_output)) # Perform the Check: all(epsilon > abs(input_data - second_output))
check_abs_within_epsilon = np.allclose(target_data.detach().numpy(), second_output.detach().numpy(), check_abs_within_epsilon = torch.allclose(target_data.detach(), second_output.detach(),
rtol=0, atol=epsilon) rtol=0, atol=epsilon)
return check_abs_within_epsilon return check_abs_within_epsilon
def test_for_fixpoints(fixpoint_counter: Dict, nets: List, id_functions=None): def test_for_fixpoints(fixpoint_counter: Dict, nets: List, id_functions=None):
id_functions = id_functions or list() id_functions = id_functions or list()
for net in nets: for net in tqdm(nets, desc='Fixpoint Tester', total=len(nets)):
if is_divergent(net): if is_divergent(net):
fixpoint_counter["divergent"] += 1 fixpoint_counter["divergent"] += 1
net.is_fixpoint = "divergent" net.is_fixpoint = "divergent"

View File

@ -9,6 +9,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import optim, Tensor from torch import optim, Tensor
from tqdm import tqdm
def prng(): def prng():
@ -391,6 +392,17 @@ class MetaNet(nn.Module):
interface=self.width, width=self.out) interface=self.width, width=self.out)
) )
def replace_with_zero(self, ident_key):
replaced_particles = 0
for particle in self.particles:
if particle.is_fixpoint == ident_key:
particle.load_state_dict(
{key: torch.zeros_like(state) for key, state in particle.state_dict().items()}
)
replaced_particles += 1
tqdm.write(f'Particle Parameters replaced: {str(replaced_particles)}')
return self
def forward(self, x): def forward(self, x):
tensor = x tensor = x
for meta_layer in self._meta_layer_list: for meta_layer in self._meta_layer_list:
@ -401,15 +413,22 @@ class MetaNet(nn.Module):
def particles(self): def particles(self):
return (cell for metalayer in self._meta_layer_list for cell in metalayer.particles) return (cell for metalayer in self._meta_layer_list for cell in metalayer.particles)
def combined_self_train(self): def combined_self_train(self, external_optimizer):
losses = [] losses = []
for particle in self.particles: for particle in self.particles:
# Zero your gradients for every batch!
external_optimizer.zero_grad()
# Intergrate optimizer and backward function # Intergrate optimizer and backward function
input_data = particle.input_weight_matrix() input_data = particle.input_weight_matrix()
target_data = particle.create_target_weights(input_data) target_data = particle.create_target_weights(input_data)
output = particle(input_data) output = particle(input_data)
losses.append(F.mse_loss(output, target_data)) loss = F.mse_loss(output, target_data)
return torch.hstack(losses).sum(dim=-1, keepdim=True) losses.append(loss.detach)
loss.backward()
# Adjust learning weights
external_optimizer.step()
# return torch.hstack(losses).sum(dim=-1, keepdim=True)
return sum(losses)
if __name__ == '__main__': if __name__ == '__main__':