smaller task train
This commit is contained in:
@ -295,8 +295,7 @@ def flat_for_store(parameters):
|
||||
|
||||
|
||||
def train_self_replication(model, optimizer, st_stps) -> dict:
|
||||
for _ in range(st_stps):
|
||||
self_train_loss = model.combined_self_train(optimizer)
|
||||
self_train_loss = model.combined_self_train(optimizer, st_stps)
|
||||
# noinspection PyUnboundLocalVariable
|
||||
stp_log = dict(Metric='Self Train Loss', Score=self_train_loss.item())
|
||||
return stp_log
|
||||
|
@ -68,7 +68,8 @@ if __name__ == '__main__':
|
||||
mean_self_tain_loss = []
|
||||
|
||||
for batch, (batch_x, batch_y) in tenumerate(dataloader):
|
||||
# self_train_loss, _ = net.self_train(2, save_history=False, learning_rate=0.004)
|
||||
self_train_loss, _ = net.self_train(2, save_history=False, learning_rate=0.004)
|
||||
|
||||
for _ in range(2):
|
||||
optimizer.zero_grad()
|
||||
input_data = net.input_weight_matrix()
|
||||
|
Reference in New Issue
Block a user