README.md Update

This commit is contained in:
Steffen Illium
2022-02-10 16:53:49 +01:00
parent 594bbaa3dd
commit 14768ffc0a
8 changed files with 134 additions and 18 deletions

View File

@ -198,12 +198,13 @@ if __name__ == '__main__':
plotting = True
particle_analysis = True
as_sparse_network_test = True
self_train_alpha = 100
self_train_alpha = 1
batch_train_beta = 1
data_path = Path('data')
data_path.mkdir(exist_ok=True, parents=True)
run_path = Path('output') / 'mn_st_200_8_alpha_100'
run_path = Path('output') / 'mn_st_400_2_no_res'
model_path = run_path / '0000_trained_model.zip'
df_store_path = run_path / 'train_store.csv'
weight_store_path = run_path / 'weight_store.csv'
@ -217,7 +218,7 @@ if __name__ == '__main__':
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
interface = np.prod(dataset[0][0].shape)
metanet = MetaNet(interface, depth=5, width=6, out=10, residual_skip=True).to(DEVICE)
metanet = MetaNet(interface, depth=5, width=6, out=10, residual_skip=False).to(DEVICE)
meta_weight_count = sum(p.numel() for p in next(metanet.particles).parameters())
loss_fn = nn.CrossEntropyLoss()
@ -249,7 +250,7 @@ if __name__ == '__main__':
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
y = metanet(batch_x)
# loss = loss_fn(y, batch_y.unsqueeze(-1).to(torch.float32))
loss = loss_fn(y, batch_y.to(torch.long))
loss = loss_fn(y, batch_y.to(torch.long)) * batch_train_beta
loss.backward()
# Adjust learning weights
@ -312,7 +313,7 @@ if __name__ == '__main__':
plot_training_result(df_store_path)
if particle_analysis:
plot_training_particle_types(df_store_path)
exit()
if particle_analysis:
model_path = next(run_path.glob(f'*e{EPOCH}.tp'))
latest_model = torch.load(model_path, map_location=DEVICE).eval()