parameters for training
This commit is contained in:
parent
d4c25872c6
commit
594bbaa3dd
@ -41,7 +41,7 @@ from functionalities_test import test_for_fixpoints, FixTypes
|
||||
|
||||
WORKER = 10 if not debug else 2
|
||||
BATCHSIZE = 500 if not debug else 50
|
||||
EPOCH = 400 if not debug else 3
|
||||
EPOCH = 200
|
||||
VALIDATION_FRQ = 5 if not debug else 1
|
||||
SELF_TRAIN_FRQ = 1 if not debug else 1
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
@ -194,15 +194,16 @@ def flat_for_store(parameters):
|
||||
if __name__ == '__main__':
|
||||
|
||||
self_train = True
|
||||
training = False
|
||||
training = True
|
||||
plotting = True
|
||||
particle_analysis = True
|
||||
as_sparse_network_test = True
|
||||
self_train_alpha = 100
|
||||
|
||||
data_path = Path('data')
|
||||
data_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
run_path = Path('output') / 'mn_st_NoRes'
|
||||
run_path = Path('output') / 'mn_st_200_8_alpha_100'
|
||||
model_path = run_path / '0000_trained_model.zip'
|
||||
df_store_path = run_path / 'train_store.csv'
|
||||
weight_store_path = run_path / 'weight_store.csv'
|
||||
@ -216,7 +217,7 @@ if __name__ == '__main__':
|
||||
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
|
||||
|
||||
interface = np.prod(dataset[0][0].shape)
|
||||
metanet = MetaNet(interface, depth=5, width=6, out=10, residual_skip=False).to(DEVICE)
|
||||
metanet = MetaNet(interface, depth=5, width=6, out=10, residual_skip=True).to(DEVICE)
|
||||
meta_weight_count = sum(p.numel() for p in next(metanet.particles).parameters())
|
||||
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
@ -236,7 +237,7 @@ if __name__ == '__main__':
|
||||
if self_train and is_self_train_epoch:
|
||||
# Zero your gradients for every batch!
|
||||
optimizer.zero_grad()
|
||||
self_train_loss = metanet.combined_self_train()
|
||||
self_train_loss = metanet.combined_self_train() * self_train_alpha
|
||||
self_train_loss.backward()
|
||||
# Adjust learning weights
|
||||
optimizer.step()
|
||||
|
@ -296,7 +296,7 @@ class MetaCell(nn.Module):
|
||||
self.name = name
|
||||
self.interface = interface
|
||||
self.weight_interface = 5
|
||||
self.net_hidden_size = 4
|
||||
self.net_hidden_size = 8
|
||||
self.net_ouput_size = 1
|
||||
self.meta_weight_list = nn.ModuleList()
|
||||
self.meta_weight_list.extend(
|
||||
|
Loading…
x
Reference in New Issue
Block a user