smaller task train
This commit is contained in:
parent
16c08d04d4
commit
c52b398819
@ -62,13 +62,22 @@ if __name__ == '__main__':
|
||||
train_frame = pd.DataFrame(columns=['Epoch', 'Batch', 'Metric', 'Score'])
|
||||
|
||||
dataset = MultiplyByXTaskDataset(x=multiplication_target, length=1000000)
|
||||
dataloader = DataLoader(dataset=dataset, batch_size=8000)
|
||||
dataloader = DataLoader(dataset=dataset, batch_size=8000, num_workers=0)
|
||||
for epoch in trange(30):
|
||||
mean_batch_loss = []
|
||||
mean_self_tain_loss = []
|
||||
|
||||
for batch, (batch_x, batch_y) in tenumerate(dataloader):
|
||||
self_train_loss, _ = net.self_train(2, save_history=False)
|
||||
is_fixpoint = functionalities_test.is_zero_fixpoint(net)
|
||||
# self_train_loss, _ = net.self_train(2, save_history=False, learning_rate=0.004)
|
||||
for _ in range(2):
|
||||
optimizer.zero_grad()
|
||||
input_data = net.input_weight_matrix()
|
||||
target_data = net.create_target_weights(input_data)
|
||||
output = net(input_data)
|
||||
self_train_loss = loss_fn(output, target_data)
|
||||
self_train_loss.backward()
|
||||
optimizer.step()
|
||||
is_fixpoint = functionalities_test.is_identity_function(net)
|
||||
|
||||
optimizer.zero_grad()
|
||||
batch_x_emb = torch.zeros(batch_x.shape[0], 5)
|
||||
@ -80,7 +89,7 @@ if __name__ == '__main__':
|
||||
optimizer.step()
|
||||
if is_fixpoint:
|
||||
tqdm.write(f'is fixpoint after st : {is_fixpoint}')
|
||||
tqdm.write(f'is fixpoint after tsk: {functionalities_test.is_zero_fixpoint(net)}')
|
||||
tqdm.write(f'is fixpoint after tsk: {functionalities_test.is_identity_function(net)}')
|
||||
|
||||
mean_batch_loss.append(loss.detach())
|
||||
mean_self_tain_loss.append(self_train_loss.detach())
|
||||
@ -92,7 +101,7 @@ if __name__ == '__main__':
|
||||
|
||||
counter = defaultdict(lambda: 0)
|
||||
functionalities_test.test_for_fixpoints(counter, nets=[net])
|
||||
print(dict(counter))
|
||||
print(dict(counter), self_train_loss)
|
||||
sanity = net(torch.Tensor([0,0,0,0,1])).detach()
|
||||
print(sanity)
|
||||
print(abs(sanity - multiplication_target))
|
||||
|
Loading…
x
Reference in New Issue
Block a user