parameters for training

This commit is contained in:
Steffen Illium 2022-02-08 17:24:00 +01:00
parent d4c25872c6
commit 594bbaa3dd
2 changed files with 7 additions and 6 deletions

View File

@ -41,7 +41,7 @@ from functionalities_test import test_for_fixpoints, FixTypes
WORKER = 10 if not debug else 2
BATCHSIZE = 500 if not debug else 50
EPOCH = 400 if not debug else 3
EPOCH = 200
VALIDATION_FRQ = 5 if not debug else 1
SELF_TRAIN_FRQ = 1 if not debug else 1
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@ -194,15 +194,16 @@ def flat_for_store(parameters):
if __name__ == '__main__':
self_train = True
training = False
training = True
plotting = True
particle_analysis = True
as_sparse_network_test = True
self_train_alpha = 100
data_path = Path('data')
data_path.mkdir(exist_ok=True, parents=True)
run_path = Path('output') / 'mn_st_NoRes'
run_path = Path('output') / 'mn_st_200_8_alpha_100'
model_path = run_path / '0000_trained_model.zip'
df_store_path = run_path / 'train_store.csv'
weight_store_path = run_path / 'weight_store.csv'
@ -216,7 +217,7 @@ if __name__ == '__main__':
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
interface = np.prod(dataset[0][0].shape)
metanet = MetaNet(interface, depth=5, width=6, out=10, residual_skip=False).to(DEVICE)
metanet = MetaNet(interface, depth=5, width=6, out=10, residual_skip=True).to(DEVICE)
meta_weight_count = sum(p.numel() for p in next(metanet.particles).parameters())
loss_fn = nn.CrossEntropyLoss()
@ -236,7 +237,7 @@ if __name__ == '__main__':
if self_train and is_self_train_epoch:
# Zero your gradients for every batch!
optimizer.zero_grad()
self_train_loss = metanet.combined_self_train()
self_train_loss = metanet.combined_self_train() * self_train_alpha
self_train_loss.backward()
# Adjust learning weights
optimizer.step()

View File

@ -296,7 +296,7 @@ class MetaCell(nn.Module):
self.name = name
self.interface = interface
self.weight_interface = 5
self.net_hidden_size = 4
self.net_hidden_size = 8
self.net_ouput_size = 1
self.meta_weight_list = nn.ModuleList()
self.meta_weight_list.extend(