Initial Push

This commit is contained in:
Steffen Illium
2022-02-02 12:03:31 +01:00
parent 1b7581e656
commit eb3b9b8958
3 changed files with 83 additions and 63 deletions

View File

@ -165,7 +165,7 @@ if __name__ == '__main__':
self_train = False
training = False
plotting = False
plotting = True
particle_analysis = True
as_sparse_network_test = True
@ -185,22 +185,28 @@ if __name__ == '__main__':
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
interface = np.prod(dataset[0][0].shape)
metanet = MetaNet(interface, depth=4, width=6, out=10).to(DEVICE).train()
metanet = MetaNet(interface, depth=5, width=6, out=10).to(DEVICE)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(metanet.parameters(), lr=0.004, momentum=0.9)
optimizer = torch.optim.SGD(metanet.parameters(), lr=0.008, momentum=0.9)
train_store = new_train_storage_df()
for epoch in tqdm(range(EPOCH), desc='MetaNet Train - Epochs'):
is_validation_epoch = epoch % VALIDATION_FRQ == 0 if not debug else True
is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True
metanet = metanet.train()
if is_validation_epoch:
metric = torchmetrics.Accuracy()
else:
metric = None
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
if self_train and is_self_train_epoch:
self_train_loss = metanet.combined_self_train(optimizer)
# Zero your gradients for every batch!
optimizer.zero_grad()
self_train_loss = metanet.combined_self_train()
self_train_loss.backward()
# Adjust learning weights
optimizer.step()
step_log = dict(Epoch=epoch, Batch=batch, Metric='Self Train Loss', Score=self_train_loss.item())
train_store.loc[train_store.shape[0]] = step_log
@ -225,6 +231,7 @@ if __name__ == '__main__':
break
if is_validation_epoch:
metanet = metanet.eval()
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric='Train Accuracy', Score=metric.compute().item())
train_store.loc[train_store.shape[0]] = validation_log
@ -241,8 +248,9 @@ if __name__ == '__main__':
step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value)
train_store.loc[train_store.shape[0]] = step_log
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists())
train_store = new_train_storage_df()
# train_store = new_train_storage_df()
metanet.eval()
accuracy = checkpoint_and_validate(metanet, run_path, EPOCH, final_model=True)
validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
Metric='Test Accuracy', Score=accuracy.item())
@ -254,7 +262,7 @@ if __name__ == '__main__':
plot_training_result(df_store_path)
if particle_analysis:
model_path = next(run_path.glob('*ckpt.tp'))
model_path = next(run_path.glob(f'*e{EPOCH}.tp'))
latest_model = torch.load(model_path, map_location=DEVICE).eval()
counter_dict = defaultdict(lambda: 0)
_ = test_for_fixpoints(counter_dict, list(latest_model.particles))