residual skip in metacomparebaseline

This commit is contained in:
Steffen Illium 2022-02-23 18:32:36 +01:00
parent 3da00c793b
commit 5b2b5b5beb
2 changed files with 20 additions and 13 deletions

View File

@ -487,25 +487,33 @@ class MetaNet(nn.Module):
class MetaNetCompareBaseline(nn.Module): class MetaNetCompareBaseline(nn.Module):
def __init__(self, interface=4, depth=3, width=4, out=1, activation=None): def __init__(self, interface=4, depth=3, width=4, out=1, activation=None, residual_skip=True):
super().__init__() super().__init__()
self.residual_skip = residual_skip
self.activation = activation self.activation = activation
self.out = out self.out = out
self.interface = interface self.interface = interface
self.width = width self.width = width
self.depth = depth self.depth = depth
self._meta_layer_list = nn.ModuleList() self._first_layer = nn.Linear(self.interface, self.width, bias=False)
self._meta_layer_list = nn.ModuleList([nn.Linear(self.width, self.width, bias=False) for _ in range(self.depth - 2)])
self._meta_layer_list.append(nn.Linear(self.interface, self.width, bias=False)) self._last_layer = nn.Linear(self.width, self.out, bias=False)
self._meta_layer_list.extend([nn.Linear(self.width, self.width, bias=False) for _ in range(self.depth - 2)])
self._meta_layer_list.append(nn.Linear(self.width, self.out, bias=False))
def forward(self, x): def forward(self, x):
tensor = x tensor = self._first_layer(x)
for meta_layer in self._meta_layer_list: for idx, meta_layer in enumerate(self._meta_layer_list, start=1):
if idx % 2 == 1 and self.residual_skip:
x = tensor.clone()
tensor = meta_layer(tensor) tensor = meta_layer(tensor)
if idx % 2 == 0 and self.residual_skip:
tensor = tensor + x
tensor = self._last_layer(tensor)
return tensor return tensor
@property
def all_layers(self):
return (x for x in (self._first_layer, *self._meta_layer_list, self._last_layer))
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -26,7 +26,8 @@ def extract_weights_from_model(model:MetaNet)->dict:
def test_weights_as_model(model, new_weights:dict, data): def test_weights_as_model(model, new_weights:dict, data):
TransferNet = MetaNetCompareBaseline(model.interface, depth=model.depth, width=model.width, out=model.out) TransferNet = MetaNetCompareBaseline(model.interface, depth=model.depth, width=model.width, out=model.out,
residual_skip=True)
with torch.no_grad(): with torch.no_grad():
for weights, parameters in zip(new_weights.values(), TransferNet.parameters()): for weights, parameters in zip(new_weights.values(), TransferNet.parameters()):
@ -37,7 +38,6 @@ def test_weights_as_model(model, new_weights:dict, data):
with tqdm(desc='Test Batch: ') as pbar: with tqdm(desc='Test Batch: ') as pbar:
for batch, (batch_x, batch_y) in tqdm(enumerate(data), total=len(data), desc='MetaNet Sanity Check'): for batch, (batch_x, batch_y) in tqdm(enumerate(data), total=len(data), desc='MetaNet Sanity Check'):
y = TransferNet(batch_x) y = TransferNet(batch_x)
loss = loss_fn(y, batch_y)
acc = metric(y.cpu(), batch_y.cpu()) acc = metric(y.cpu(), batch_y.cpu())
pbar.set_postfix_str(f'Acc: {acc}') pbar.set_postfix_str(f'Acc: {acc}')
pbar.update() pbar.update()
@ -52,13 +52,12 @@ if __name__ == '__main__':
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
WORKER = 0 WORKER = 0
BATCHSIZE = 500 BATCHSIZE = 500
MNIST_TRANSFORM = Compose([Resize((15, 15)), ToTensor(), Normalize((0.1307,), (0.3081,)), Flatten(start_dim=0)]) MNIST_TRANSFORM = Compose([Resize((15, 15)), ToTensor(), Flatten(start_dim=0)])
torch.manual_seed(42) torch.manual_seed(42)
data_path = Path('data') data_path = Path('data')
data_path.mkdir(exist_ok=True, parents=True) data_path.mkdir(exist_ok=True, parents=True)
mnist_test = MNIST(str(data_path), transform=MNIST_TRANSFORM, download=True, train=False) mnist_test = MNIST(str(data_path), transform=MNIST_TRANSFORM, download=True, train=False)
d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER) d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER)
loss_fn = nn.CrossEntropyLoss()
model = torch.load(Path('experiments/output/trained_model_ckpt_e50.tp'), map_location=DEVICE).eval() model = torch.load(Path('experiments/output/trained_model_ckpt_e50.tp'), map_location=DEVICE).eval()
weights = extract_weights_from_model(model) weights = extract_weights_from_model(model)