diff --git a/experiments/meta_task_sanity_exp.py b/experiments/meta_task_sanity_exp.py index 1e94b6b..dc152cf 100644 --- a/experiments/meta_task_sanity_exp.py +++ b/experiments/meta_task_sanity_exp.py @@ -62,13 +62,22 @@ if __name__ == '__main__': train_frame = pd.DataFrame(columns=['Epoch', 'Batch', 'Metric', 'Score']) dataset = MultiplyByXTaskDataset(x=multiplication_target, length=1000000) - dataloader = DataLoader(dataset=dataset, batch_size=8000) + dataloader = DataLoader(dataset=dataset, batch_size=8000, num_workers=0) for epoch in trange(30): mean_batch_loss = [] mean_self_tain_loss = [] + for batch, (batch_x, batch_y) in tenumerate(dataloader): - self_train_loss, _ = net.self_train(2, save_history=False) - is_fixpoint = functionalities_test.is_zero_fixpoint(net) + # self_train_loss, _ = net.self_train(2, save_history=False, learning_rate=0.004) + for _ in range(2): + optimizer.zero_grad() + input_data = net.input_weight_matrix() + target_data = net.create_target_weights(input_data) + output = net(input_data) + self_train_loss = loss_fn(output, target_data) + self_train_loss.backward() + optimizer.step() + is_fixpoint = functionalities_test.is_identity_function(net) optimizer.zero_grad() batch_x_emb = torch.zeros(batch_x.shape[0], 5) @@ -80,7 +89,7 @@ if __name__ == '__main__': optimizer.step() if is_fixpoint: tqdm.write(f'is fixpoint after st : {is_fixpoint}') - tqdm.write(f'is fixpoint after tsk: {functionalities_test.is_zero_fixpoint(net)}') + tqdm.write(f'is fixpoint after tsk: {functionalities_test.is_identity_function(net)}') mean_batch_loss.append(loss.detach()) mean_self_tain_loss.append(self_train_loss.detach()) @@ -92,7 +101,7 @@ if __name__ == '__main__': counter = defaultdict(lambda: 0) functionalities_test.test_for_fixpoints(counter, nets=[net]) - print(dict(counter)) + print(dict(counter), self_train_loss) sanity = net(torch.Tensor([0,0,0,0,1])).detach() print(sanity) print(abs(sanity - multiplication_target))