sparse network redo

This commit is contained in:
Steffen Illium
2022-02-20 21:21:22 +01:00
parent 52081d176e
commit f25cee5203
7 changed files with 365 additions and 270 deletions

View File

@ -1,9 +1,17 @@
# Bureaucratic Cohort Swarms
### (The Meta-Task Experience) # Deadline: 28.02.22
## Experimente
### Pruning Networks by SRNN
###### Deadline: 28.02.22
Data Exchange: [Google Drive Folder](***REMOVED***)
Paper Template: [Overleaf Project](***REMOVED***)
## Experimente
### Fixpoint Tests:
- [X] Dropout Test

View File

@ -1,40 +0,0 @@
import numpy as np
import torch
import pandas as pd
import re
from pathlib import Path
import seaborn as sns
from matplotlib import pyplot as plt
from network import FixTypes
if __name__ == '__main__':
p = Path(r'experiments\output\mn_st_200_4_alpha_100\trained_model_ckpt_e200.tp')
m = torch.load(p, map_location=torch.device('cpu'))
particles = [y for x in m._meta_layer_list for y in x.particles]
df = pd.DataFrame(columns=['type', 'layer', 'neuron', 'name', 'color'])
colors = []
for particle in particles:
l, c, w = [float(x) for x in re.sub("[^0-9|_]", "", particle.name).split('_')]
color = sns.color_palette()[0 if particle.is_fixpoint == FixTypes.identity_func else 1]
# color = 'orange' if particle.is_fixpoint == FixTypes.identity_func else 'blue'
colors.append(color)
df.loc[df.shape[0]] = (particle.is_fixpoint, l-1, w, particle.name, color)
df.loc[df.shape[0]] = (particle.is_fixpoint, l, c, particle.name, color)
for layer in list(df['layer'].unique()):
divisor = df.loc[(df['layer'] == layer), 'neuron'].max()
df.loc[(df['layer'] == layer), 'neuron'] /= divisor
print('gathered')
for n, (fixtype, color) in enumerate(zip([FixTypes.other_func, FixTypes.identity_func], ['blue', 'orange'])):
plt.clf()
ax = sns.lineplot(y='neuron', x='layer', hue='name', data=df[df['type'] == fixtype],
legend=False, estimator=None,
palette=[sns.color_palette()[n]] * (df[df['type'] == fixtype].shape[0]//2), lw=1)
# ax.set(yscale='log', ylabel='Neuron')
ax.set_title(fixtype)
plt.show()
print('plottet')

View File

@ -1,4 +1,5 @@
import pickle
import re
from collections import defaultdict
from pathlib import Path
import sys
@ -17,7 +18,7 @@ from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Compose, Resize
from tqdm import tqdm
# noinspection DuplicatedCode
if platform.node() == 'CarbonX':
debug = True
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
@ -37,14 +38,15 @@ else:
DIR = None
pass
from network import MetaNet, FixTypes
from network import MetaNet, FixTypes as ft
from sparse_net import SparseNetwork
from functionalities_test import test_for_fixpoints
WORKER = 10 if not debug else 2
debug = False
BATCHSIZE = 500 if not debug else 50
EPOCH = 200
VALIDATION_FRQ = 5 if not debug else 1
VALIDATION_FRQ = 3 if not debug else 1
SELF_TRAIN_FRQ = 1 if not debug else 1
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@ -139,7 +141,7 @@ def plot_training_particle_types(path_to_dataframe):
df = pd.read_csv(path_to_dataframe, index_col=False)
# Set up figure
fig, ax = plt.subplots() # initializes figure and plots
data = df[df['Metric'].isin(FixTypes.all_types())]
data = df.loc[df['Metric'].isin(ft.all_types())]
fix_types = data['Metric'].unique()
data = data.pivot(index='Epoch', columns='Metric', values='Score').reset_index().fillna(0)
_ = plt.stackplot(data['Epoch'], *[data[fixtype] for fixtype in fix_types], labels=fix_types.tolist())
@ -189,196 +191,253 @@ def plot_training_result(path_to_dataframe):
plt.savefig(Path(path_to_dataframe.parent / 'training_lineplot.png'), dpi=300)
def plot_network_connectivity_by_fixtype(path_to_trained_model):
m = torch.load(path_to_trained_model, map_location=torch.device('cpu'))
# noinspection PyProtectedMember
particles = [y for x in m._meta_layer_list for y in x.particles]
df = pd.DataFrame(columns=['type', 'layer', 'neuron', 'name'])
for prtcl in particles:
l, c, w = [float(x) for x in re.sub("[^0-9|_]", "", prtcl.name).split('_')]
df.loc[df.shape[0]] = (prtcl.is_fixpoint, l-1, w, prtcl.name)
df.loc[df.shape[0]] = (prtcl.is_fixpoint, l, c, prtcl.name)
for layer in list(df['layer'].unique()):
# Rescale
divisor = df.loc[(df['layer'] == layer), 'neuron'].max()
df.loc[(df['layer'] == layer), 'neuron'] /= divisor
print('gathered')
for n, fixtype in enumerate([ft.other_func, ft.identity_func]):
plt.clf()
ax = sns.lineplot(y='neuron', x='layer', hue='name', data=df[df['type'] == fixtype],
legend=False, estimator=None,
palette=[sns.color_palette()[n]] * (df[df['type'] == fixtype].shape[0]//2), lw=1)
ax.set_title(fixtype)
plt.show()
print('plottet')
def run_particle_dropout_test(run_path):
diff_store_path = run_path / 'diff_store.csv'
prtcl_dict = defaultdict(lambda: 0)
_ = test_for_fixpoints(prtcl_dict, list(latest_model.particles))
tqdm.write(str(dict(prtcl_dict)))
acc_pre = validate(model_path, ratio=1).item()
diff_df = pd.DataFrame(columns=['Particle Type', 'Accuracy', 'Diff'])
for fixpoint_type in ft.all_types():
new_model = torch.load(model_path, map_location=DEVICE).eval().replace_with_zero(fixpoint_type)
if [x for x in new_model.particles if x.is_fixpoint == fixpoint_type]:
new_ckpt = set_checkpoint(new_model, model_path.parent, fixpoint_type, final_model=True)
acc_post = validate(new_ckpt, ratio=1).item()
acc_diff = abs(acc_post - acc_pre)
tqdm.write(f'Zero_ident diff = {acc_diff}')
diff_df.loc[diff_df.shape[0]] = (fixpoint_type, acc_post, acc_diff)
diff_df.to_csv(diff_store_path, mode='a', header=not df_store_path.exists(), index=False)
return diff_store_path
def plot_dropout_stacked_barplot(path_to_diff_df):
diff_df = pd.read_csv(path_to_diff_df)
particle_dict = defaultdict(lambda: 0)
_ = test_for_fixpoints(particle_dict, list(latest_model.particles))
tqdm.write(str(dict(particle_dict)))
plt.clf()
fig, ax = plt.subplots(ncols=2)
colors = sns.color_palette()[:diff_df.shape[0]]
barplot = sns.barplot(data=diff_df, y='Accuracy', x='Particle Type', palette=colors, ax=ax[0])
# noinspection PyUnboundLocalVariable
for idx, patch in enumerate(barplot.patches):
if idx != 0:
# we recenter the bar
patch.set_x(patch.get_x() + idx * 0.035)
ax[0].set_title('Accuracy after particle dropout')
ax[0].set_xlabel('Accuracy')
ax[1].pie(particle_dict.values(), labels=particle_dict.keys(), colors=colors, )
ax[1].set_title('Particle Count for ')
plt.tight_layout()
if debug:
plt.show()
else:
plt.savefig(Path(path_to_diff_df.parent / 'dropout_stacked_barplot.png'), dpi=300)
def run_particle_dropout_and_plot(run_path):
diff_store_path = run_particle_dropout_test(run_path)
plot_dropout_stacked_barplot(diff_store_path)
def flat_for_store(parameters):
return (x.item() for y in parameters for x in y.detach().flatten())
if __name__ == '__main__':
use_sparse_implementation = True
self_train = True
training = False
plotting = True
particle_analysis = True
as_sparse_network_test = True
training = True
train_to_id_first = False
self_train_alpha = 100
train_to_task_first = False
train_to_task_first_sequential = True
tsk_threshold = 0.855
self_train_alpha = 1
batch_train_beta = 1
weight_hidden_size = 4
weight_hidden_size = 3
residual_skip = True
dropout = 0
n_seeds = 2
data_path = Path('data')
data_path.mkdir(exist_ok=True, parents=True)
assert not (train_to_task_first and train_to_id_first)
st_str = f'{"" if self_train else "no_"}st'
a_str = f'_alpha_{self_train_alpha}' if self_train_alpha != 1 else ''
res_str = f'{"" if residual_skip else "_no"}_res'
dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}'
res_str = f'{"" if residual_skip else "_no_res"}'
# dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}'
id_str = f'{f"_StToId" if train_to_id_first else ""}'
run_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{a_str}{res_str}{dr_str}{id_str}'
tsk_str = f'{f"_Tsk_{tsk_threshold}" if train_to_task_first else ""}'
exp_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{a_str}{res_str}{id_str}{tsk_str}'
model_path = run_path / '0000_trained_model.zip'
df_store_path = run_path / 'train_store.csv'
weight_store_path = run_path / 'weight_store.csv'
srnn_parameters = dict()
if use_sparse_implementation:
metanet_class = SparseNetwork
else:
metanet_class = MetaNet
if training:
utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)])
try:
dataset = MNIST(str(data_path), transform=utility_transforms)
except RuntimeError:
dataset = MNIST(str(data_path), transform=utility_transforms, download=True)
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
for seed in range(n_seeds):
seed_path = exp_path / str(seed)
interface = np.prod(dataset[0][0].shape)
metanet = MetaNet(interface, depth=5, width=6, out=10, residual_skip=residual_skip, dropout=dropout,
weight_hidden_size=weight_hidden_size,
).to(DEVICE)
meta_weight_count = sum(p.numel() for p in next(metanet.particles).parameters())
model_path = seed_path / '0000_trained_model.zip'
df_store_path = seed_path / 'train_store.csv'
weight_store_path = seed_path / 'weight_store.csv'
srnn_parameters = dict()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(metanet.parameters(), lr=0.008, momentum=0.9)
if training:
utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)])
try:
dataset = MNIST(str(data_path), transform=utility_transforms)
except RuntimeError:
dataset = MNIST(str(data_path), transform=utility_transforms, download=True)
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
train_store = new_storage_df('train', None)
weight_store = new_storage_df('weights', meta_weight_count)
for epoch in tqdm(range(EPOCH), desc='MetaNet Train - Epochs'):
is_validation_epoch = epoch % VALIDATION_FRQ == 0 if not debug else True
is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True
metanet = metanet.train()
if is_validation_epoch:
metric = torchmetrics.Accuracy()
else:
metric = None
init_st = train_to_id_first and all(x.is_fixpoint == FixTypes.identity_func for x in metanet.particles)
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
if (self_train and is_self_train_epoch) or init_st:
# Zero your gradients for every batch!
optimizer.zero_grad()
self_train_loss = metanet.combined_self_train() * self_train_alpha
self_train_loss.backward()
# Adjust learning weights
optimizer.step()
step_log = dict(Epoch=epoch, Batch=batch, Metric='Self Train Loss', Score=self_train_loss.item())
train_store.loc[train_store.shape[0]] = step_log
if train_to_id_first <= epoch:
# Zero your gradients for every batch!
optimizer.zero_grad()
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
y = metanet(batch_x)
# loss = loss_fn(y, batch_y.unsqueeze(-1).to(torch.float32))
loss = loss_fn(y, batch_y.to(torch.long)) * batch_train_beta
loss.backward()
interface = np.prod(dataset[0][0].shape)
metanet = metanet_class(interface, depth=5, width=6, out=10, residual_skip=residual_skip,
weight_hidden_size=weight_hidden_size,).to(DEVICE)
meta_weight_count = sum(p.numel() for p in next(metanet.particles).parameters())
# Adjust learning weights
optimizer.step()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(metanet.parameters(), lr=0.008, momentum=0.9)
step_log = dict(Epoch=epoch, Batch=batch,
Metric='Task Loss', Score=loss.item())
train_store.loc[train_store.shape[0]] = step_log
if is_validation_epoch:
metric(y.cpu(), batch_y.cpu())
train_store = new_storage_df('train', None)
weight_store = new_storage_df('weights', meta_weight_count)
init_tsk = train_to_task_first
for epoch in tqdm(range(EPOCH), desc='MetaNet Train - Epochs'):
is_validation_epoch = epoch % VALIDATION_FRQ == 0 if not debug else True
is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True
metanet = metanet.train()
if is_validation_epoch:
metric = torchmetrics.Accuracy()
else:
metric = None
init_st = train_to_id_first and not all(x.is_fixpoint == ft.identity_func for x in metanet.particles)
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
if self_train and not init_tsk and (is_self_train_epoch or init_st):
# Zero your gradients for every batch!
optimizer.zero_grad()
self_train_loss = metanet.combined_self_train() * self_train_alpha
self_train_loss.backward()
# Adjust learning weights
optimizer.step()
step_log = dict(Epoch=epoch, Batch=batch,
Metric='Self Train Loss', Score=self_train_loss.item())
train_store.loc[train_store.shape[0]] = step_log
if not init_st:
# Zero your gradients for every batch!
optimizer.zero_grad()
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
y_pred = metanet(batch_x)
# loss = loss_fn(y, batch_y.unsqueeze(-1).to(torch.float32))
loss = loss_fn(y_pred, batch_y.to(torch.long)) * batch_train_beta
loss.backward()
if batch >= 3 and debug:
break
# Adjust learning weights
optimizer.step()
if is_validation_epoch:
metanet = metanet.eval()
if train_to_id_first <= epoch:
step_log = dict(Epoch=epoch, Batch=batch,
Metric='Task Loss', Score=loss.item())
train_store.loc[train_store.shape[0]] = step_log
if is_validation_epoch:
metric(y_pred.cpu(), batch_y.cpu())
if batch >= 3 and debug:
break
if is_validation_epoch:
metanet = metanet.eval()
if train_to_id_first <= epoch:
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric='Train Accuracy', Score=metric.compute().item())
train_store.loc[train_store.shape[0]] = validation_log
accuracy = checkpoint_and_validate(metanet, seed_path, epoch).item()
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric='Train Accuracy', Score=metric.compute().item())
Metric='Test Accuracy', Score=accuracy)
train_store.loc[train_store.shape[0]] = validation_log
if init_tsk or (train_to_task_first and train_to_task_first_sequential):
init_tsk = accuracy <= tsk_threshold
if init_st or is_validation_epoch:
counter_dict = defaultdict(lambda: 0)
# This returns ID-functions
_ = test_for_fixpoints(counter_dict, list(metanet.particles))
for key, value in dict(counter_dict).items():
step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value)
train_store.loc[train_store.shape[0]] = step_log
if init_st or is_validation_epoch:
for particle in metanet.particles:
weight_log = (epoch, particle.name, *flat_for_store(particle.parameters()))
weight_store.loc[weight_store.shape[0]] = weight_log
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False)
weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
train_store = new_storage_df('train', None)
weight_store = new_storage_df('weights', meta_weight_count)
accuracy = checkpoint_and_validate(metanet, run_path, epoch)
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric='Test Accuracy', Score=accuracy.item())
train_store.loc[train_store.shape[0]] = validation_log
if particle_analysis and (init_st or is_validation_epoch):
counter_dict = defaultdict(lambda: 0)
# This returns ID-functions
_ = test_for_fixpoints(counter_dict, list(metanet.particles))
for key, value in dict(counter_dict).items():
step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value)
train_store.loc[train_store.shape[0]] = step_log
if init_st or is_validation_epoch:
for particle in metanet.particles:
weight_log = (epoch, particle.name, *flat_for_store(particle.parameters()))
weight_store.loc[weight_store.shape[0]] = weight_log
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False)
weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
train_store = new_storage_df('train', None)
weight_store = new_storage_df('weights', meta_weight_count)
metanet.eval()
metanet.eval()
if particle_analysis:
counter_dict = defaultdict(lambda: 0)
# This returns ID-functions
_ = test_for_fixpoints(counter_dict, list(metanet.particles))
for key, value in dict(counter_dict).items():
step_log = dict(Epoch=int(EPOCH), Batch=BATCHSIZE, Metric=key, Score=value)
train_store.loc[train_store.shape[0]] = step_log
accuracy = checkpoint_and_validate(metanet, run_path, EPOCH, final_model=True)
validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
Metric='Test Accuracy', Score=accuracy.item())
for particle in metanet.particles:
weight_log = (EPOCH, particle.name, *(flat_for_store(particle.parameters())))
weight_store.loc[weight_store.shape[0]] = weight_log
accuracy = checkpoint_and_validate(metanet, seed_path, EPOCH, final_model=True)
validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
Metric='Test Accuracy', Score=accuracy.item())
for particle in metanet.particles:
weight_log = (EPOCH, particle.name, *(flat_for_store(particle.parameters())))
weight_store.loc[weight_store.shape[0]] = weight_log
train_store.loc[train_store.shape[0]] = validation_log
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False)
weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
train_store.loc[train_store.shape[0]] = validation_log
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False)
weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
if plotting:
plot_training_result(df_store_path)
if particle_analysis:
plot_training_particle_types(df_store_path)
plot_training_particle_types(df_store_path)
if particle_analysis:
try:
model_path = next(run_path.glob(f'*e{EPOCH}.tp'))
model_path = next(seed_path.glob(f'*e{EPOCH}.tp'))
except StopIteration:
print('Model pattern did not trigger.')
print(f'Search path was: {run_path}:')
print(f'Found Models are: {list(run_path.rglob(".tp"))}')
print(f'Search path was: {seed_path}:')
print(f'Found Models are: {list(seed_path.rglob(".tp"))}')
exit(1)
latest_model = torch.load(model_path, map_location=DEVICE).eval()
counter_dict = defaultdict(lambda: 0)
_ = test_for_fixpoints(counter_dict, list(latest_model.particles))
tqdm.write(str(dict(counter_dict)))
if as_sparse_network_test:
acc_pre = validate(model_path, ratio=1).item()
diff_df = pd.DataFrame(columns=['Particle Type', 'Accuracy', 'Diff'])
for fixpoint_type in FixTypes.all_types():
new_model = torch.load(model_path, map_location=DEVICE).eval().replace_with_zero(fixpoint_type)
if [x for x in new_model.particles if x.is_fixpoint == fixpoint_type]:
new_ckpt = set_checkpoint(new_model, model_path.parent, fixpoint_type, final_model=True)
acc_post = validate(new_ckpt, ratio=1).item()
acc_diff = abs(acc_post-acc_pre)
tqdm.write(f'Zero_ident diff = {acc_diff}')
diff_df.loc[diff_df.shape[0]] = (fixpoint_type, acc_post, acc_diff)
run_particle_dropout_and_plot(seed_path)
plot_network_connectivity_by_fixtype(model_path)
if plotting:
plt.clf()
fig, ax = plt.subplots(ncols=2)
labels = ['Full Network', 'Sparse, No Identity', 'Sparse, No Other']
colors = sns.color_palette()[:diff_df.shape[0]] if diff_df.shape[0] >= 2 else sns.color_palette()[0]
barplot = sns.barplot(data=diff_df, y='Accuracy', x='Particle Type', palette=colors, ax=ax[0])
# noinspection PyUnboundLocalVariable
for idx, patch in enumerate(barplot.patches):
if idx != 0:
# we recenter the bar
patch.set_x(patch.get_x() + idx * 0.035)
ax[0].set_title('Accuracy after particle dropout')
ax[0].set_xlabel('Accuracy')
# ax[0].legend()
ax[1].pie(counter_dict.values(), labels=counter_dict.keys(), colors=colors, )
ax[1].set_title('Particle Count for ')
# ax[1].set_xlabel('')
plt.tight_layout()
if debug:
plt.show()
else:
plt.savefig(Path(run_path / 'dropout_stacked_barplot.png'), dpi=300)
if n_seeds >= 2:
pass

View File

View File

@ -68,7 +68,6 @@ class Net(nn.Module):
for weight_id, weight_value in enumerate(self.state_dict()[layer_name][line_id]):
self.state_dict()[layer_name][line_id][weight_id] = new_weights[i]
i += 1
return self
def __init__(self, i_size: int, h_size: int, o_size: int, name=None, start_time=1) -> None:
@ -100,7 +99,6 @@ class Net(nn.Module):
self._weight_pos_enc_and_mask = None
@property
def _weight_pos_enc(self):
if self._weight_pos_enc_and_mask is None:
@ -127,8 +125,8 @@ class Net(nn.Module):
# Normalize 1,2,3 column of dim 1
last_pos_idx = self.input_size - 4
norm2 = weight_matrix[:, 1:-last_pos_idx].pow(2).sum(keepdim=True, dim=0).sqrt()
weight_matrix[:, 1:-last_pos_idx] = (weight_matrix[:, 1:-last_pos_idx] / norm2) + 1e-8
max_per_col, _ = weight_matrix[:, 1:-last_pos_idx].max(keepdim=True, dim=0)
weight_matrix[:, 1:-last_pos_idx] = (weight_matrix[:, 1:-last_pos_idx] / max_per_col) + 1e-8
# computations
# create a mask where pos is 0 if it is to be replaced
@ -389,6 +387,7 @@ class MetaNet(nn.Module):
def __init__(self, interface=4, depth=3, width=4, out=1, activation=None, residual_skip=True, dropout=0,
weight_interface=5, weight_hidden_size=2, weight_output_size=1,):
super().__init__()
self.residual_skip = residual_skip
self.dropout = dropout
self.activation = activation
self.out = out
@ -398,7 +397,6 @@ class MetaNet(nn.Module):
self.weight_interface = weight_interface
self.weight_hidden_size = weight_hidden_size
self.weight_output_size = weight_output_size
self._meta_layer_first = MetaLayer(name=f'L{0}',
interface=self.interface,
width=self.width,
@ -411,6 +409,7 @@ class MetaNet(nn.Module):
weight_interface=weight_interface,
weight_hidden_size=weight_hidden_size,
weight_output_size=weight_output_size,
) for layer_idx in range(self.depth - 2)]
)
self._meta_layer_last = MetaLayer(name=f'L{len(self._meta_layer_list)}',
@ -441,10 +440,10 @@ class MetaNet(nn.Module):
for idx, meta_layer in enumerate(self._meta_layer_list, start=1):
if self.dropout != 0:
tensor = self.dropout_layer(tensor)
if idx % 2 == 1:
if idx % 2 == 1 and self.residual_skip:
x = tensor.clone()
tensor = meta_layer(tensor)
if idx % 2 == 0:
if idx % 2 == 0 and self.residual_skip:
tensor = tensor + x
if self.dropout != 0:
x = self.dropout_layer(x)

View File

@ -56,7 +56,7 @@ if __name__ == '__main__':
d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER)
loss_fn = nn.CrossEntropyLoss()
model = torch.load("trained_model_ckpt_e200.tp", map_location=DEVICE).eval()
model = torch.load("mn_st_40_6_res_Tsk_0.85", map_location=DEVICE).eval()
weights = extract_weights_from_model(model)
test_weights_as_model(weights, d_test)

View File

@ -1,85 +1,114 @@
from torch import nn
from network import Net
from typing import List
from functionalities_test import is_identity_function
from tqdm import tqdm,trange
import numpy as np
from pathlib import Path
import torch
from torch.nn import Flatten
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Compose, Resize
class SparseLayer():
class SparseLayer(nn.Module):
def __init__(self, nr_nets, interface=5, depth=3, width=2, out=1):
super(SparseLayer, self).__init__()
self.nr_nets = nr_nets
self.interface_dim = interface
self.depth_dim = depth
self.hidden_dim = width
self.out_dim = out
self.dummy_net = Net(self.interface_dim, self.hidden_dim, self.out_dim)
self.sparse_sub_layer = []
self.weights = []
for layer_id in range(depth):
layer, weights = self.coo_sparse_layer(layer_id)
self.sparse_sub_layer.append(layer)
self.sparse_sub_layer = list()
self.indices = list()
self.diag_shapes = list()
self.weights = nn.ParameterList()
self._particles = None
for layer_id in range(self.depth_dim):
indices, weights, diag_shape = self.coo_sparse_layer(layer_id)
self.indices.append(indices)
self.diag_shapes.append(diag_shape)
self.weights.append(weights)
def coo_sparse_layer(self, layer_id):
layer_shape = list(self.dummy_net.parameters())[layer_id].shape
#print(layer_shape) #(out_cells, in_cells) -> (2,5), (2,2), (1,2)
sparse_diagonal = np.eye(self.nr_nets).repeat(layer_shape[0], axis=-2).repeat(layer_shape[1], axis=-1)
indices = np.argwhere(sparse_diagonal == 1).T
values = torch.nn.Parameter(torch.randn((self.nr_nets * (layer_shape[0]*layer_shape[1]) )))
#values = torch.randn((self.nr_nets * layer_shape[0]*layer_shape[1] ))
s = torch.sparse_coo_tensor(indices, values, sparse_diagonal.shape, requires_grad=True)
print(f"L{layer_id}:", s.shape)
return s, values
indices = torch.Tensor(np.argwhere(sparse_diagonal == 1).T)
values = torch.nn.Parameter(
torch.randn((self.nr_nets * (layer_shape[0]*layer_shape[1]))), requires_grad=True
)
return indices, values, sparse_diagonal.shape
def get_self_train_inputs_and_targets(self):
encoding_matrix, mask = self.dummy_net._weight_pos_enc
# view weights of each sublayer in equal chunks, each column representing weights of one selfrepNN
# i.e., first interface*hidden weights of layer1, first hidden*hidden weights of layer2 and first hidden*out weights of layer3 = first net
weights = [layer.view(-1, int(len(layer)/self.nr_nets)) for layer in self.weights] #[nr_layers*[nr_net*nr_weights_layer_i]]
weights_per_net = [torch.cat([layer[i] for layer in weights]).view(-1,1) for i in range(self.nr_nets)] #[nr_net*[nr_weights]]
inputs = torch.hstack([encoding_matrix * mask + weights_per_net[i].expand(-1, encoding_matrix.shape[-1]) * (1 - mask) for i in range(self.nr_nets)]) #(16, 25)
# i.e., first interface*hidden weights of layer1, first hidden*hidden weights of layer2
# and first hidden*out weights of layer3 = first net
# [nr_layers*[nr_net*nr_weights_layer_i]]
weights = [layer.view(-1, int(len(layer)/self.nr_nets)) for layer in self.weights]
# [nr_net*[nr_weights]]
weights_per_net = [torch.cat([layer[i] for layer in weights]).view(-1, 1) for i in range(self.nr_nets)]
# (16, 25)
inputs = torch.hstack(
[encoding_matrix * mask + weights_per_net[i].expand(-1, encoding_matrix.shape[-1]) * (1 - mask)
for i in range(self.nr_nets)]
)
targets = torch.hstack(weights_per_net)
return inputs.T, targets.T
return inputs.T.detach(), targets.T.detach()
@property
def particles(self):
if self._particles is None:
self._particles = [Net(self.interface_dim, self.hidden_dim, self.out_dim) for _ in range(self.nr_nets)]
pass
else:
pass
# Particle Weight Update
all_weights = [layer.view(-1, int(len(layer) / self.nr_nets)) for layer in self.weights]
weights_per_net = [torch.cat([layer[i] for layer in all_weights]).view(-1, 1) for i in
range(self.nr_nets)] # [nr_net*[nr_weights]]
for particles, weights in zip(self._particles, weights_per_net):
particles.apply_weights(weights)
return self._particles
def __call__(self, x):
X1 = torch.sparse.mm(self.sparse_sub_layer[0], x)
#print("X1", X1.shape)
for indices, diag_shapes, weights in zip(self.indices, self.diag_shapes, self.weights):
s = torch.sparse_coo_tensor(indices, weights, diag_shapes, requires_grad=True, device=x.device)
x = torch.sparse.mm(s, x)
return x
X2 = torch.sparse.mm(self.sparse_sub_layer[1], X1)
#print("X2", X2.shape)
X3 = torch.sparse.mm(self.sparse_sub_layer[2], X2)
#print("X3", X3.shape)
return X3
def to(self, *args, **kwargs):
super(SparseLayer, self).to(*args, **kwargs)
self.sparse_sub_layer = [sparse_sub_layer.to(*args, **kwargs) for sparse_sub_layer in self.sparse_sub_layer]
return self
def test_sparse_layer():
net = SparseLayer(500) #50 parallel nets
loss_fn = torch.nn.MSELoss(reduction="sum")
optimizer = torch.optim.SGD([weight for weight in net.weights], lr=0.004, momentum=0.9)
#optimizer = torch.optim.SGD([layer for layer in net.sparse_sub_layer], lr=0.004, momentum=0.9)
optimizer = torch.optim.SGD(net.weights, lr=0.004, momentum=0.9)
# optimizer = torch.optim.SGD([layer.coalesce().values() for layer in net.sparse_sub_layer], lr=0.004, momentum=0.9)
for train_iteration in trange(1000):
optimizer.zero_grad()
optimizer.zero_grad()
X,Y = net.get_self_train_inputs_and_targets()
out = net(X)
loss = loss_fn(out, Y)
# print("X:", X.shape, "Y:", Y.shape)
# print("OUT", out.shape)
# print("LOSS", loss.item())
loss.backward(retain_graph=True)
optimizer.step()
@ -88,54 +117,95 @@ def test_sparse_layer():
print(f"identity_fn after {train_iteration+1} self-train iterations: {sum([torch.allclose(out[i], Y[i], rtol=0, atol=epsilon) for i in range(net.nr_nets)])}/{net.nr_nets}")
def embed_batch(x, repeat_dim):
# x of shape (batchsize, flat_img_dim)
x = x.unsqueeze(-1) #(batchsize, flat_img_dim, 1)
return torch.cat( (torch.zeros( x.shape[0], x.shape[1], 4), x), dim=2).repeat(1,1,repeat_dim) #(batchsize, flat_img_dim, encoding_dim*repeat_dim)
return torch.cat((torch.zeros(x.shape[0], x.shape[1], 4, device=x.device), x), dim=2).repeat(1, 1, repeat_dim) #(batchsize, flat_img_dim, encoding_dim*repeat_dim)
def embed_vector(x, repeat_dim):
# x of shape [flat_img_dim]
x = x.unsqueeze(-1) #(flat_img_dim, 1)
return torch.cat( (torch.zeros( x.shape[0], 4), x), dim=1).repeat(1,repeat_dim) #(flat_img_dim, encoding_dim*repeat_dim)
x = x.unsqueeze(-1) # (flat_img_dim, 1)
# (flat_img_dim, encoding_dim*repeat_dim)
return torch.cat((torch.zeros(x.shape[0], 4), x), dim=1).repeat(1,repeat_dim)
class SparseNetwork():
def __init__(self, input_dim, depth, width, out):
class SparseNetwork(nn.Module):
def __init__(self, input_dim, depth, width, out, residual_skip=True,
weight_interface=5, weight_hidden_size=2, weight_output_size=1
):
super(SparseNetwork, self).__init__()
self.residual_skip = residual_skip
self.input_dim = input_dim
self.depth_dim = depth
self.hidden_dim = width
self.out_dim = out
self.sparse_layers = []
self.sparse_layers.append( SparseLayer( self.input_dim * self.hidden_dim ))
self.sparse_layers.extend([ SparseLayer( self.hidden_dim * self.hidden_dim ) for layer_idx in range(self.depth_dim - 2)])
self.sparse_layers.append( SparseLayer( self.hidden_dim * self.out_dim ))
self.first_layer = SparseLayer(self.input_dim * self.hidden_dim,
interface=weight_interface, width=weight_hidden_size, out=weight_output_size)
self.last_layer = SparseLayer(self.hidden_dim * self.out_dim,
interface=weight_interface, width=weight_hidden_size, out=weight_output_size)
self.hidden_layers = nn.ModuleList([
SparseLayer(self.hidden_dim * self.hidden_dim,
interface=weight_interface, width=weight_hidden_size, out=weight_output_size
) for _ in range(self.depth_dim - 2)])
def __call__(self, x):
for sparse_layer in self.sparse_layers[:-1]:
# batch pass (one by one, sparse bmm doesn't support grad)
if len(x.shape) > 1:
embedded_inpt = embed_batch(x, sparse_layer.nr_nets)
x = torch.stack([sparse_layer(inpt.T).sum(dim=1).view(self.hidden_dim, x.shape[1]).sum(dim=1) for inpt in embedded_inpt]) #[batchsize, hidden*inpt_dim, feature_dim]
# vector
else:
embedded_inpt = embed_vector(x, sparse_layer.nr_nets)
x = sparse_layer(embedded_inpt.T).sum(dim=1).view(self.hidden_dim, x.shape[1]).sum(dim=1)
print("out", x.shape)
# output layer
sparse_layer = self.sparse_layers[-1]
tensor = self.sparse_layer_forward(x, self.first_layer)
for nl_idx, network_layer in enumerate(self.hidden_layers):
if nl_idx % 2 == 0 and self.residual_skip:
residual = tensor.clone()
# Sparse Layer pass
tensor = self.sparse_layer_forward(tensor, network_layer)
if nl_idx % 2 != 0 and self.residual_skip:
# noinspection PyUnboundLocalVariable
tensor += residual
tensor = self.sparse_layer_forward(tensor, self.last_layer, view_dim=self.out_dim)
return tensor
def sparse_layer_forward(self, x, sparse_layer, view_dim=None):
view_dim = view_dim if view_dim else self.hidden_dim
# batch pass (one by one, sparse bmm doesn't support grad)
if len(x.shape) > 1:
embedded_inpt = embed_batch(x, sparse_layer.nr_nets)
x = torch.stack([sparse_layer(inpt.T).sum(dim=1).view(self.out_dim, x.shape[1]).sum(dim=1) for inpt in embedded_inpt]) #[batchsize, hidden*inpt_dim, feature_dim]
# [batchsize, hidden*inpt_dim, feature_dim]
x = torch.stack([sparse_layer(inpt.T).sum(dim=1).view(view_dim, x.shape[1]).sum(dim=1) for inpt in
embedded_inpt])
# vector
else:
embedded_inpt = embed_vector(x, sparse_layer.nr_nets)
x = sparse_layer(embedded_inpt.T).sum(dim=1).view(self.out_dim, x.shape[1]).sum(dim=1)
print("out", x.shape)
x = sparse_layer(embedded_inpt.T).sum(dim=1).view(view_dim, x.shape[1]).sum(dim=1)
return x
@property
def particles(self):
particles = []
particles.extend(self.first_layer.particles)
for layer in self.hidden_layers:
particles.extend(layer.particles)
particles.extend(self.last_layer.particles)
return iter(particles)
def to(self, *args, **kwargs):
super(SparseNetwork, self).to(*args, **kwargs)
self.first_layer = self.first_layer.to(*args, **kwargs)
self.last_layer = self.last_layer.to(*args, **kwargs)
self.hidden_layers = nn.ModuleList([hidden_layer.to(*args, **kwargs) for hidden_layer in self.hidden_layers])
return self
def combined_self_train(self):
import time
t = time.time()
losses = []
for layer in [self.first_layer, *self.hidden_layers, self.last_layer]:
x, target_data = layer.get_self_train_inputs_and_targets()
output = layer(x)
losses.append(F.mse_loss(output, target_data))
print('Time Taken:', time.time() - t)
return torch.hstack(losses).sum(dim=-1, keepdim=True)
def test_sparse_net():
utility_transforms = Compose([ Resize((10, 10)), ToTensor(), Flatten(start_dim=0)])
@ -150,7 +220,6 @@ def test_sparse_net():
data_dim = np.prod(dataset[0][0].shape)
metanet = SparseNetwork(data_dim, depth=3, width=5, out=10)
batchx, batchy = next(iter(d))
batchx.shape, batchy.shape
metanet(batchx)
@ -176,6 +245,6 @@ def test_manual_for_loop():
if __name__ == '__main__':
test_sparse_layer()
test_sparse_net()
#for comparison
# test_sparse_net()
# for comparison
test_manual_for_loop()