refactoring and running experiments

This commit is contained in:
Steffen Illium
2022-03-05 16:51:19 +01:00
parent 69c904e156
commit f3ff4c9239
5 changed files with 196 additions and 203 deletions

View File

@ -1,12 +1,5 @@
from pathlib import Path
import torch
import torchmetrics
from torch.utils.data import Dataset
from tqdm import tqdm
from experiments.meta_task_utility import set_checkpoint
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@ -24,31 +17,6 @@ class AddTaskDataset(Dataset):
return ab, ab.sum(axis=-1, keepdims=True)
def validate(checkpoint_path, valid_d, ratio=1, validmetric=torchmetrics.MeanAbsoluteError()):
checkpoint_path = Path(checkpoint_path)
# initialize metric
model = torch.load(checkpoint_path, map_location=DEVICE).eval()
n_samples = int(len(valid_d) * ratio)
with tqdm(total=n_samples, desc='Validation Run: ') as pbar:
for idx, (valid_batch_x, valid_batch_y) in enumerate(valid_d):
valid_batch_x, valid_batch_y = valid_batch_x.to(DEVICE), valid_batch_y.to(DEVICE)
y_valid = model(valid_batch_x)
# metric on current batch
acc = validmetric(y_valid.cpu(), valid_batch_y.cpu())
pbar.set_postfix_str(f'Acc: {acc}')
pbar.update()
if idx == n_samples:
break
# metric on all batches using custom accumulation
acc = validmetric.compute()
tqdm.write(f"Avg. Accuracy on all data: {acc}")
return acc
def train_task(model, optimizer, loss_func, btch_x, btch_y) -> (dict, torch.Tensor):
# Zero your gradients for every batch!
optimizer.zero_grad()
@ -66,12 +34,5 @@ def train_task(model, optimizer, loss_func, btch_x, btch_y) -> (dict, torch.Tens
return stp_log, y_prd
def checkpoint_and_validate(model, out_path, epoch_n, valid_d, final_model=False):
out_path = Path(out_path)
ckpt_path = set_checkpoint(model, out_path, epoch_n, final_model=final_model)
result = validate(ckpt_path, valid_d)
return result
if __name__ == '__main__':
raise(NotImplementedError('Get out of here'))