new sanity methode
This commit is contained in:
@@ -45,7 +45,7 @@ from functionalities_test import test_for_fixpoints
|
||||
WORKER = 10 if not debug else 2
|
||||
debug = False
|
||||
BATCHSIZE = 500 if not debug else 50
|
||||
EPOCH = 100
|
||||
EPOCH = 50
|
||||
VALIDATION_FRQ = 3 if not debug else 1
|
||||
SELF_TRAIN_FRQ = 1 if not debug else 1
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
@@ -279,9 +279,9 @@ if __name__ == '__main__':
|
||||
|
||||
self_train = True
|
||||
training = True
|
||||
train_to_id_first = True
|
||||
train_to_id_first = False
|
||||
train_to_task_first = False
|
||||
train_to_task_first_sequential = False
|
||||
train_to_task_first_sequential = True
|
||||
force_st_for_n_from_last_epochs = 5
|
||||
|
||||
use_sparse_network = False
|
||||
@@ -303,10 +303,12 @@ if __name__ == '__main__':
|
||||
# dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}'
|
||||
id_str = f'{f"_StToId" if train_to_id_first else ""}'
|
||||
tsk_str = f'{f"_Tsk_{tsk_threshold}" if train_to_task_first and tsk_threshold != 1 else ""}'
|
||||
sprs_str = '_sprs' if use_sparse_network else ''
|
||||
f_str = f'_f_{force_st_for_n_from_last_epochs}' if \
|
||||
force_st_for_n_from_last_epochs and train_to_task_first_sequential and train_to_task_first \
|
||||
else ""
|
||||
exp_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{a_str}{res_str}{id_str}{tsk_str}{f_str}'
|
||||
config_str = f'{a_str}{res_str}{id_str}{tsk_str}{f_str}{sprs_str}'
|
||||
exp_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{config_str}'
|
||||
|
||||
for seed in range(n_seeds):
|
||||
seed_path = exp_path / str(seed)
|
||||
@@ -358,8 +360,8 @@ if __name__ == '__main__':
|
||||
force_st = (force_st_for_n_from_last_epochs >= (EPOCH - epoch)
|
||||
) and train_to_task_first_sequential and force_st_for_n_from_last_epochs
|
||||
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
|
||||
# Self Train
|
||||
|
||||
# Self Train
|
||||
if self_train and ((not init_tsk and (is_self_train_epoch or init_st)) or force_st):
|
||||
# Transfer weights
|
||||
if use_sparse_network:
|
||||
@@ -376,6 +378,8 @@ if __name__ == '__main__':
|
||||
# Transfer weights
|
||||
if use_sparse_network:
|
||||
dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights)
|
||||
|
||||
# Task Train
|
||||
if not init_st:
|
||||
# Zero your gradients for every batch!
|
||||
dense_optimizer.zero_grad()
|
||||
|
Reference in New Issue
Block a user