sparse net training
This commit is contained in:
@@ -287,17 +287,17 @@ def flat_for_store(parameters):
|
||||
if __name__ == '__main__':
|
||||
|
||||
self_train = True
|
||||
training = False
|
||||
train_to_id_first = True
|
||||
training = True
|
||||
train_to_id_first = False
|
||||
train_to_task_first = False
|
||||
sequential_task_train = True
|
||||
force_st_for_n_from_last_epochs = 5
|
||||
n_st_per_batch = 3
|
||||
activation = None # nn.ReLU()
|
||||
|
||||
use_sparse_network = True
|
||||
use_sparse_network = False
|
||||
|
||||
for weight_hidden_size in [3, 4, 5, 6]:
|
||||
for weight_hidden_size in [8]:
|
||||
|
||||
tsk_threshold = 0.85
|
||||
weight_hidden_size = weight_hidden_size
|
||||
@@ -353,15 +353,16 @@ if __name__ == '__main__':
|
||||
meta_weight_count = sum(p.numel() for p in next(dense_metanet.particles).parameters())
|
||||
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
dense_optimizer = torch.optim.SGD(dense_metanet.parameters(), lr=0.008, momentum=0.9)
|
||||
dense_optimizer = torch.optim.SGD(dense_metanet.parameters(), lr=0.004, momentum=0.9)
|
||||
sparse_optimizer = torch.optim.SGD(
|
||||
sparse_metanet.parameters(), lr=0.008, momentum=0.9
|
||||
sparse_metanet.parameters(), lr=0.004, momentum=0.9
|
||||
) if use_sparse_network else dense_optimizer
|
||||
|
||||
train_store = new_storage_df('train', None)
|
||||
weight_store = new_storage_df('weights', meta_weight_count)
|
||||
init_tsk = train_to_task_first
|
||||
for epoch in tqdm(range(EPOCH), desc='MetaNet Train - Epochs'):
|
||||
for epoch in tqdm(range(EPOCH), desc=f'Train - Epochs'):
|
||||
tqdm.write(f'{seed}: {exp_path}')
|
||||
is_validation_epoch = epoch % VALIDATION_FRQ == 0 if not debug else True
|
||||
is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True
|
||||
sparse_metanet = sparse_metanet.train()
|
||||
|
Reference in New Issue
Block a user