adjustments for cuda and auto ckpt cleanup
This commit is contained in:
parent
ce5a36c8f4
commit
dd2458da4a
@ -70,7 +70,7 @@ def set_checkpoint(model, out_path, epoch_n, final_model=False):
|
||||
ckpt_path = Path(out_path) / 'ckpt' / f'{epoch_n.zfill(4)}_model_ckpt.tp'
|
||||
else:
|
||||
if isinstance(epoch_n, str):
|
||||
ckpt_path = Path(out_path) / f'{epoch_n}_{FINAL_CHECKPOINT_NAME}'
|
||||
ckpt_path = Path(out_path) / f'{Path(FINAL_CHECKPOINT_NAME).stem}_{epoch_n}.tp'
|
||||
else:
|
||||
ckpt_path = Path(out_path) / FINAL_CHECKPOINT_NAME
|
||||
ckpt_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
@ -113,10 +113,20 @@ def new_storage_df(identifier, weight_count):
|
||||
return pd.DataFrame(columns=['Epoch', 'Weight', *(f'weight_{x}' for x in range(weight_count))])
|
||||
|
||||
|
||||
def checkpoint_and_validate(model, valid_loader, out_path, epoch_n, final_model=False,
|
||||
def checkpoint_and_validate(model, valid_loader, out_path, epoch_n, keep_n=5, final_model=False,
|
||||
validation_metric=torchmetrics.Accuracy):
|
||||
out_path = Path(out_path)
|
||||
ckpt_path = set_checkpoint(model, out_path, epoch_n, final_model=final_model)
|
||||
# Clean up Checkpoints
|
||||
if keep_n > 0:
|
||||
all_ckpts = sorted(list(ckpt_path.parent.iterdir()))
|
||||
while len(all_ckpts) > keep_n:
|
||||
all_ckpts.pop(0).unlink()
|
||||
elif keep_n == 0:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f'"keep_n" cannot be negative, but was: {keep_n}')
|
||||
|
||||
result = validate(ckpt_path, valid_loader, metric_class=validation_metric)
|
||||
return result
|
||||
|
||||
|
@ -77,7 +77,7 @@ def test_robustness(model_path, noise_levels=10, seeds=10, log_step_size=10):
|
||||
# When this raises a Type Error, we found a second order fixpoint!
|
||||
steps += 1
|
||||
|
||||
df.loc[df.shape[0]] = [setting, f'$\mathregular{{10^{{-{noise_level}}}}}$',
|
||||
df.loc[df.shape[0]] = [f'{setting}_{seed}', fr'$\mathregular{{10^{{-{noise_level}}}}}$',
|
||||
steps, absolute_loss,
|
||||
time_to_vergence[setting][noise_level],
|
||||
time_as_fixpoint[setting][noise_level]]
|
||||
|
@ -36,7 +36,7 @@ else:
|
||||
from network import MetaNet, FixTypes
|
||||
from functionalities_test import test_for_fixpoints
|
||||
|
||||
utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0), AddGaussianNoise()])
|
||||
utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)]) # , AddGaussianNoise()])
|
||||
WORKER = 10 if not debug else 2
|
||||
debug = False
|
||||
BATCHSIZE = 2000 if not debug else 50
|
||||
@ -60,16 +60,16 @@ plot_loader = DataLoader(plot_dataset, batch_size=BATCHSIZE, shuffle=False,
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
training = False
|
||||
plotting = False
|
||||
robustnes = True # EXPENSIV!!!!!!!
|
||||
n_st = 300 # per batch !!
|
||||
training = True
|
||||
plotting = True
|
||||
robustnes = True
|
||||
n_st = 1 # per batch !!
|
||||
activation = None # nn.ReLU()
|
||||
|
||||
for weight_hidden_size in [3]:
|
||||
|
||||
weight_hidden_size = weight_hidden_size
|
||||
residual_skip = True
|
||||
residual_skip = False
|
||||
n_seeds = 3
|
||||
depth = 5
|
||||
width = 3
|
||||
@ -84,7 +84,7 @@ if __name__ == '__main__':
|
||||
st_str = f'_nst_{n_st}'
|
||||
|
||||
config_str = f'{res_str}{ac_str}{st_str}'
|
||||
exp_path = Path('output') / f'mn_st_{EPOCH}_{weight_hidden_size}{config_str}_gauss'
|
||||
exp_path = Path('output') / f'mn_st_{EPOCH}_{weight_hidden_size}{config_str}'
|
||||
|
||||
for seed in range(n_seeds):
|
||||
seed_path = exp_path / str(seed)
|
||||
@ -161,7 +161,7 @@ if __name__ == '__main__':
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
accuracy = checkpoint_and_validate(metanet, valid_loader, seed_path, epoch).item()
|
||||
accuracy = checkpoint_and_validate(metanet, valid_loader, seed_path, epoch, keep_n=5).item()
|
||||
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
|
||||
Metric=f'Test {VAL_METRIC_NAME}', Score=accuracy)
|
||||
train_store.loc[train_store.shape[0]] = validation_log
|
||||
|
@ -15,13 +15,12 @@ from functionalities_test import epsilon_error_margin as e
|
||||
from network import MetaNet, MetaNetCompareBaseline
|
||||
|
||||
|
||||
def extract_weights_from_model(model:MetaNet)->dict:
|
||||
inpt = torch.zeros(5)
|
||||
def extract_weights_from_model(model: MetaNet) -> dict:
|
||||
inpt = torch.zeros(5, device=next(model.parameters()).device, dtype=torch.float)
|
||||
inpt[-1] = 1
|
||||
inpt.long()
|
||||
|
||||
weights = defaultdict(list)
|
||||
layers = [layer.particles for layer in [model._meta_layer_first, *model._meta_layer_list, model._meta_layer_last]]
|
||||
layers = [layer.particles for layer in model.all_layers]
|
||||
for i, layer in enumerate(layers):
|
||||
for net in layer:
|
||||
weights[i].append(net(inpt).detach())
|
||||
@ -29,9 +28,10 @@ def extract_weights_from_model(model:MetaNet)->dict:
|
||||
|
||||
|
||||
def test_weights_as_model(meta_net, new_weights, data, metric_class=torchmetrics.Accuracy):
|
||||
meta_net_device = next(meta_net.parameters()).device
|
||||
transfer_net = MetaNetCompareBaseline(meta_net.interface, depth=meta_net.depth,
|
||||
width=meta_net.width, out=meta_net.out,
|
||||
residual_skip=meta_net.residual_skip)
|
||||
residual_skip=meta_net.residual_skip).to(meta_net_device)
|
||||
with torch.no_grad():
|
||||
new_weight_values = list(new_weights.values())
|
||||
old_parameters = list(transfer_net.parameters())
|
||||
@ -45,7 +45,7 @@ def test_weights_as_model(meta_net, new_weights, data, metric_class=torchmetrics
|
||||
net.eval()
|
||||
metric = metric_class()
|
||||
for batch, (batch_x, batch_y) in tqdm(enumerate(data), total=len(data), desc='Test Batch: '):
|
||||
y = net(batch_x)
|
||||
y = net(batch_x.to(meta_net_device))
|
||||
metric(y.cpu(), batch_y.cpu())
|
||||
|
||||
# metric on all batches using custom accumulation
|
||||
|
Loading…
x
Reference in New Issue
Block a user