smaller task train

This commit is contained in:
Steffen Illium 2022-03-03 21:43:38 +01:00
parent 16c08d04d4
commit c52b398819

View File

@ -62,13 +62,22 @@ if __name__ == '__main__':
train_frame = pd.DataFrame(columns=['Epoch', 'Batch', 'Metric', 'Score']) train_frame = pd.DataFrame(columns=['Epoch', 'Batch', 'Metric', 'Score'])
dataset = MultiplyByXTaskDataset(x=multiplication_target, length=1000000) dataset = MultiplyByXTaskDataset(x=multiplication_target, length=1000000)
dataloader = DataLoader(dataset=dataset, batch_size=8000) dataloader = DataLoader(dataset=dataset, batch_size=8000, num_workers=0)
for epoch in trange(30): for epoch in trange(30):
mean_batch_loss = [] mean_batch_loss = []
mean_self_tain_loss = [] mean_self_tain_loss = []
for batch, (batch_x, batch_y) in tenumerate(dataloader): for batch, (batch_x, batch_y) in tenumerate(dataloader):
self_train_loss, _ = net.self_train(2, save_history=False) # self_train_loss, _ = net.self_train(2, save_history=False, learning_rate=0.004)
is_fixpoint = functionalities_test.is_zero_fixpoint(net) for _ in range(2):
optimizer.zero_grad()
input_data = net.input_weight_matrix()
target_data = net.create_target_weights(input_data)
output = net(input_data)
self_train_loss = loss_fn(output, target_data)
self_train_loss.backward()
optimizer.step()
is_fixpoint = functionalities_test.is_identity_function(net)
optimizer.zero_grad() optimizer.zero_grad()
batch_x_emb = torch.zeros(batch_x.shape[0], 5) batch_x_emb = torch.zeros(batch_x.shape[0], 5)
@ -80,7 +89,7 @@ if __name__ == '__main__':
optimizer.step() optimizer.step()
if is_fixpoint: if is_fixpoint:
tqdm.write(f'is fixpoint after st : {is_fixpoint}') tqdm.write(f'is fixpoint after st : {is_fixpoint}')
tqdm.write(f'is fixpoint after tsk: {functionalities_test.is_zero_fixpoint(net)}') tqdm.write(f'is fixpoint after tsk: {functionalities_test.is_identity_function(net)}')
mean_batch_loss.append(loss.detach()) mean_batch_loss.append(loss.detach())
mean_self_tain_loss.append(self_train_loss.detach()) mean_self_tain_loss.append(self_train_loss.detach())
@ -92,7 +101,7 @@ if __name__ == '__main__':
counter = defaultdict(lambda: 0) counter = defaultdict(lambda: 0)
functionalities_test.test_for_fixpoints(counter, nets=[net]) functionalities_test.test_for_fixpoints(counter, nets=[net])
print(dict(counter)) print(dict(counter), self_train_loss)
sanity = net(torch.Tensor([0,0,0,0,1])).detach() sanity = net(torch.Tensor([0,0,0,0,1])).detach()
print(sanity) print(sanity)
print(abs(sanity - multiplication_target)) print(abs(sanity - multiplication_target))