residual skip in metacomparebaseline
This commit is contained in:
parent
3da00c793b
commit
5b2b5b5beb
24
network.py
24
network.py
@ -487,26 +487,34 @@ class MetaNet(nn.Module):
|
||||
|
||||
class MetaNetCompareBaseline(nn.Module):
|
||||
|
||||
def __init__(self, interface=4, depth=3, width=4, out=1, activation=None):
|
||||
def __init__(self, interface=4, depth=3, width=4, out=1, activation=None, residual_skip=True):
|
||||
super().__init__()
|
||||
self.residual_skip = residual_skip
|
||||
self.activation = activation
|
||||
self.out = out
|
||||
self.interface = interface
|
||||
self.width = width
|
||||
self.depth = depth
|
||||
|
||||
self._meta_layer_list = nn.ModuleList()
|
||||
|
||||
self._meta_layer_list.append(nn.Linear(self.interface, self.width, bias=False))
|
||||
self._meta_layer_list.extend([nn.Linear(self.width, self.width, bias=False) for _ in range(self.depth - 2)])
|
||||
self._meta_layer_list.append(nn.Linear(self.width, self.out, bias=False))
|
||||
self._first_layer = nn.Linear(self.interface, self.width, bias=False)
|
||||
self._meta_layer_list = nn.ModuleList([nn.Linear(self.width, self.width, bias=False) for _ in range(self.depth - 2)])
|
||||
self._last_layer = nn.Linear(self.width, self.out, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
tensor = x
|
||||
for meta_layer in self._meta_layer_list:
|
||||
tensor = self._first_layer(x)
|
||||
for idx, meta_layer in enumerate(self._meta_layer_list, start=1):
|
||||
if idx % 2 == 1 and self.residual_skip:
|
||||
x = tensor.clone()
|
||||
tensor = meta_layer(tensor)
|
||||
if idx % 2 == 0 and self.residual_skip:
|
||||
tensor = tensor + x
|
||||
tensor = self._last_layer(tensor)
|
||||
return tensor
|
||||
|
||||
@property
|
||||
def all_layers(self):
|
||||
return (x for x in (self._first_layer, *self._meta_layer_list, self._last_layer))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
metanet = MetaNet(interface=3, depth=5, width=3, out=1, residual_skip=True)
|
||||
|
@ -26,7 +26,8 @@ def extract_weights_from_model(model:MetaNet)->dict:
|
||||
|
||||
|
||||
def test_weights_as_model(model, new_weights:dict, data):
|
||||
TransferNet = MetaNetCompareBaseline(model.interface, depth=model.depth, width=model.width, out=model.out)
|
||||
TransferNet = MetaNetCompareBaseline(model.interface, depth=model.depth, width=model.width, out=model.out,
|
||||
residual_skip=True)
|
||||
|
||||
with torch.no_grad():
|
||||
for weights, parameters in zip(new_weights.values(), TransferNet.parameters()):
|
||||
@ -37,7 +38,6 @@ def test_weights_as_model(model, new_weights:dict, data):
|
||||
with tqdm(desc='Test Batch: ') as pbar:
|
||||
for batch, (batch_x, batch_y) in tqdm(enumerate(data), total=len(data), desc='MetaNet Sanity Check'):
|
||||
y = TransferNet(batch_x)
|
||||
loss = loss_fn(y, batch_y)
|
||||
acc = metric(y.cpu(), batch_y.cpu())
|
||||
pbar.set_postfix_str(f'Acc: {acc}')
|
||||
pbar.update()
|
||||
@ -52,13 +52,12 @@ if __name__ == '__main__':
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
WORKER = 0
|
||||
BATCHSIZE = 500
|
||||
MNIST_TRANSFORM = Compose([Resize((15, 15)), ToTensor(), Normalize((0.1307,), (0.3081,)), Flatten(start_dim=0)])
|
||||
MNIST_TRANSFORM = Compose([Resize((15, 15)), ToTensor(), Flatten(start_dim=0)])
|
||||
torch.manual_seed(42)
|
||||
data_path = Path('data')
|
||||
data_path.mkdir(exist_ok=True, parents=True)
|
||||
mnist_test = MNIST(str(data_path), transform=MNIST_TRANSFORM, download=True, train=False)
|
||||
d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER)
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
model = torch.load(Path('experiments/output/trained_model_ckpt_e50.tp'), map_location=DEVICE).eval()
|
||||
weights = extract_weights_from_model(model)
|
||||
|
Loading…
x
Reference in New Issue
Block a user