sparse network redo

This commit is contained in:
Steffen Illium
2022-02-20 21:21:22 +01:00
parent 52081d176e
commit f25cee5203
7 changed files with 365 additions and 270 deletions

View File

@ -1,9 +1,17 @@
# Bureaucratic Cohort Swarms # Bureaucratic Cohort Swarms
### (The Meta-Task Experience) # Deadline: 28.02.22 ### Pruning Networks by SRNN
## Experimente ###### Deadline: 28.02.22
Data Exchange: [Google Drive Folder](***REMOVED***) Data Exchange: [Google Drive Folder](***REMOVED***)
Paper Template: [Overleaf Project](***REMOVED***)
## Experimente
### Fixpoint Tests: ### Fixpoint Tests:
- [X] Dropout Test - [X] Dropout Test

View File

@ -1,40 +0,0 @@
import numpy as np
import torch
import pandas as pd
import re
from pathlib import Path
import seaborn as sns
from matplotlib import pyplot as plt
from network import FixTypes
if __name__ == '__main__':
p = Path(r'experiments\output\mn_st_200_4_alpha_100\trained_model_ckpt_e200.tp')
m = torch.load(p, map_location=torch.device('cpu'))
particles = [y for x in m._meta_layer_list for y in x.particles]
df = pd.DataFrame(columns=['type', 'layer', 'neuron', 'name', 'color'])
colors = []
for particle in particles:
l, c, w = [float(x) for x in re.sub("[^0-9|_]", "", particle.name).split('_')]
color = sns.color_palette()[0 if particle.is_fixpoint == FixTypes.identity_func else 1]
# color = 'orange' if particle.is_fixpoint == FixTypes.identity_func else 'blue'
colors.append(color)
df.loc[df.shape[0]] = (particle.is_fixpoint, l-1, w, particle.name, color)
df.loc[df.shape[0]] = (particle.is_fixpoint, l, c, particle.name, color)
for layer in list(df['layer'].unique()):
divisor = df.loc[(df['layer'] == layer), 'neuron'].max()
df.loc[(df['layer'] == layer), 'neuron'] /= divisor
print('gathered')
for n, (fixtype, color) in enumerate(zip([FixTypes.other_func, FixTypes.identity_func], ['blue', 'orange'])):
plt.clf()
ax = sns.lineplot(y='neuron', x='layer', hue='name', data=df[df['type'] == fixtype],
legend=False, estimator=None,
palette=[sns.color_palette()[n]] * (df[df['type'] == fixtype].shape[0]//2), lw=1)
# ax.set(yscale='log', ylabel='Neuron')
ax.set_title(fixtype)
plt.show()
print('plottet')

View File

@ -1,4 +1,5 @@
import pickle import pickle
import re
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
import sys import sys
@ -17,7 +18,7 @@ from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Compose, Resize from torchvision.transforms import ToTensor, Compose, Resize
from tqdm import tqdm from tqdm import tqdm
# noinspection DuplicatedCode
if platform.node() == 'CarbonX': if platform.node() == 'CarbonX':
debug = True debug = True
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@") print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
@ -37,14 +38,15 @@ else:
DIR = None DIR = None
pass pass
from network import MetaNet, FixTypes from network import MetaNet, FixTypes as ft
from sparse_net import SparseNetwork
from functionalities_test import test_for_fixpoints from functionalities_test import test_for_fixpoints
WORKER = 10 if not debug else 2 WORKER = 10 if not debug else 2
debug = False debug = False
BATCHSIZE = 500 if not debug else 50 BATCHSIZE = 500 if not debug else 50
EPOCH = 200 EPOCH = 200
VALIDATION_FRQ = 5 if not debug else 1 VALIDATION_FRQ = 3 if not debug else 1
SELF_TRAIN_FRQ = 1 if not debug else 1 SELF_TRAIN_FRQ = 1 if not debug else 1
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@ -139,7 +141,7 @@ def plot_training_particle_types(path_to_dataframe):
df = pd.read_csv(path_to_dataframe, index_col=False) df = pd.read_csv(path_to_dataframe, index_col=False)
# Set up figure # Set up figure
fig, ax = plt.subplots() # initializes figure and plots fig, ax = plt.subplots() # initializes figure and plots
data = df[df['Metric'].isin(FixTypes.all_types())] data = df.loc[df['Metric'].isin(ft.all_types())]
fix_types = data['Metric'].unique() fix_types = data['Metric'].unique()
data = data.pivot(index='Epoch', columns='Metric', values='Score').reset_index().fillna(0) data = data.pivot(index='Epoch', columns='Metric', values='Score').reset_index().fillna(0)
_ = plt.stackplot(data['Epoch'], *[data[fixtype] for fixtype in fix_types], labels=fix_types.tolist()) _ = plt.stackplot(data['Epoch'], *[data[fixtype] for fixtype in fix_types], labels=fix_types.tolist())
@ -189,196 +191,253 @@ def plot_training_result(path_to_dataframe):
plt.savefig(Path(path_to_dataframe.parent / 'training_lineplot.png'), dpi=300) plt.savefig(Path(path_to_dataframe.parent / 'training_lineplot.png'), dpi=300)
def plot_network_connectivity_by_fixtype(path_to_trained_model):
m = torch.load(path_to_trained_model, map_location=torch.device('cpu'))
# noinspection PyProtectedMember
particles = [y for x in m._meta_layer_list for y in x.particles]
df = pd.DataFrame(columns=['type', 'layer', 'neuron', 'name'])
for prtcl in particles:
l, c, w = [float(x) for x in re.sub("[^0-9|_]", "", prtcl.name).split('_')]
df.loc[df.shape[0]] = (prtcl.is_fixpoint, l-1, w, prtcl.name)
df.loc[df.shape[0]] = (prtcl.is_fixpoint, l, c, prtcl.name)
for layer in list(df['layer'].unique()):
# Rescale
divisor = df.loc[(df['layer'] == layer), 'neuron'].max()
df.loc[(df['layer'] == layer), 'neuron'] /= divisor
print('gathered')
for n, fixtype in enumerate([ft.other_func, ft.identity_func]):
plt.clf()
ax = sns.lineplot(y='neuron', x='layer', hue='name', data=df[df['type'] == fixtype],
legend=False, estimator=None,
palette=[sns.color_palette()[n]] * (df[df['type'] == fixtype].shape[0]//2), lw=1)
ax.set_title(fixtype)
plt.show()
print('plottet')
def run_particle_dropout_test(run_path):
diff_store_path = run_path / 'diff_store.csv'
prtcl_dict = defaultdict(lambda: 0)
_ = test_for_fixpoints(prtcl_dict, list(latest_model.particles))
tqdm.write(str(dict(prtcl_dict)))
acc_pre = validate(model_path, ratio=1).item()
diff_df = pd.DataFrame(columns=['Particle Type', 'Accuracy', 'Diff'])
for fixpoint_type in ft.all_types():
new_model = torch.load(model_path, map_location=DEVICE).eval().replace_with_zero(fixpoint_type)
if [x for x in new_model.particles if x.is_fixpoint == fixpoint_type]:
new_ckpt = set_checkpoint(new_model, model_path.parent, fixpoint_type, final_model=True)
acc_post = validate(new_ckpt, ratio=1).item()
acc_diff = abs(acc_post - acc_pre)
tqdm.write(f'Zero_ident diff = {acc_diff}')
diff_df.loc[diff_df.shape[0]] = (fixpoint_type, acc_post, acc_diff)
diff_df.to_csv(diff_store_path, mode='a', header=not df_store_path.exists(), index=False)
return diff_store_path
def plot_dropout_stacked_barplot(path_to_diff_df):
diff_df = pd.read_csv(path_to_diff_df)
particle_dict = defaultdict(lambda: 0)
_ = test_for_fixpoints(particle_dict, list(latest_model.particles))
tqdm.write(str(dict(particle_dict)))
plt.clf()
fig, ax = plt.subplots(ncols=2)
colors = sns.color_palette()[:diff_df.shape[0]]
barplot = sns.barplot(data=diff_df, y='Accuracy', x='Particle Type', palette=colors, ax=ax[0])
# noinspection PyUnboundLocalVariable
for idx, patch in enumerate(barplot.patches):
if idx != 0:
# we recenter the bar
patch.set_x(patch.get_x() + idx * 0.035)
ax[0].set_title('Accuracy after particle dropout')
ax[0].set_xlabel('Accuracy')
ax[1].pie(particle_dict.values(), labels=particle_dict.keys(), colors=colors, )
ax[1].set_title('Particle Count for ')
plt.tight_layout()
if debug:
plt.show()
else:
plt.savefig(Path(path_to_diff_df.parent / 'dropout_stacked_barplot.png'), dpi=300)
def run_particle_dropout_and_plot(run_path):
diff_store_path = run_particle_dropout_test(run_path)
plot_dropout_stacked_barplot(diff_store_path)
def flat_for_store(parameters): def flat_for_store(parameters):
return (x.item() for y in parameters for x in y.detach().flatten()) return (x.item() for y in parameters for x in y.detach().flatten())
if __name__ == '__main__': if __name__ == '__main__':
use_sparse_implementation = True
self_train = True self_train = True
training = False training = True
plotting = True
particle_analysis = True
as_sparse_network_test = True
train_to_id_first = False train_to_id_first = False
self_train_alpha = 100 train_to_task_first = False
train_to_task_first_sequential = True
tsk_threshold = 0.855
self_train_alpha = 1
batch_train_beta = 1 batch_train_beta = 1
weight_hidden_size = 4 weight_hidden_size = 3
residual_skip = True residual_skip = True
dropout = 0 n_seeds = 2
data_path = Path('data') data_path = Path('data')
data_path.mkdir(exist_ok=True, parents=True) data_path.mkdir(exist_ok=True, parents=True)
assert not (train_to_task_first and train_to_id_first)
st_str = f'{"" if self_train else "no_"}st' st_str = f'{"" if self_train else "no_"}st'
a_str = f'_alpha_{self_train_alpha}' if self_train_alpha != 1 else '' a_str = f'_alpha_{self_train_alpha}' if self_train_alpha != 1 else ''
res_str = f'{"" if residual_skip else "_no"}_res' res_str = f'{"" if residual_skip else "_no_res"}'
dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}' # dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}'
id_str = f'{f"_StToId" if train_to_id_first else ""}' id_str = f'{f"_StToId" if train_to_id_first else ""}'
run_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{a_str}{res_str}{dr_str}{id_str}' tsk_str = f'{f"_Tsk_{tsk_threshold}" if train_to_task_first else ""}'
exp_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{a_str}{res_str}{id_str}{tsk_str}'
model_path = run_path / '0000_trained_model.zip' if use_sparse_implementation:
df_store_path = run_path / 'train_store.csv' metanet_class = SparseNetwork
weight_store_path = run_path / 'weight_store.csv' else:
srnn_parameters = dict() metanet_class = MetaNet
if training: for seed in range(n_seeds):
utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)]) seed_path = exp_path / str(seed)
try:
dataset = MNIST(str(data_path), transform=utility_transforms)
except RuntimeError:
dataset = MNIST(str(data_path), transform=utility_transforms, download=True)
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
interface = np.prod(dataset[0][0].shape) model_path = seed_path / '0000_trained_model.zip'
metanet = MetaNet(interface, depth=5, width=6, out=10, residual_skip=residual_skip, dropout=dropout, df_store_path = seed_path / 'train_store.csv'
weight_hidden_size=weight_hidden_size, weight_store_path = seed_path / 'weight_store.csv'
).to(DEVICE) srnn_parameters = dict()
meta_weight_count = sum(p.numel() for p in next(metanet.particles).parameters())
loss_fn = nn.CrossEntropyLoss() if training:
optimizer = torch.optim.SGD(metanet.parameters(), lr=0.008, momentum=0.9) utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)])
try:
dataset = MNIST(str(data_path), transform=utility_transforms)
except RuntimeError:
dataset = MNIST(str(data_path), transform=utility_transforms, download=True)
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
train_store = new_storage_df('train', None) interface = np.prod(dataset[0][0].shape)
weight_store = new_storage_df('weights', meta_weight_count) metanet = metanet_class(interface, depth=5, width=6, out=10, residual_skip=residual_skip,
for epoch in tqdm(range(EPOCH), desc='MetaNet Train - Epochs'): weight_hidden_size=weight_hidden_size,).to(DEVICE)
is_validation_epoch = epoch % VALIDATION_FRQ == 0 if not debug else True meta_weight_count = sum(p.numel() for p in next(metanet.particles).parameters())
is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True
metanet = metanet.train()
if is_validation_epoch:
metric = torchmetrics.Accuracy()
else:
metric = None
init_st = train_to_id_first and all(x.is_fixpoint == FixTypes.identity_func for x in metanet.particles)
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
if (self_train and is_self_train_epoch) or init_st:
# Zero your gradients for every batch!
optimizer.zero_grad()
self_train_loss = metanet.combined_self_train() * self_train_alpha
self_train_loss.backward()
# Adjust learning weights
optimizer.step()
step_log = dict(Epoch=epoch, Batch=batch, Metric='Self Train Loss', Score=self_train_loss.item())
train_store.loc[train_store.shape[0]] = step_log
if train_to_id_first <= epoch:
# Zero your gradients for every batch!
optimizer.zero_grad()
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
y = metanet(batch_x)
# loss = loss_fn(y, batch_y.unsqueeze(-1).to(torch.float32))
loss = loss_fn(y, batch_y.to(torch.long)) * batch_train_beta
loss.backward()
# Adjust learning weights loss_fn = nn.CrossEntropyLoss()
optimizer.step() optimizer = torch.optim.SGD(metanet.parameters(), lr=0.008, momentum=0.9)
step_log = dict(Epoch=epoch, Batch=batch, train_store = new_storage_df('train', None)
Metric='Task Loss', Score=loss.item()) weight_store = new_storage_df('weights', meta_weight_count)
train_store.loc[train_store.shape[0]] = step_log init_tsk = train_to_task_first
if is_validation_epoch: for epoch in tqdm(range(EPOCH), desc='MetaNet Train - Epochs'):
metric(y.cpu(), batch_y.cpu()) is_validation_epoch = epoch % VALIDATION_FRQ == 0 if not debug else True
is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True
metanet = metanet.train()
if is_validation_epoch:
metric = torchmetrics.Accuracy()
else:
metric = None
init_st = train_to_id_first and not all(x.is_fixpoint == ft.identity_func for x in metanet.particles)
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
if self_train and not init_tsk and (is_self_train_epoch or init_st):
# Zero your gradients for every batch!
optimizer.zero_grad()
self_train_loss = metanet.combined_self_train() * self_train_alpha
self_train_loss.backward()
# Adjust learning weights
optimizer.step()
step_log = dict(Epoch=epoch, Batch=batch,
Metric='Self Train Loss', Score=self_train_loss.item())
train_store.loc[train_store.shape[0]] = step_log
if not init_st:
# Zero your gradients for every batch!
optimizer.zero_grad()
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
y_pred = metanet(batch_x)
# loss = loss_fn(y, batch_y.unsqueeze(-1).to(torch.float32))
loss = loss_fn(y_pred, batch_y.to(torch.long)) * batch_train_beta
loss.backward()
if batch >= 3 and debug: # Adjust learning weights
break optimizer.step()
if is_validation_epoch: step_log = dict(Epoch=epoch, Batch=batch,
metanet = metanet.eval() Metric='Task Loss', Score=loss.item())
if train_to_id_first <= epoch: train_store.loc[train_store.shape[0]] = step_log
if is_validation_epoch:
metric(y_pred.cpu(), batch_y.cpu())
if batch >= 3 and debug:
break
if is_validation_epoch:
metanet = metanet.eval()
if train_to_id_first <= epoch:
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric='Train Accuracy', Score=metric.compute().item())
train_store.loc[train_store.shape[0]] = validation_log
accuracy = checkpoint_and_validate(metanet, seed_path, epoch).item()
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric='Train Accuracy', Score=metric.compute().item()) Metric='Test Accuracy', Score=accuracy)
train_store.loc[train_store.shape[0]] = validation_log train_store.loc[train_store.shape[0]] = validation_log
if init_tsk or (train_to_task_first and train_to_task_first_sequential):
init_tsk = accuracy <= tsk_threshold
if init_st or is_validation_epoch:
counter_dict = defaultdict(lambda: 0)
# This returns ID-functions
_ = test_for_fixpoints(counter_dict, list(metanet.particles))
for key, value in dict(counter_dict).items():
step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value)
train_store.loc[train_store.shape[0]] = step_log
if init_st or is_validation_epoch:
for particle in metanet.particles:
weight_log = (epoch, particle.name, *flat_for_store(particle.parameters()))
weight_store.loc[weight_store.shape[0]] = weight_log
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False)
weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
train_store = new_storage_df('train', None)
weight_store = new_storage_df('weights', meta_weight_count)
accuracy = checkpoint_and_validate(metanet, run_path, epoch) metanet.eval()
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric='Test Accuracy', Score=accuracy.item())
train_store.loc[train_store.shape[0]] = validation_log
if particle_analysis and (init_st or is_validation_epoch):
counter_dict = defaultdict(lambda: 0)
# This returns ID-functions
_ = test_for_fixpoints(counter_dict, list(metanet.particles))
for key, value in dict(counter_dict).items():
step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value)
train_store.loc[train_store.shape[0]] = step_log
if init_st or is_validation_epoch:
for particle in metanet.particles:
weight_log = (epoch, particle.name, *flat_for_store(particle.parameters()))
weight_store.loc[weight_store.shape[0]] = weight_log
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False)
weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
train_store = new_storage_df('train', None)
weight_store = new_storage_df('weights', meta_weight_count)
metanet.eval()
if particle_analysis:
counter_dict = defaultdict(lambda: 0) counter_dict = defaultdict(lambda: 0)
# This returns ID-functions # This returns ID-functions
_ = test_for_fixpoints(counter_dict, list(metanet.particles)) _ = test_for_fixpoints(counter_dict, list(metanet.particles))
for key, value in dict(counter_dict).items(): for key, value in dict(counter_dict).items():
step_log = dict(Epoch=int(EPOCH), Batch=BATCHSIZE, Metric=key, Score=value) step_log = dict(Epoch=int(EPOCH), Batch=BATCHSIZE, Metric=key, Score=value)
train_store.loc[train_store.shape[0]] = step_log train_store.loc[train_store.shape[0]] = step_log
accuracy = checkpoint_and_validate(metanet, run_path, EPOCH, final_model=True) accuracy = checkpoint_and_validate(metanet, seed_path, EPOCH, final_model=True)
validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE, validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
Metric='Test Accuracy', Score=accuracy.item()) Metric='Test Accuracy', Score=accuracy.item())
for particle in metanet.particles: for particle in metanet.particles:
weight_log = (EPOCH, particle.name, *(flat_for_store(particle.parameters()))) weight_log = (EPOCH, particle.name, *(flat_for_store(particle.parameters())))
weight_store.loc[weight_store.shape[0]] = weight_log weight_store.loc[weight_store.shape[0]] = weight_log
train_store.loc[train_store.shape[0]] = validation_log train_store.loc[train_store.shape[0]] = validation_log
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False) train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False)
weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False) weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
if plotting:
plot_training_result(df_store_path) plot_training_result(df_store_path)
if particle_analysis: plot_training_particle_types(df_store_path)
plot_training_particle_types(df_store_path)
if particle_analysis:
try: try:
model_path = next(run_path.glob(f'*e{EPOCH}.tp')) model_path = next(seed_path.glob(f'*e{EPOCH}.tp'))
except StopIteration: except StopIteration:
print('Model pattern did not trigger.') print('Model pattern did not trigger.')
print(f'Search path was: {run_path}:') print(f'Search path was: {seed_path}:')
print(f'Found Models are: {list(run_path.rglob(".tp"))}') print(f'Found Models are: {list(seed_path.rglob(".tp"))}')
exit(1) exit(1)
latest_model = torch.load(model_path, map_location=DEVICE).eval() latest_model = torch.load(model_path, map_location=DEVICE).eval()
counter_dict = defaultdict(lambda: 0)
_ = test_for_fixpoints(counter_dict, list(latest_model.particles))
tqdm.write(str(dict(counter_dict)))
if as_sparse_network_test: run_particle_dropout_and_plot(seed_path)
acc_pre = validate(model_path, ratio=1).item() plot_network_connectivity_by_fixtype(model_path)
diff_df = pd.DataFrame(columns=['Particle Type', 'Accuracy', 'Diff'])
for fixpoint_type in FixTypes.all_types():
new_model = torch.load(model_path, map_location=DEVICE).eval().replace_with_zero(fixpoint_type)
if [x for x in new_model.particles if x.is_fixpoint == fixpoint_type]:
new_ckpt = set_checkpoint(new_model, model_path.parent, fixpoint_type, final_model=True)
acc_post = validate(new_ckpt, ratio=1).item()
acc_diff = abs(acc_post-acc_pre)
tqdm.write(f'Zero_ident diff = {acc_diff}')
diff_df.loc[diff_df.shape[0]] = (fixpoint_type, acc_post, acc_diff)
if plotting: if n_seeds >= 2:
plt.clf() pass
fig, ax = plt.subplots(ncols=2)
labels = ['Full Network', 'Sparse, No Identity', 'Sparse, No Other']
colors = sns.color_palette()[:diff_df.shape[0]] if diff_df.shape[0] >= 2 else sns.color_palette()[0]
barplot = sns.barplot(data=diff_df, y='Accuracy', x='Particle Type', palette=colors, ax=ax[0])
# noinspection PyUnboundLocalVariable
for idx, patch in enumerate(barplot.patches):
if idx != 0:
# we recenter the bar
patch.set_x(patch.get_x() + idx * 0.035)
ax[0].set_title('Accuracy after particle dropout')
ax[0].set_xlabel('Accuracy')
# ax[0].legend()
ax[1].pie(counter_dict.values(), labels=counter_dict.keys(), colors=colors, )
ax[1].set_title('Particle Count for ')
# ax[1].set_xlabel('')
plt.tight_layout()
if debug:
plt.show()
else:
plt.savefig(Path(run_path / 'dropout_stacked_barplot.png'), dpi=300)

View File

View File

@ -68,7 +68,6 @@ class Net(nn.Module):
for weight_id, weight_value in enumerate(self.state_dict()[layer_name][line_id]): for weight_id, weight_value in enumerate(self.state_dict()[layer_name][line_id]):
self.state_dict()[layer_name][line_id][weight_id] = new_weights[i] self.state_dict()[layer_name][line_id][weight_id] = new_weights[i]
i += 1 i += 1
return self return self
def __init__(self, i_size: int, h_size: int, o_size: int, name=None, start_time=1) -> None: def __init__(self, i_size: int, h_size: int, o_size: int, name=None, start_time=1) -> None:
@ -100,7 +99,6 @@ class Net(nn.Module):
self._weight_pos_enc_and_mask = None self._weight_pos_enc_and_mask = None
@property @property
def _weight_pos_enc(self): def _weight_pos_enc(self):
if self._weight_pos_enc_and_mask is None: if self._weight_pos_enc_and_mask is None:
@ -127,8 +125,8 @@ class Net(nn.Module):
# Normalize 1,2,3 column of dim 1 # Normalize 1,2,3 column of dim 1
last_pos_idx = self.input_size - 4 last_pos_idx = self.input_size - 4
norm2 = weight_matrix[:, 1:-last_pos_idx].pow(2).sum(keepdim=True, dim=0).sqrt() max_per_col, _ = weight_matrix[:, 1:-last_pos_idx].max(keepdim=True, dim=0)
weight_matrix[:, 1:-last_pos_idx] = (weight_matrix[:, 1:-last_pos_idx] / norm2) + 1e-8 weight_matrix[:, 1:-last_pos_idx] = (weight_matrix[:, 1:-last_pos_idx] / max_per_col) + 1e-8
# computations # computations
# create a mask where pos is 0 if it is to be replaced # create a mask where pos is 0 if it is to be replaced
@ -389,6 +387,7 @@ class MetaNet(nn.Module):
def __init__(self, interface=4, depth=3, width=4, out=1, activation=None, residual_skip=True, dropout=0, def __init__(self, interface=4, depth=3, width=4, out=1, activation=None, residual_skip=True, dropout=0,
weight_interface=5, weight_hidden_size=2, weight_output_size=1,): weight_interface=5, weight_hidden_size=2, weight_output_size=1,):
super().__init__() super().__init__()
self.residual_skip = residual_skip
self.dropout = dropout self.dropout = dropout
self.activation = activation self.activation = activation
self.out = out self.out = out
@ -398,7 +397,6 @@ class MetaNet(nn.Module):
self.weight_interface = weight_interface self.weight_interface = weight_interface
self.weight_hidden_size = weight_hidden_size self.weight_hidden_size = weight_hidden_size
self.weight_output_size = weight_output_size self.weight_output_size = weight_output_size
self._meta_layer_first = MetaLayer(name=f'L{0}', self._meta_layer_first = MetaLayer(name=f'L{0}',
interface=self.interface, interface=self.interface,
width=self.width, width=self.width,
@ -411,6 +409,7 @@ class MetaNet(nn.Module):
weight_interface=weight_interface, weight_interface=weight_interface,
weight_hidden_size=weight_hidden_size, weight_hidden_size=weight_hidden_size,
weight_output_size=weight_output_size, weight_output_size=weight_output_size,
) for layer_idx in range(self.depth - 2)] ) for layer_idx in range(self.depth - 2)]
) )
self._meta_layer_last = MetaLayer(name=f'L{len(self._meta_layer_list)}', self._meta_layer_last = MetaLayer(name=f'L{len(self._meta_layer_list)}',
@ -441,10 +440,10 @@ class MetaNet(nn.Module):
for idx, meta_layer in enumerate(self._meta_layer_list, start=1): for idx, meta_layer in enumerate(self._meta_layer_list, start=1):
if self.dropout != 0: if self.dropout != 0:
tensor = self.dropout_layer(tensor) tensor = self.dropout_layer(tensor)
if idx % 2 == 1: if idx % 2 == 1 and self.residual_skip:
x = tensor.clone() x = tensor.clone()
tensor = meta_layer(tensor) tensor = meta_layer(tensor)
if idx % 2 == 0: if idx % 2 == 0 and self.residual_skip:
tensor = tensor + x tensor = tensor + x
if self.dropout != 0: if self.dropout != 0:
x = self.dropout_layer(x) x = self.dropout_layer(x)

View File

@ -56,7 +56,7 @@ if __name__ == '__main__':
d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER) d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER)
loss_fn = nn.CrossEntropyLoss() loss_fn = nn.CrossEntropyLoss()
model = torch.load("trained_model_ckpt_e200.tp", map_location=DEVICE).eval() model = torch.load("mn_st_40_6_res_Tsk_0.85", map_location=DEVICE).eval()
weights = extract_weights_from_model(model) weights = extract_weights_from_model(model)
test_weights_as_model(weights, d_test) test_weights_as_model(weights, d_test)

View File

@ -1,85 +1,114 @@
from torch import nn
from network import Net from network import Net
from typing import List
from functionalities_test import is_identity_function from functionalities_test import is_identity_function
from tqdm import tqdm,trange from tqdm import tqdm,trange
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
import torch import torch
from torch.nn import Flatten from torch.nn import Flatten
from torch.utils.data import Dataset, DataLoader from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision.datasets import MNIST from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Compose, Resize from torchvision.transforms import ToTensor, Compose, Resize
class SparseLayer(): class SparseLayer(nn.Module):
def __init__(self, nr_nets, interface=5, depth=3, width=2, out=1): def __init__(self, nr_nets, interface=5, depth=3, width=2, out=1):
super(SparseLayer, self).__init__()
self.nr_nets = nr_nets self.nr_nets = nr_nets
self.interface_dim = interface self.interface_dim = interface
self.depth_dim = depth self.depth_dim = depth
self.hidden_dim = width self.hidden_dim = width
self.out_dim = out self.out_dim = out
self.dummy_net = Net(self.interface_dim, self.hidden_dim, self.out_dim) self.dummy_net = Net(self.interface_dim, self.hidden_dim, self.out_dim)
self.sparse_sub_layer = [] self.sparse_sub_layer = list()
self.weights = [] self.indices = list()
for layer_id in range(depth): self.diag_shapes = list()
layer, weights = self.coo_sparse_layer(layer_id) self.weights = nn.ParameterList()
self.sparse_sub_layer.append(layer) self._particles = None
for layer_id in range(self.depth_dim):
indices, weights, diag_shape = self.coo_sparse_layer(layer_id)
self.indices.append(indices)
self.diag_shapes.append(diag_shape)
self.weights.append(weights) self.weights.append(weights)
def coo_sparse_layer(self, layer_id): def coo_sparse_layer(self, layer_id):
layer_shape = list(self.dummy_net.parameters())[layer_id].shape layer_shape = list(self.dummy_net.parameters())[layer_id].shape
#print(layer_shape) #(out_cells, in_cells) -> (2,5), (2,2), (1,2)
sparse_diagonal = np.eye(self.nr_nets).repeat(layer_shape[0], axis=-2).repeat(layer_shape[1], axis=-1) sparse_diagonal = np.eye(self.nr_nets).repeat(layer_shape[0], axis=-2).repeat(layer_shape[1], axis=-1)
indices = np.argwhere(sparse_diagonal == 1).T indices = torch.Tensor(np.argwhere(sparse_diagonal == 1).T)
values = torch.nn.Parameter(torch.randn((self.nr_nets * (layer_shape[0]*layer_shape[1]) ))) values = torch.nn.Parameter(
#values = torch.randn((self.nr_nets * layer_shape[0]*layer_shape[1] )) torch.randn((self.nr_nets * (layer_shape[0]*layer_shape[1]))), requires_grad=True
s = torch.sparse_coo_tensor(indices, values, sparse_diagonal.shape, requires_grad=True) )
print(f"L{layer_id}:", s.shape)
return s, values return indices, values, sparse_diagonal.shape
def get_self_train_inputs_and_targets(self): def get_self_train_inputs_and_targets(self):
encoding_matrix, mask = self.dummy_net._weight_pos_enc encoding_matrix, mask = self.dummy_net._weight_pos_enc
# view weights of each sublayer in equal chunks, each column representing weights of one selfrepNN # view weights of each sublayer in equal chunks, each column representing weights of one selfrepNN
# i.e., first interface*hidden weights of layer1, first hidden*hidden weights of layer2 and first hidden*out weights of layer3 = first net # i.e., first interface*hidden weights of layer1, first hidden*hidden weights of layer2
weights = [layer.view(-1, int(len(layer)/self.nr_nets)) for layer in self.weights] #[nr_layers*[nr_net*nr_weights_layer_i]] # and first hidden*out weights of layer3 = first net
weights_per_net = [torch.cat([layer[i] for layer in weights]).view(-1,1) for i in range(self.nr_nets)] #[nr_net*[nr_weights]] # [nr_layers*[nr_net*nr_weights_layer_i]]
inputs = torch.hstack([encoding_matrix * mask + weights_per_net[i].expand(-1, encoding_matrix.shape[-1]) * (1 - mask) for i in range(self.nr_nets)]) #(16, 25) weights = [layer.view(-1, int(len(layer)/self.nr_nets)) for layer in self.weights]
# [nr_net*[nr_weights]]
weights_per_net = [torch.cat([layer[i] for layer in weights]).view(-1, 1) for i in range(self.nr_nets)]
# (16, 25)
inputs = torch.hstack(
[encoding_matrix * mask + weights_per_net[i].expand(-1, encoding_matrix.shape[-1]) * (1 - mask)
for i in range(self.nr_nets)]
)
targets = torch.hstack(weights_per_net) targets = torch.hstack(weights_per_net)
return inputs.T, targets.T return inputs.T.detach(), targets.T.detach()
@property
def particles(self):
if self._particles is None:
self._particles = [Net(self.interface_dim, self.hidden_dim, self.out_dim) for _ in range(self.nr_nets)]
pass
else:
pass
# Particle Weight Update
all_weights = [layer.view(-1, int(len(layer) / self.nr_nets)) for layer in self.weights]
weights_per_net = [torch.cat([layer[i] for layer in all_weights]).view(-1, 1) for i in
range(self.nr_nets)] # [nr_net*[nr_weights]]
for particles, weights in zip(self._particles, weights_per_net):
particles.apply_weights(weights)
return self._particles
def __call__(self, x): def __call__(self, x):
X1 = torch.sparse.mm(self.sparse_sub_layer[0], x) for indices, diag_shapes, weights in zip(self.indices, self.diag_shapes, self.weights):
#print("X1", X1.shape) s = torch.sparse_coo_tensor(indices, weights, diag_shapes, requires_grad=True, device=x.device)
x = torch.sparse.mm(s, x)
return x
X2 = torch.sparse.mm(self.sparse_sub_layer[1], X1) def to(self, *args, **kwargs):
#print("X2", X2.shape) super(SparseLayer, self).to(*args, **kwargs)
self.sparse_sub_layer = [sparse_sub_layer.to(*args, **kwargs) for sparse_sub_layer in self.sparse_sub_layer]
X3 = torch.sparse.mm(self.sparse_sub_layer[2], X2) return self
#print("X3", X3.shape)
return X3
def test_sparse_layer(): def test_sparse_layer():
net = SparseLayer(500) #50 parallel nets net = SparseLayer(500) #50 parallel nets
loss_fn = torch.nn.MSELoss(reduction="sum") loss_fn = torch.nn.MSELoss(reduction="sum")
optimizer = torch.optim.SGD([weight for weight in net.weights], lr=0.004, momentum=0.9) optimizer = torch.optim.SGD(net.weights, lr=0.004, momentum=0.9)
#optimizer = torch.optim.SGD([layer for layer in net.sparse_sub_layer], lr=0.004, momentum=0.9) # optimizer = torch.optim.SGD([layer.coalesce().values() for layer in net.sparse_sub_layer], lr=0.004, momentum=0.9)
for train_iteration in trange(1000): for train_iteration in trange(1000):
optimizer.zero_grad() optimizer.zero_grad()
X,Y = net.get_self_train_inputs_and_targets() X,Y = net.get_self_train_inputs_and_targets()
out = net(X) out = net(X)
loss = loss_fn(out, Y) loss = loss_fn(out, Y)
# print("X:", X.shape, "Y:", Y.shape) # print("X:", X.shape, "Y:", Y.shape)
# print("OUT", out.shape) # print("OUT", out.shape)
# print("LOSS", loss.item()) # print("LOSS", loss.item())
loss.backward(retain_graph=True) loss.backward(retain_graph=True)
optimizer.step() optimizer.step()
@ -88,54 +117,95 @@ def test_sparse_layer():
print(f"identity_fn after {train_iteration+1} self-train iterations: {sum([torch.allclose(out[i], Y[i], rtol=0, atol=epsilon) for i in range(net.nr_nets)])}/{net.nr_nets}") print(f"identity_fn after {train_iteration+1} self-train iterations: {sum([torch.allclose(out[i], Y[i], rtol=0, atol=epsilon) for i in range(net.nr_nets)])}/{net.nr_nets}")
def embed_batch(x, repeat_dim): def embed_batch(x, repeat_dim):
# x of shape (batchsize, flat_img_dim) # x of shape (batchsize, flat_img_dim)
x = x.unsqueeze(-1) #(batchsize, flat_img_dim, 1) x = x.unsqueeze(-1) #(batchsize, flat_img_dim, 1)
return torch.cat( (torch.zeros( x.shape[0], x.shape[1], 4), x), dim=2).repeat(1,1,repeat_dim) #(batchsize, flat_img_dim, encoding_dim*repeat_dim) return torch.cat((torch.zeros(x.shape[0], x.shape[1], 4, device=x.device), x), dim=2).repeat(1, 1, repeat_dim) #(batchsize, flat_img_dim, encoding_dim*repeat_dim)
def embed_vector(x, repeat_dim): def embed_vector(x, repeat_dim):
# x of shape [flat_img_dim] # x of shape [flat_img_dim]
x = x.unsqueeze(-1) #(flat_img_dim, 1) x = x.unsqueeze(-1) # (flat_img_dim, 1)
return torch.cat( (torch.zeros( x.shape[0], 4), x), dim=1).repeat(1,repeat_dim) #(flat_img_dim, encoding_dim*repeat_dim) # (flat_img_dim, encoding_dim*repeat_dim)
return torch.cat((torch.zeros(x.shape[0], 4), x), dim=1).repeat(1,repeat_dim)
class SparseNetwork():
def __init__(self, input_dim, depth, width, out): class SparseNetwork(nn.Module):
def __init__(self, input_dim, depth, width, out, residual_skip=True,
weight_interface=5, weight_hidden_size=2, weight_output_size=1
):
super(SparseNetwork, self).__init__()
self.residual_skip = residual_skip
self.input_dim = input_dim self.input_dim = input_dim
self.depth_dim = depth self.depth_dim = depth
self.hidden_dim = width self.hidden_dim = width
self.out_dim = out self.out_dim = out
self.sparse_layers = [] self.first_layer = SparseLayer(self.input_dim * self.hidden_dim,
self.sparse_layers.append( SparseLayer( self.input_dim * self.hidden_dim )) interface=weight_interface, width=weight_hidden_size, out=weight_output_size)
self.sparse_layers.extend([ SparseLayer( self.hidden_dim * self.hidden_dim ) for layer_idx in range(self.depth_dim - 2)]) self.last_layer = SparseLayer(self.hidden_dim * self.out_dim,
self.sparse_layers.append( SparseLayer( self.hidden_dim * self.out_dim )) interface=weight_interface, width=weight_hidden_size, out=weight_output_size)
self.hidden_layers = nn.ModuleList([
SparseLayer(self.hidden_dim * self.hidden_dim,
interface=weight_interface, width=weight_hidden_size, out=weight_output_size
) for _ in range(self.depth_dim - 2)])
def __call__(self, x): def __call__(self, x):
for sparse_layer in self.sparse_layers[:-1]: tensor = self.sparse_layer_forward(x, self.first_layer)
# batch pass (one by one, sparse bmm doesn't support grad) for nl_idx, network_layer in enumerate(self.hidden_layers):
if len(x.shape) > 1: if nl_idx % 2 == 0 and self.residual_skip:
embedded_inpt = embed_batch(x, sparse_layer.nr_nets) residual = tensor.clone()
x = torch.stack([sparse_layer(inpt.T).sum(dim=1).view(self.hidden_dim, x.shape[1]).sum(dim=1) for inpt in embedded_inpt]) #[batchsize, hidden*inpt_dim, feature_dim] # Sparse Layer pass
# vector tensor = self.sparse_layer_forward(tensor, network_layer)
else:
embedded_inpt = embed_vector(x, sparse_layer.nr_nets) if nl_idx % 2 != 0 and self.residual_skip:
x = sparse_layer(embedded_inpt.T).sum(dim=1).view(self.hidden_dim, x.shape[1]).sum(dim=1) # noinspection PyUnboundLocalVariable
print("out", x.shape) tensor += residual
tensor = self.sparse_layer_forward(tensor, self.last_layer, view_dim=self.out_dim)
# output layer return tensor
sparse_layer = self.sparse_layers[-1]
def sparse_layer_forward(self, x, sparse_layer, view_dim=None):
view_dim = view_dim if view_dim else self.hidden_dim
# batch pass (one by one, sparse bmm doesn't support grad)
if len(x.shape) > 1: if len(x.shape) > 1:
embedded_inpt = embed_batch(x, sparse_layer.nr_nets) embedded_inpt = embed_batch(x, sparse_layer.nr_nets)
x = torch.stack([sparse_layer(inpt.T).sum(dim=1).view(self.out_dim, x.shape[1]).sum(dim=1) for inpt in embedded_inpt]) #[batchsize, hidden*inpt_dim, feature_dim] # [batchsize, hidden*inpt_dim, feature_dim]
x = torch.stack([sparse_layer(inpt.T).sum(dim=1).view(view_dim, x.shape[1]).sum(dim=1) for inpt in
embedded_inpt])
# vector
else: else:
embedded_inpt = embed_vector(x, sparse_layer.nr_nets) embedded_inpt = embed_vector(x, sparse_layer.nr_nets)
x = sparse_layer(embedded_inpt.T).sum(dim=1).view(self.out_dim, x.shape[1]).sum(dim=1) x = sparse_layer(embedded_inpt.T).sum(dim=1).view(view_dim, x.shape[1]).sum(dim=1)
print("out", x.shape)
return x return x
@property
def particles(self):
particles = []
particles.extend(self.first_layer.particles)
for layer in self.hidden_layers:
particles.extend(layer.particles)
particles.extend(self.last_layer.particles)
return iter(particles)
def to(self, *args, **kwargs):
super(SparseNetwork, self).to(*args, **kwargs)
self.first_layer = self.first_layer.to(*args, **kwargs)
self.last_layer = self.last_layer.to(*args, **kwargs)
self.hidden_layers = nn.ModuleList([hidden_layer.to(*args, **kwargs) for hidden_layer in self.hidden_layers])
return self
def combined_self_train(self):
import time
t = time.time()
losses = []
for layer in [self.first_layer, *self.hidden_layers, self.last_layer]:
x, target_data = layer.get_self_train_inputs_and_targets()
output = layer(x)
losses.append(F.mse_loss(output, target_data))
print('Time Taken:', time.time() - t)
return torch.hstack(losses).sum(dim=-1, keepdim=True)
def test_sparse_net(): def test_sparse_net():
utility_transforms = Compose([ Resize((10, 10)), ToTensor(), Flatten(start_dim=0)]) utility_transforms = Compose([ Resize((10, 10)), ToTensor(), Flatten(start_dim=0)])
@ -150,7 +220,6 @@ def test_sparse_net():
data_dim = np.prod(dataset[0][0].shape) data_dim = np.prod(dataset[0][0].shape)
metanet = SparseNetwork(data_dim, depth=3, width=5, out=10) metanet = SparseNetwork(data_dim, depth=3, width=5, out=10)
batchx, batchy = next(iter(d)) batchx, batchy = next(iter(d))
batchx.shape, batchy.shape
metanet(batchx) metanet(batchx)
@ -176,6 +245,6 @@ def test_manual_for_loop():
if __name__ == '__main__': if __name__ == '__main__':
test_sparse_layer() test_sparse_layer()
test_sparse_net() # test_sparse_net()
#for comparison # for comparison
test_manual_for_loop() test_manual_for_loop()