apply networks are now loop free

This commit is contained in:
Steffen Illium
2022-02-21 18:11:30 +01:00
parent f25cee5203
commit 2a710b40d7
3 changed files with 121 additions and 65 deletions

View File

@ -277,7 +277,6 @@ def flat_for_store(parameters):
if __name__ == '__main__': if __name__ == '__main__':
use_sparse_implementation = True
self_train = True self_train = True
training = True training = True
train_to_id_first = False train_to_id_first = False
@ -303,11 +302,6 @@ if __name__ == '__main__':
tsk_str = f'{f"_Tsk_{tsk_threshold}" if train_to_task_first else ""}' tsk_str = f'{f"_Tsk_{tsk_threshold}" if train_to_task_first else ""}'
exp_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{a_str}{res_str}{id_str}{tsk_str}' exp_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{a_str}{res_str}{id_str}{tsk_str}'
if use_sparse_implementation:
metanet_class = SparseNetwork
else:
metanet_class = MetaNet
for seed in range(n_seeds): for seed in range(n_seeds):
seed_path = exp_path / str(seed) seed_path = exp_path / str(seed)
@ -325,12 +319,15 @@ if __name__ == '__main__':
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER) d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
interface = np.prod(dataset[0][0].shape) interface = np.prod(dataset[0][0].shape)
metanet = metanet_class(interface, depth=5, width=6, out=10, residual_skip=residual_skip, sparse_metanet = SparseNetwork(interface, depth=5, width=6, out=10, residual_skip=residual_skip,
weight_hidden_size=weight_hidden_size,).to(DEVICE) weight_hidden_size=weight_hidden_size,).to(DEVICE)
meta_weight_count = sum(p.numel() for p in next(metanet.particles).parameters()) dense_metanet = MetaNet(interface, depth=5, width=6, out=10, residual_skip=residual_skip,
weight_hidden_size=weight_hidden_size,).to(DEVICE)
meta_weight_count = sum(p.numel() for p in next(dense_metanet.particles).parameters())
loss_fn = nn.CrossEntropyLoss() loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(metanet.parameters(), lr=0.008, momentum=0.9) dense_optimizer = torch.optim.SGD(dense_metanet.parameters(), lr=0.008, momentum=0.9)
sparse_optimizer = torch.optim.SGD(sparse_metanet.parameters(), lr=0.008, momentum=0.9)
train_store = new_storage_df('train', None) train_store = new_storage_df('train', None)
weight_store = new_storage_df('weights', meta_weight_count) weight_store = new_storage_df('weights', meta_weight_count)
@ -338,34 +335,40 @@ if __name__ == '__main__':
for epoch in tqdm(range(EPOCH), desc='MetaNet Train - Epochs'): for epoch in tqdm(range(EPOCH), desc='MetaNet Train - Epochs'):
is_validation_epoch = epoch % VALIDATION_FRQ == 0 if not debug else True is_validation_epoch = epoch % VALIDATION_FRQ == 0 if not debug else True
is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True
metanet = metanet.train() sparse_metanet = sparse_metanet.train()
dense_metanet = dense_metanet.train()
if is_validation_epoch: if is_validation_epoch:
metric = torchmetrics.Accuracy() metric = torchmetrics.Accuracy()
else: else:
metric = None metric = None
init_st = train_to_id_first and not all(x.is_fixpoint == ft.identity_func for x in metanet.particles) init_st = train_to_id_first and not all(x.is_fixpoint == ft.identity_func for x in dense_metanet.particles)
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'): for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
# Self Train
if self_train and not init_tsk and (is_self_train_epoch or init_st): if self_train and not init_tsk and (is_self_train_epoch or init_st):
# Transfer weights
sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles)
# Zero your gradients for every batch! # Zero your gradients for every batch!
optimizer.zero_grad() sparse_optimizer.zero_grad()
self_train_loss = metanet.combined_self_train() * self_train_alpha self_train_loss = sparse_metanet.combined_self_train() * self_train_alpha
self_train_loss.backward() self_train_loss.backward()
# Adjust learning weights # Adjust learning weights
optimizer.step() sparse_optimizer.step()
step_log = dict(Epoch=epoch, Batch=batch, step_log = dict(Epoch=epoch, Batch=batch,
Metric='Self Train Loss', Score=self_train_loss.item()) Metric='Self Train Loss', Score=self_train_loss.item())
train_store.loc[train_store.shape[0]] = step_log train_store.loc[train_store.shape[0]] = step_log
# Transfer weights
dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights)
if not init_st: if not init_st:
# Zero your gradients for every batch! # Zero your gradients for every batch!
optimizer.zero_grad() dense_optimizer.zero_grad()
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE) batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
y_pred = metanet(batch_x) y_pred = dense_metanet(batch_x)
# loss = loss_fn(y, batch_y.unsqueeze(-1).to(torch.float32)) # loss = loss_fn(y, batch_y.unsqueeze(-1).to(torch.float32))
loss = loss_fn(y_pred, batch_y.to(torch.long)) * batch_train_beta loss = loss_fn(y_pred, batch_y.to(torch.long)) * batch_train_beta
loss.backward() loss.backward()
# Adjust learning weights # Adjust learning weights
optimizer.step() dense_optimizer.step()
step_log = dict(Epoch=epoch, Batch=batch, step_log = dict(Epoch=epoch, Batch=batch,
Metric='Task Loss', Score=loss.item()) Metric='Task Loss', Score=loss.item())
@ -377,13 +380,13 @@ if __name__ == '__main__':
break break
if is_validation_epoch: if is_validation_epoch:
metanet = metanet.eval() dense_metanet = dense_metanet.eval()
if train_to_id_first <= epoch: if train_to_id_first <= epoch:
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric='Train Accuracy', Score=metric.compute().item()) Metric='Train Accuracy', Score=metric.compute().item())
train_store.loc[train_store.shape[0]] = validation_log train_store.loc[train_store.shape[0]] = validation_log
accuracy = checkpoint_and_validate(metanet, seed_path, epoch).item() accuracy = checkpoint_and_validate(dense_metanet, seed_path, epoch).item()
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric='Test Accuracy', Score=accuracy) Metric='Test Accuracy', Score=accuracy)
train_store.loc[train_store.shape[0]] = validation_log train_store.loc[train_store.shape[0]] = validation_log
@ -392,12 +395,12 @@ if __name__ == '__main__':
if init_st or is_validation_epoch: if init_st or is_validation_epoch:
counter_dict = defaultdict(lambda: 0) counter_dict = defaultdict(lambda: 0)
# This returns ID-functions # This returns ID-functions
_ = test_for_fixpoints(counter_dict, list(metanet.particles)) _ = test_for_fixpoints(counter_dict, list(dense_metanet.particles))
for key, value in dict(counter_dict).items(): for key, value in dict(counter_dict).items():
step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value) step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value)
train_store.loc[train_store.shape[0]] = step_log train_store.loc[train_store.shape[0]] = step_log
if init_st or is_validation_epoch: if init_st or is_validation_epoch:
for particle in metanet.particles: for particle in dense_metanet.particles:
weight_log = (epoch, particle.name, *flat_for_store(particle.parameters())) weight_log = (epoch, particle.name, *flat_for_store(particle.parameters()))
weight_store.loc[weight_store.shape[0]] = weight_log weight_store.loc[weight_store.shape[0]] = weight_log
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False) train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False)
@ -405,18 +408,18 @@ if __name__ == '__main__':
train_store = new_storage_df('train', None) train_store = new_storage_df('train', None)
weight_store = new_storage_df('weights', meta_weight_count) weight_store = new_storage_df('weights', meta_weight_count)
metanet.eval() dense_metanet.eval()
counter_dict = defaultdict(lambda: 0) counter_dict = defaultdict(lambda: 0)
# This returns ID-functions # This returns ID-functions
_ = test_for_fixpoints(counter_dict, list(metanet.particles)) _ = test_for_fixpoints(counter_dict, list(dense_metanet.particles))
for key, value in dict(counter_dict).items(): for key, value in dict(counter_dict).items():
step_log = dict(Epoch=int(EPOCH), Batch=BATCHSIZE, Metric=key, Score=value) step_log = dict(Epoch=int(EPOCH), Batch=BATCHSIZE, Metric=key, Score=value)
train_store.loc[train_store.shape[0]] = step_log train_store.loc[train_store.shape[0]] = step_log
accuracy = checkpoint_and_validate(metanet, seed_path, EPOCH, final_model=True) accuracy = checkpoint_and_validate(dense_metanet, seed_path, EPOCH, final_model=True)
validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE, validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
Metric='Test Accuracy', Score=accuracy.item()) Metric='Test Accuracy', Score=accuracy.item())
for particle in metanet.particles: for particle in dense_metanet.particles:
weight_log = (EPOCH, particle.name, *(flat_for_store(particle.parameters()))) weight_log = (EPOCH, particle.name, *(flat_for_store(particle.parameters())))
weight_store.loc[weight_store.shape[0]] = weight_log weight_store.loc[weight_store.shape[0]] = weight_log

View File

@ -1,9 +1,9 @@
# from __future__ import annotations # from __future__ import annotations
import copy import copy
import random import random
from math import sqrt
from typing import Union from typing import Union
import numpy as np
import pandas as pd import pandas as pd
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -61,13 +61,14 @@ class Net(nn.Module):
def apply_weights(self, new_weights: Tensor): def apply_weights(self, new_weights: Tensor):
""" Changing the weights of a network to new given values. """ """ Changing the weights of a network to new given values. """
# TODO: Change this to 'parameters' version keys = self.state_dict().keys()
i = 0 shapes = [x.shape for x in self.state_dict().values()]
for layer_id, layer_name in enumerate(self.state_dict()): numels = np.cumsum([0, *[x.numel() for x in self.state_dict().values()]])
for line_id, line_values in enumerate(self.state_dict()[layer_name]): new_state_dict = {key: new_weights[start: end].view(
for weight_id, weight_value in enumerate(self.state_dict()[layer_name][line_id]): shape) for key, shape, start, end in zip(keys, shapes, numels, numels[1:])
self.state_dict()[layer_name][line_id][weight_id] = new_weights[i] }
i += 1 # noinspection PyTypeChecker
self.load_state_dict(new_state_dict)
return self return self
def __init__(self, i_size: int, h_size: int, o_size: int, name=None, start_time=1) -> None: def __init__(self, i_size: int, h_size: int, o_size: int, name=None, start_time=1) -> None:
@ -159,6 +160,11 @@ class Net(nn.Module):
weight_matrix = pos_enc * mask + weight_matrix.expand(-1, pos_enc.shape[-1]) * (1 - mask) weight_matrix = pos_enc * mask + weight_matrix.expand(-1, pos_enc.shape[-1]) * (1 - mask)
return weight_matrix return weight_matrix
def target_weight_matrix(self) -> Tensor:
weight_matrix = torch.cat([x.view(-1, 1) for x in self.parameters()])
return weight_matrix
def self_train(self, def self_train(self,
training_steps: int, training_steps: int,
log_step_size: int = 0, log_step_size: int = 0,
@ -305,11 +311,10 @@ class MetaCell(nn.Module):
super().__init__() super().__init__()
self.name = name self.name = name
self.interface = interface self.interface = interface
self.weight_interface = 5 self.weight_interface = weight_interface
self.net_hidden_size = 2 self.net_hidden_size = weight_hidden_size
self.net_ouput_size = 1 self.net_ouput_size = weight_output_size
self.meta_weight_list = nn.ModuleList() self.meta_weight_list = nn.ModuleList(
self.meta_weight_list.extend(
[Net(self.weight_interface, self.net_hidden_size, [Net(self.weight_interface, self.net_hidden_size,
self.net_ouput_size, name=f'{self.name}_W{weight_idx}' self.net_ouput_size, name=f'{self.name}_W{weight_idx}'
) for weight_idx in range(self.interface)] ) for weight_idx in range(self.interface)]
@ -360,13 +365,13 @@ class MetaLayer(nn.Module):
self.interface = interface self.interface = interface
self.width = width self.width = width
self.meta_cell_list = nn.ModuleList() self.meta_cell_list = nn.ModuleList([
self.meta_cell_list.extend([MetaCell(name=f'{self.name}_C{cell_idx}', MetaCell(name=f'{self.name}_C{cell_idx}',
interface=interface, interface=interface,
weight_interface=weight_interface, weight_hidden_size=weight_hidden_size, weight_interface=weight_interface, weight_hidden_size=weight_hidden_size,
weight_output_size=weight_output_size, weight_output_size=weight_output_size,
) for cell_idx in range(self.width)] ) for cell_idx in range(self.width)]
) )
def forward(self, x): def forward(self, x):
cell_results = [] cell_results = []
@ -468,6 +473,14 @@ class MetaNet(nn.Module):
def hyperparams(self): def hyperparams(self):
return {key: val for key, val in self.__dict__.items() if not key.startswith('_')} return {key: val for key, val in self.__dict__.items() if not key.startswith('_')}
def replace_particles(self, particle_weights_list):
for layer in self._all_layers_with_particles:
for cell in layer.meta_cell_list:
# Individual replacement on cell lvl
for weight in cell.meta_weight_list:
weight.apply_weights(next(particle_weights_list))
return self
class MetaNetCompareBaseline(nn.Module): class MetaNetCompareBaseline(nn.Module):

View File

@ -22,7 +22,9 @@ class SparseLayer(nn.Module):
self.depth_dim = depth self.depth_dim = depth
self.hidden_dim = width self.hidden_dim = width
self.out_dim = out self.out_dim = out
self.dummy_net = Net(self.interface_dim, self.hidden_dim, self.out_dim) dummy_net = Net(self.interface_dim, self.hidden_dim, self.out_dim)
self.dummy_net_shapes = [list(x.shape) for x in dummy_net.parameters()]
self.dummy_net_weight_pos_enc = dummy_net._weight_pos_enc
self.sparse_sub_layer = list() self.sparse_sub_layer = list()
self.indices = list() self.indices = list()
@ -37,18 +39,14 @@ class SparseLayer(nn.Module):
self.weights.append(weights) self.weights.append(weights)
def coo_sparse_layer(self, layer_id): def coo_sparse_layer(self, layer_id):
layer_shape = list(self.dummy_net.parameters())[layer_id].shape layer_shape = self.dummy_net_shapes[layer_id]
sparse_diagonal = np.eye(self.nr_nets).repeat(layer_shape[0], axis=-2).repeat(layer_shape[1], axis=-1) sparse_diagonal = np.eye(self.nr_nets).repeat(layer_shape[0], axis=-2).repeat(layer_shape[1], axis=-1)
indices = torch.Tensor(np.argwhere(sparse_diagonal == 1).T) indices = torch.Tensor(np.argwhere(sparse_diagonal == 1).T)
values = torch.nn.Parameter( values = torch.nn.Parameter(torch.randn((np.prod((*layer_shape, self.nr_nets)).item())), requires_grad=True)
torch.randn((self.nr_nets * (layer_shape[0]*layer_shape[1]))), requires_grad=True
)
return indices, values, sparse_diagonal.shape return indices, values, sparse_diagonal.shape
def get_self_train_inputs_and_targets(self): def get_self_train_inputs_and_targets(self):
encoding_matrix, mask = self.dummy_net._weight_pos_enc
# view weights of each sublayer in equal chunks, each column representing weights of one selfrepNN # view weights of each sublayer in equal chunks, each column representing weights of one selfrepNN
# i.e., first interface*hidden weights of layer1, first hidden*hidden weights of layer2 # i.e., first interface*hidden weights of layer1, first hidden*hidden weights of layer2
# and first hidden*out weights of layer3 = first net # and first hidden*out weights of layer3 = first net
@ -57,6 +55,13 @@ class SparseLayer(nn.Module):
# [nr_net*[nr_weights]] # [nr_net*[nr_weights]]
weights_per_net = [torch.cat([layer[i] for layer in weights]).view(-1, 1) for i in range(self.nr_nets)] weights_per_net = [torch.cat([layer[i] for layer in weights]).view(-1, 1) for i in range(self.nr_nets)]
# (16, 25) # (16, 25)
encoding_matrix, mask = self.dummy_net_weight_pos_enc
weight_device = weights_per_net[0].device
if weight_device != encoding_matrix.device or weight_device != mask.device:
encoding_matrix, mask = encoding_matrix.to(weight_device), mask.to(weight_device)
self.dummy_net_weight_pos_enc = encoding_matrix, mask
inputs = torch.hstack( inputs = torch.hstack(
[encoding_matrix * mask + weights_per_net[i].expand(-1, encoding_matrix.shape[-1]) * (1 - mask) [encoding_matrix * mask + weights_per_net[i].expand(-1, encoding_matrix.shape[-1]) * (1 - mask)
for i in range(self.nr_nets)] for i in range(self.nr_nets)]
@ -80,6 +85,24 @@ class SparseLayer(nn.Module):
particles.apply_weights(weights) particles.apply_weights(weights)
return self._particles return self._particles
@property
def particle_weights(self):
all_weights = [layer.view(-1, int(len(layer) / self.nr_nets)) for layer in self.weights]
weights_per_net = [torch.cat([layer[i] for layer in all_weights]).view(-1, 1) for i in
range(self.nr_nets)] # [nr_net*[nr_weights]]
return weights_per_net
def replace_weights_by_particles(self, particles):
assert len(particles) == self.nr_nets
# Particle Weight Update
all_weights = [list(particle.parameters()) for particle in particles]
all_weights = [torch.cat(x).view(-1) for x in zip(*all_weights)]
# [layer.view(-1, int(len(layer) / self.nr_nets)) for layer in self.weights]
for widx, (weights, key) in enumerate(zip(all_weights, self.state_dict().keys())):
self.state_dict()[key] = weights[:]
return self
def __call__(self, x): def __call__(self, x):
for indices, diag_shapes, weights in zip(self.indices, self.diag_shapes, self.weights): for indices, diag_shapes, weights in zip(self.indices, self.diag_shapes, self.weights):
s = torch.sparse_coo_tensor(indices, weights, diag_shapes, requires_grad=True, device=x.device) s = torch.sparse_coo_tensor(indices, weights, diag_shapes, requires_grad=True, device=x.device)
@ -119,9 +142,12 @@ def test_sparse_layer():
def embed_batch(x, repeat_dim): def embed_batch(x, repeat_dim):
# x of shape (batchsize, flat_img_dim) # x of shape (batchsize, flat_img_dim)
x = x.unsqueeze(-1) #(batchsize, flat_img_dim, 1)
return torch.cat((torch.zeros(x.shape[0], x.shape[1], 4, device=x.device), x), dim=2).repeat(1, 1, repeat_dim) #(batchsize, flat_img_dim, encoding_dim*repeat_dim)
# (batchsize, flat_img_dim, 1)
x = x.unsqueeze(-1)
# (batchsize, flat_img_dim, encoding_dim*repeat_dim)
# torch.sparse_coo_tensor(indices, weights, diag_shapes, requires_grad=True, device=x.device)
return torch.cat((torch.zeros(x.shape[0], x.shape[1], 4, device=x.device), x), dim=2).repeat(1, 1, repeat_dim)
def embed_vector(x, repeat_dim): def embed_vector(x, repeat_dim):
# x of shape [flat_img_dim] # x of shape [flat_img_dim]
@ -154,7 +180,7 @@ class SparseNetwork(nn.Module):
tensor = self.sparse_layer_forward(x, self.first_layer) tensor = self.sparse_layer_forward(x, self.first_layer)
for nl_idx, network_layer in enumerate(self.hidden_layers): for nl_idx, network_layer in enumerate(self.hidden_layers):
if nl_idx % 2 == 0 and self.residual_skip: if nl_idx % 2 == 0 and self.residual_skip:
residual = tensor.clone() residual = tensor
# Sparse Layer pass # Sparse Layer pass
tensor = self.sparse_layer_forward(tensor, network_layer) tensor = self.sparse_layer_forward(tensor, network_layer)
@ -180,12 +206,18 @@ class SparseNetwork(nn.Module):
@property @property
def particles(self): def particles(self):
particles = [] #particles = []
particles.extend(self.first_layer.particles) #particles.extend(self.first_layer.particles)
for layer in self.hidden_layers: #for layer in self.hidden_layers:
particles.extend(layer.particles) # particles.extend(layer.particles)
particles.extend(self.last_layer.particles) #particles.extend(self.last_layer.particles)
return iter(particles) return (x for y in (self.first_layer.particles,
*(l.particles for l in self.hidden_layers),
self.last_layer.particles) for x in y)
@property
def particle_weights(self):
return (x for y in self.sparselayers for x in y.particle_weights)
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
super(SparseNetwork, self).to(*args, **kwargs) super(SparseNetwork, self).to(*args, **kwargs)
@ -194,18 +226,26 @@ class SparseNetwork(nn.Module):
self.hidden_layers = nn.ModuleList([hidden_layer.to(*args, **kwargs) for hidden_layer in self.hidden_layers]) self.hidden_layers = nn.ModuleList([hidden_layer.to(*args, **kwargs) for hidden_layer in self.hidden_layers])
return self return self
@property
def sparselayers(self):
return (x for x in (self.first_layer, *self.hidden_layers, self.last_layer))
def combined_self_train(self): def combined_self_train(self):
import time
t = time.time()
losses = [] losses = []
for layer in [self.first_layer, *self.hidden_layers, self.last_layer]: for layer in self.sparselayers:
x, target_data = layer.get_self_train_inputs_and_targets() x, target_data = layer.get_self_train_inputs_and_targets()
output = layer(x) output = layer(x)
losses.append(F.mse_loss(output, target_data)) losses.append(F.mse_loss(output, target_data))
print('Time Taken:', time.time() - t)
return torch.hstack(losses).sum(dim=-1, keepdim=True) return torch.hstack(losses).sum(dim=-1, keepdim=True)
def replace_weights_by_particles(self, particles):
particles = list(particles)
for layer in self.sparselayers:
layer.replace_weights_by_particles(particles[:layer.nr_nets])
del particles[:layer.nr_nets]
return self
def test_sparse_net(): def test_sparse_net():
utility_transforms = Compose([ Resize((10, 10)), ToTensor(), Flatten(start_dim=0)]) utility_transforms = Compose([ Resize((10, 10)), ToTensor(), Flatten(start_dim=0)])