README.md Update
This commit is contained in:
@ -198,12 +198,13 @@ if __name__ == '__main__':
|
||||
plotting = True
|
||||
particle_analysis = True
|
||||
as_sparse_network_test = True
|
||||
self_train_alpha = 100
|
||||
self_train_alpha = 1
|
||||
batch_train_beta = 1
|
||||
|
||||
data_path = Path('data')
|
||||
data_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
run_path = Path('output') / 'mn_st_200_8_alpha_100'
|
||||
run_path = Path('output') / 'mn_st_400_2_no_res'
|
||||
model_path = run_path / '0000_trained_model.zip'
|
||||
df_store_path = run_path / 'train_store.csv'
|
||||
weight_store_path = run_path / 'weight_store.csv'
|
||||
@ -217,7 +218,7 @@ if __name__ == '__main__':
|
||||
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
|
||||
|
||||
interface = np.prod(dataset[0][0].shape)
|
||||
metanet = MetaNet(interface, depth=5, width=6, out=10, residual_skip=True).to(DEVICE)
|
||||
metanet = MetaNet(interface, depth=5, width=6, out=10, residual_skip=False).to(DEVICE)
|
||||
meta_weight_count = sum(p.numel() for p in next(metanet.particles).parameters())
|
||||
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
@ -249,7 +250,7 @@ if __name__ == '__main__':
|
||||
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
|
||||
y = metanet(batch_x)
|
||||
# loss = loss_fn(y, batch_y.unsqueeze(-1).to(torch.float32))
|
||||
loss = loss_fn(y, batch_y.to(torch.long))
|
||||
loss = loss_fn(y, batch_y.to(torch.long)) * batch_train_beta
|
||||
loss.backward()
|
||||
|
||||
# Adjust learning weights
|
||||
@ -312,7 +313,7 @@ if __name__ == '__main__':
|
||||
plot_training_result(df_store_path)
|
||||
if particle_analysis:
|
||||
plot_training_particle_types(df_store_path)
|
||||
exit()
|
||||
|
||||
if particle_analysis:
|
||||
model_path = next(run_path.glob(f'*e{EPOCH}.tp'))
|
||||
latest_model = torch.load(model_path, map_location=DEVICE).eval()
|
||||
|
Reference in New Issue
Block a user