new sanity methode

This commit is contained in:
Steffen Illium
2022-02-23 18:23:00 +01:00
parent ebf133414c
commit 3da00c793b
4 changed files with 57 additions and 32 deletions

View File

@@ -45,7 +45,7 @@ from functionalities_test import test_for_fixpoints
WORKER = 10 if not debug else 2
debug = False
BATCHSIZE = 500 if not debug else 50
EPOCH = 100
EPOCH = 50
VALIDATION_FRQ = 3 if not debug else 1
SELF_TRAIN_FRQ = 1 if not debug else 1
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -279,9 +279,9 @@ if __name__ == '__main__':
self_train = True
training = True
train_to_id_first = True
train_to_id_first = False
train_to_task_first = False
train_to_task_first_sequential = False
train_to_task_first_sequential = True
force_st_for_n_from_last_epochs = 5
use_sparse_network = False
@@ -303,10 +303,12 @@ if __name__ == '__main__':
# dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}'
id_str = f'{f"_StToId" if train_to_id_first else ""}'
tsk_str = f'{f"_Tsk_{tsk_threshold}" if train_to_task_first and tsk_threshold != 1 else ""}'
sprs_str = '_sprs' if use_sparse_network else ''
f_str = f'_f_{force_st_for_n_from_last_epochs}' if \
force_st_for_n_from_last_epochs and train_to_task_first_sequential and train_to_task_first \
else ""
exp_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{a_str}{res_str}{id_str}{tsk_str}{f_str}'
config_str = f'{a_str}{res_str}{id_str}{tsk_str}{f_str}{sprs_str}'
exp_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{config_str}'
for seed in range(n_seeds):
seed_path = exp_path / str(seed)
@@ -358,8 +360,8 @@ if __name__ == '__main__':
force_st = (force_st_for_n_from_last_epochs >= (EPOCH - epoch)
) and train_to_task_first_sequential and force_st_for_n_from_last_epochs
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
# Self Train
# Self Train
if self_train and ((not init_tsk and (is_self_train_epoch or init_st)) or force_st):
# Transfer weights
if use_sparse_network:
@@ -376,6 +378,8 @@ if __name__ == '__main__':
# Transfer weights
if use_sparse_network:
dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights)
# Task Train
if not init_st:
# Zero your gradients for every batch!
dense_optimizer.zero_grad()