new sanity methode
This commit is contained in:
@ -45,7 +45,7 @@ from functionalities_test import test_for_fixpoints
|
|||||||
WORKER = 10 if not debug else 2
|
WORKER = 10 if not debug else 2
|
||||||
debug = False
|
debug = False
|
||||||
BATCHSIZE = 500 if not debug else 50
|
BATCHSIZE = 500 if not debug else 50
|
||||||
EPOCH = 100
|
EPOCH = 50
|
||||||
VALIDATION_FRQ = 3 if not debug else 1
|
VALIDATION_FRQ = 3 if not debug else 1
|
||||||
SELF_TRAIN_FRQ = 1 if not debug else 1
|
SELF_TRAIN_FRQ = 1 if not debug else 1
|
||||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
@ -279,9 +279,9 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
self_train = True
|
self_train = True
|
||||||
training = True
|
training = True
|
||||||
train_to_id_first = True
|
train_to_id_first = False
|
||||||
train_to_task_first = False
|
train_to_task_first = False
|
||||||
train_to_task_first_sequential = False
|
train_to_task_first_sequential = True
|
||||||
force_st_for_n_from_last_epochs = 5
|
force_st_for_n_from_last_epochs = 5
|
||||||
|
|
||||||
use_sparse_network = False
|
use_sparse_network = False
|
||||||
@ -303,10 +303,12 @@ if __name__ == '__main__':
|
|||||||
# dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}'
|
# dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}'
|
||||||
id_str = f'{f"_StToId" if train_to_id_first else ""}'
|
id_str = f'{f"_StToId" if train_to_id_first else ""}'
|
||||||
tsk_str = f'{f"_Tsk_{tsk_threshold}" if train_to_task_first and tsk_threshold != 1 else ""}'
|
tsk_str = f'{f"_Tsk_{tsk_threshold}" if train_to_task_first and tsk_threshold != 1 else ""}'
|
||||||
|
sprs_str = '_sprs' if use_sparse_network else ''
|
||||||
f_str = f'_f_{force_st_for_n_from_last_epochs}' if \
|
f_str = f'_f_{force_st_for_n_from_last_epochs}' if \
|
||||||
force_st_for_n_from_last_epochs and train_to_task_first_sequential and train_to_task_first \
|
force_st_for_n_from_last_epochs and train_to_task_first_sequential and train_to_task_first \
|
||||||
else ""
|
else ""
|
||||||
exp_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{a_str}{res_str}{id_str}{tsk_str}{f_str}'
|
config_str = f'{a_str}{res_str}{id_str}{tsk_str}{f_str}{sprs_str}'
|
||||||
|
exp_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{config_str}'
|
||||||
|
|
||||||
for seed in range(n_seeds):
|
for seed in range(n_seeds):
|
||||||
seed_path = exp_path / str(seed)
|
seed_path = exp_path / str(seed)
|
||||||
@ -358,8 +360,8 @@ if __name__ == '__main__':
|
|||||||
force_st = (force_st_for_n_from_last_epochs >= (EPOCH - epoch)
|
force_st = (force_st_for_n_from_last_epochs >= (EPOCH - epoch)
|
||||||
) and train_to_task_first_sequential and force_st_for_n_from_last_epochs
|
) and train_to_task_first_sequential and force_st_for_n_from_last_epochs
|
||||||
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
|
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
|
||||||
# Self Train
|
|
||||||
|
|
||||||
|
# Self Train
|
||||||
if self_train and ((not init_tsk and (is_self_train_epoch or init_st)) or force_st):
|
if self_train and ((not init_tsk and (is_self_train_epoch or init_st)) or force_st):
|
||||||
# Transfer weights
|
# Transfer weights
|
||||||
if use_sparse_network:
|
if use_sparse_network:
|
||||||
@ -376,6 +378,8 @@ if __name__ == '__main__':
|
|||||||
# Transfer weights
|
# Transfer weights
|
||||||
if use_sparse_network:
|
if use_sparse_network:
|
||||||
dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights)
|
dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights)
|
||||||
|
|
||||||
|
# Task Train
|
||||||
if not init_st:
|
if not init_st:
|
||||||
# Zero your gradients for every batch!
|
# Zero your gradients for every batch!
|
||||||
dense_optimizer.zero_grad()
|
dense_optimizer.zero_grad()
|
||||||
|
@ -11,6 +11,11 @@ from torch import optim, Tensor
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def xavier_init(m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
nn.init.xavier_uniform_(m.weight.data)
|
||||||
|
|
||||||
|
|
||||||
def prng():
|
def prng():
|
||||||
return random.random()
|
return random.random()
|
||||||
|
|
||||||
@ -97,6 +102,7 @@ class Net(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self._weight_pos_enc_and_mask = None
|
self._weight_pos_enc_and_mask = None
|
||||||
|
self.apply(xavier_init)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _weight_pos_enc(self):
|
def _weight_pos_enc(self):
|
||||||
@ -503,7 +509,7 @@ class MetaNetCompareBaseline(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
metanet = MetaNet(interface=3, depth=5, width=3, out=1, dropout=0.0, residual_skip=True)
|
metanet = MetaNet(interface=3, depth=5, width=3, out=1, residual_skip=True)
|
||||||
next(metanet.particles).input_weight_matrix()
|
next(metanet.particles).input_weight_matrix()
|
||||||
metanet(torch.hstack([torch.full((2, 1), 1.0) for _ in range(metanet.interface)]))
|
metanet(torch.hstack([torch.full((2, 1), 1.0) for _ in range(metanet.interface)]))
|
||||||
a = metanet.particles
|
a = metanet.particles
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -15,18 +17,20 @@ def extract_weights_from_model(model:MetaNet)->dict:
|
|||||||
inpt[-1] = 1
|
inpt[-1] = 1
|
||||||
inpt.long()
|
inpt.long()
|
||||||
|
|
||||||
weights = {i:[] for i in range(model.depth)}
|
weights = defaultdict(list)
|
||||||
layers = [layer.particles for layer in [model._meta_layer_first, *model._meta_layer_list, model._meta_layer_last]]
|
layers = [layer.particles for layer in [model._meta_layer_first, *model._meta_layer_list, model._meta_layer_last]]
|
||||||
for i,layer in enumerate(layers):
|
for i, layer in enumerate(layers):
|
||||||
for net in layer:
|
for net in layer:
|
||||||
weights[i].append(net(inpt).detach())
|
weights[i].append(net(inpt).detach())
|
||||||
return weights
|
return dict(weights)
|
||||||
|
|
||||||
def test_weights_as_model(model, weights:dict, data):
|
|
||||||
|
def test_weights_as_model(model, new_weights:dict, data):
|
||||||
TransferNet = MetaNetCompareBaseline(model.interface, depth=model.depth, width=model.width, out=model.out)
|
TransferNet = MetaNetCompareBaseline(model.interface, depth=model.depth, width=model.width, out=model.out)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for i, weight_set in weights.items():
|
for weights, parameters in zip(new_weights.values(), TransferNet.parameters()):
|
||||||
TransferNet._meta_layer_list[i].weight = torch.nn.Parameter(torch.tensor(weight_set).view(list(TransferNet.parameters())[i].shape))
|
parameters[:] = torch.Tensor(weights).view(parameters.shape)[:]
|
||||||
|
|
||||||
TransferNet.eval()
|
TransferNet.eval()
|
||||||
metric = torchmetrics.Accuracy()
|
metric = torchmetrics.Accuracy()
|
||||||
@ -56,7 +60,7 @@ if __name__ == '__main__':
|
|||||||
d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER)
|
d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER)
|
||||||
loss_fn = nn.CrossEntropyLoss()
|
loss_fn = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
model = torch.load("0039_model_ckpt.tp", map_location=DEVICE).eval()
|
model = torch.load(Path('experiments/output/trained_model_ckpt_e50.tp'), map_location=DEVICE).eval()
|
||||||
weights = extract_weights_from_model(model)
|
weights = extract_weights_from_model(model)
|
||||||
test_weights_as_model(model, weights, d_test)
|
test_weights_as_model(model, weights, d_test)
|
||||||
|
|
||||||
|
@ -120,7 +120,7 @@ class SparseLayer(nn.Module):
|
|||||||
|
|
||||||
def test_sparse_layer():
|
def test_sparse_layer():
|
||||||
net = SparseLayer(500) #50 parallel nets
|
net = SparseLayer(500) #50 parallel nets
|
||||||
loss_fn = torch.nn.MSELoss(reduction="sum")
|
loss_fn = torch.nn.MSELoss()
|
||||||
optimizer = torch.optim.SGD(net.parameters(), lr=0.004, momentum=0.9)
|
optimizer = torch.optim.SGD(net.parameters(), lr=0.004, momentum=0.9)
|
||||||
# optimizer = torch.optim.SGD([layer.coalesce().values() for layer in net.sparse_sub_layer], lr=0.004, momentum=0.9)
|
# optimizer = torch.optim.SGD([layer.coalesce().values() for layer in net.sparse_sub_layer], lr=0.004, momentum=0.9)
|
||||||
|
|
||||||
@ -138,9 +138,10 @@ def test_sparse_layer():
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
epsilon = pow(10, -5)
|
counter = defaultdict(lambda: 0)
|
||||||
# is each of the networks self-replicating?
|
id_functions = functionalities_test.test_for_fixpoints(counter, list(net.particles))
|
||||||
print(f"identity_fn after {train_iteration+1} self-train iterations: {sum([torch.allclose(out[i], Y[i], rtol=0, atol=epsilon) for i in range(net.nr_nets)])}/{net.nr_nets}")
|
counter = dict(counter)
|
||||||
|
print(f"identity_fn after {train_iteration + 1} self-train epochs: {counter}")
|
||||||
|
|
||||||
|
|
||||||
def embed_batch(x, repeat_dim):
|
def embed_batch(x, repeat_dim):
|
||||||
@ -239,7 +240,7 @@ class SparseNetwork(nn.Module):
|
|||||||
x, target_data = layer.get_self_train_inputs_and_targets()
|
x, target_data = layer.get_self_train_inputs_and_targets()
|
||||||
output = layer(x)
|
output = layer(x)
|
||||||
|
|
||||||
losses.append(F.mse_loss(output, target_data))
|
losses.append(F.mse_loss(output, target_data) / layer.nr_nets)
|
||||||
return torch.hstack(losses).sum(dim=-1, keepdim=True)
|
return torch.hstack(losses).sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
def replace_weights_by_particles(self, particles):
|
def replace_weights_by_particles(self, particles):
|
||||||
@ -269,21 +270,31 @@ def test_sparse_net():
|
|||||||
|
|
||||||
def test_sparse_net_sef_train():
|
def test_sparse_net_sef_train():
|
||||||
net = SparseNetwork(30, 5, 6, 10)
|
net = SparseNetwork(30, 5, 6, 10)
|
||||||
optimizer = torch.optim.SGD(net.parameters(), lr=0.008, momentum=0.9)
|
|
||||||
optimizer_dict = {
|
|
||||||
key: torch.optim.SGD(layer.parameters(), lr=0.008, momentum=0.9) for key, layer in enumerate(net.sparselayers)
|
|
||||||
}
|
|
||||||
epochs = 1000
|
epochs = 1000
|
||||||
loss_fn = torch.nn.MSELoss(reduction="sum")
|
if True:
|
||||||
|
optimizer = torch.optim.SGD(net.parameters(), lr=0.004, momentum=0.9)
|
||||||
for _ in trange(epochs):
|
for _ in trange(epochs):
|
||||||
for layer, optim in zip(net.sparselayers, optimizer_dict.values()):
|
optimizer.zero_grad()
|
||||||
optim.zero_grad()
|
loss = net.combined_self_train()
|
||||||
x, target_data = layer.get_self_train_inputs_and_targets()
|
print(loss)
|
||||||
output = layer(x)
|
exit()
|
||||||
loss = loss_fn(output, target_data)
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optim.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
else:
|
||||||
|
optimizer_dict = {
|
||||||
|
key: torch.optim.SGD(layer.parameters(), lr=0.004, momentum=0.9) for key, layer in enumerate(net.sparselayers)
|
||||||
|
}
|
||||||
|
loss_fn = torch.nn.MSELoss(reduction="mean")
|
||||||
|
|
||||||
|
for layer, optim in zip(net.sparselayers, optimizer_dict.values()):
|
||||||
|
for _ in trange(epochs):
|
||||||
|
optim.zero_grad()
|
||||||
|
x, target_data = layer.get_self_train_inputs_and_targets()
|
||||||
|
output = layer(x)
|
||||||
|
loss = loss_fn(output, target_data)
|
||||||
|
loss.backward()
|
||||||
|
optim.step()
|
||||||
|
|
||||||
# is each of the networks self-replicating?
|
# is each of the networks self-replicating?
|
||||||
counter = defaultdict(lambda: 0)
|
counter = defaultdict(lambda: 0)
|
||||||
@ -313,7 +324,7 @@ def test_manual_for_loop():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_sparse_layer()
|
# test_sparse_layer()
|
||||||
test_sparse_net_sef_train()
|
test_sparse_net_sef_train()
|
||||||
# test_sparse_net()
|
# test_sparse_net()
|
||||||
# for comparison
|
# for comparison
|
||||||
|
Reference in New Issue
Block a user