85 lines
3.7 KiB
Python
85 lines
3.7 KiB
Python
from tqdm import tqdm
|
|
import pandas as pd
|
|
from pathlib import Path
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import Flatten
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from torchvision.datasets import MNIST, CIFAR10
|
|
from torchvision.transforms import ToTensor, Compose, Resize, Normalize, Grayscale
|
|
import torchmetrics
|
|
import pickle
|
|
|
|
from network import MetaNetCompareBaseline
|
|
|
|
WORKER = 0
|
|
BATCHSIZE = 500
|
|
EPOCH = 10
|
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
MNIST_TRANSFORM = Compose([ Resize((10, 10)), ToTensor(), Normalize((0.1307,), (0.3081,)), Flatten(start_dim=0)])
|
|
CIFAR10_TRANSFORM = Compose([ Grayscale(num_output_channels=1), Resize((10, 10)), ToTensor(), Normalize((0.48,), (0.25,)), Flatten(start_dim=0)])
|
|
|
|
|
|
def train_and_test(testnet, optimizer, loss, trainset, testset):
|
|
d_train = DataLoader(trainset, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER)
|
|
d_test = DataLoader(testset, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER)
|
|
|
|
# train
|
|
for epoch in tqdm(range(EPOCH), desc='MetaNet Train - Epoch'):
|
|
for batch, (batch_x, batch_y) in enumerate(d_train):
|
|
optimizer.zero_grad()
|
|
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
|
|
y = testnet(batch_x)
|
|
loss = loss_fn(y, batch_y)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
# test
|
|
testnet.eval()
|
|
metric = torchmetrics.Accuracy()
|
|
with tqdm(desc='Test Batch: ') as pbar:
|
|
for batch, (batch_x, batch_y) in tqdm(enumerate(d_test), total=len(d_test), desc='MetaNet Test - Batch'):
|
|
y = testnet(batch_x)
|
|
loss = loss_fn(y, batch_y)
|
|
acc = metric(y.cpu(), batch_y.cpu())
|
|
pbar.set_postfix_str(f'Acc: {acc}')
|
|
pbar.update()
|
|
|
|
acc = metric.compute()
|
|
tqdm.write(f"Avg. accuracy on all data: {acc}")
|
|
return acc
|
|
|
|
if __name__ == '__main__':
|
|
torch.manual_seed(42)
|
|
data_path = Path('data')
|
|
data_path.mkdir(exist_ok=True, parents=True)
|
|
mnist_train = MNIST(str(data_path), transform=MNIST_TRANSFORM, download=True, train=True)
|
|
mnist_test = MNIST(str(data_path), transform=MNIST_TRANSFORM, download=True, train=False)
|
|
cifar10_train = CIFAR10(str(data_path), transform=CIFAR10_TRANSFORM, download=True, train=True)
|
|
cifar10_test = CIFAR10(str(data_path), transform=CIFAR10_TRANSFORM, download=True, train=False)
|
|
loss_fn = nn.CrossEntropyLoss()
|
|
frame = pd.DataFrame(columns=['Dataset', 'Neurons', 'Layers', 'Parameters', 'Accuracy'])
|
|
|
|
for name, trainset, testset in [("MNIST",mnist_train,mnist_test), ("CIFAR10",cifar10_train,cifar10_test)]:
|
|
best_acc = 0
|
|
neuron_count = 0
|
|
layer_count = 0
|
|
|
|
# find upper bound (in steps of 10, neurons/layer > 200 will start back from 10 with layers+1)
|
|
while best_acc <= 0.95:
|
|
neuron_count += 10
|
|
if neuron_count >= 210:
|
|
neuron_count = 10
|
|
layer_count += 1
|
|
net = MetaNetCompareBaseline(100, layer_count, neuron_count, out=10)
|
|
optimizer = torch.optim.SGD(net.parameters(), lr=0.004, momentum=0.9)
|
|
acc = train_and_test(net, optimizer, loss_fn, trainset, testset)
|
|
if acc > best_acc:
|
|
best_acc = acc
|
|
|
|
num_params = sum(p.numel() for p in net._meta_layer_list.parameters())
|
|
frame.loc[frame.shape[0]] = dict(Dataset=name, Neurons=neuron_count, Layers=layer_count, Parameters=num_params, Accuracy=acc)
|
|
print(f"> {name}\t| {neuron_count} neurons\t| {layer_count} h.-layer(s)\t| {num_params} params\n")
|
|
|
|
print(frame)
|
|
pickle.dump(frame, "min_net_search_df.pkl") |