sparse net training

This commit is contained in:
Steffen Illium
2022-02-26 16:01:12 +01:00
parent 9d8496a725
commit c0db8e19a3
3 changed files with 116 additions and 92 deletions

View File

@@ -287,17 +287,17 @@ def flat_for_store(parameters):
if __name__ == '__main__':
self_train = True
training = False
train_to_id_first = True
training = True
train_to_id_first = False
train_to_task_first = False
sequential_task_train = True
force_st_for_n_from_last_epochs = 5
n_st_per_batch = 3
activation = None # nn.ReLU()
use_sparse_network = True
use_sparse_network = False
for weight_hidden_size in [3, 4, 5, 6]:
for weight_hidden_size in [8]:
tsk_threshold = 0.85
weight_hidden_size = weight_hidden_size
@@ -353,15 +353,16 @@ if __name__ == '__main__':
meta_weight_count = sum(p.numel() for p in next(dense_metanet.particles).parameters())
loss_fn = nn.CrossEntropyLoss()
dense_optimizer = torch.optim.SGD(dense_metanet.parameters(), lr=0.008, momentum=0.9)
dense_optimizer = torch.optim.SGD(dense_metanet.parameters(), lr=0.004, momentum=0.9)
sparse_optimizer = torch.optim.SGD(
sparse_metanet.parameters(), lr=0.008, momentum=0.9
sparse_metanet.parameters(), lr=0.004, momentum=0.9
) if use_sparse_network else dense_optimizer
train_store = new_storage_df('train', None)
weight_store = new_storage_df('weights', meta_weight_count)
init_tsk = train_to_task_first
for epoch in tqdm(range(EPOCH), desc='MetaNet Train - Epochs'):
for epoch in tqdm(range(EPOCH), desc=f'Train - Epochs'):
tqdm.write(f'{seed}: {exp_path}')
is_validation_epoch = epoch % VALIDATION_FRQ == 0 if not debug else True
is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True
sparse_metanet = sparse_metanet.train()