Initial Push
This commit is contained in:
@ -165,7 +165,7 @@ if __name__ == '__main__':
|
||||
|
||||
self_train = False
|
||||
training = False
|
||||
plotting = False
|
||||
plotting = True
|
||||
particle_analysis = True
|
||||
as_sparse_network_test = True
|
||||
|
||||
@ -185,22 +185,28 @@ if __name__ == '__main__':
|
||||
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
|
||||
|
||||
interface = np.prod(dataset[0][0].shape)
|
||||
metanet = MetaNet(interface, depth=4, width=6, out=10).to(DEVICE).train()
|
||||
metanet = MetaNet(interface, depth=5, width=6, out=10).to(DEVICE)
|
||||
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.SGD(metanet.parameters(), lr=0.004, momentum=0.9)
|
||||
optimizer = torch.optim.SGD(metanet.parameters(), lr=0.008, momentum=0.9)
|
||||
|
||||
train_store = new_train_storage_df()
|
||||
for epoch in tqdm(range(EPOCH), desc='MetaNet Train - Epochs'):
|
||||
is_validation_epoch = epoch % VALIDATION_FRQ == 0 if not debug else True
|
||||
is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True
|
||||
metanet = metanet.train()
|
||||
if is_validation_epoch:
|
||||
metric = torchmetrics.Accuracy()
|
||||
else:
|
||||
metric = None
|
||||
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
|
||||
if self_train and is_self_train_epoch:
|
||||
self_train_loss = metanet.combined_self_train(optimizer)
|
||||
# Zero your gradients for every batch!
|
||||
optimizer.zero_grad()
|
||||
self_train_loss = metanet.combined_self_train()
|
||||
self_train_loss.backward()
|
||||
# Adjust learning weights
|
||||
optimizer.step()
|
||||
step_log = dict(Epoch=epoch, Batch=batch, Metric='Self Train Loss', Score=self_train_loss.item())
|
||||
train_store.loc[train_store.shape[0]] = step_log
|
||||
|
||||
@ -225,6 +231,7 @@ if __name__ == '__main__':
|
||||
break
|
||||
|
||||
if is_validation_epoch:
|
||||
metanet = metanet.eval()
|
||||
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
|
||||
Metric='Train Accuracy', Score=metric.compute().item())
|
||||
train_store.loc[train_store.shape[0]] = validation_log
|
||||
@ -241,8 +248,9 @@ if __name__ == '__main__':
|
||||
step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value)
|
||||
train_store.loc[train_store.shape[0]] = step_log
|
||||
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists())
|
||||
train_store = new_train_storage_df()
|
||||
# train_store = new_train_storage_df()
|
||||
|
||||
metanet.eval()
|
||||
accuracy = checkpoint_and_validate(metanet, run_path, EPOCH, final_model=True)
|
||||
validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
|
||||
Metric='Test Accuracy', Score=accuracy.item())
|
||||
@ -254,7 +262,7 @@ if __name__ == '__main__':
|
||||
plot_training_result(df_store_path)
|
||||
|
||||
if particle_analysis:
|
||||
model_path = next(run_path.glob('*ckpt.tp'))
|
||||
model_path = next(run_path.glob(f'*e{EPOCH}.tp'))
|
||||
latest_model = torch.load(model_path, map_location=DEVICE).eval()
|
||||
counter_dict = defaultdict(lambda: 0)
|
||||
_ = test_for_fixpoints(counter_dict, list(latest_model.particles))
|
||||
|
Reference in New Issue
Block a user