adjustments for cuda and auto ckpt cleanup
This commit is contained in:
@@ -70,7 +70,7 @@ def set_checkpoint(model, out_path, epoch_n, final_model=False):
|
||||
ckpt_path = Path(out_path) / 'ckpt' / f'{epoch_n.zfill(4)}_model_ckpt.tp'
|
||||
else:
|
||||
if isinstance(epoch_n, str):
|
||||
ckpt_path = Path(out_path) / f'{epoch_n}_{FINAL_CHECKPOINT_NAME}'
|
||||
ckpt_path = Path(out_path) / f'{Path(FINAL_CHECKPOINT_NAME).stem}_{epoch_n}.tp'
|
||||
else:
|
||||
ckpt_path = Path(out_path) / FINAL_CHECKPOINT_NAME
|
||||
ckpt_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
@@ -113,10 +113,20 @@ def new_storage_df(identifier, weight_count):
|
||||
return pd.DataFrame(columns=['Epoch', 'Weight', *(f'weight_{x}' for x in range(weight_count))])
|
||||
|
||||
|
||||
def checkpoint_and_validate(model, valid_loader, out_path, epoch_n, final_model=False,
|
||||
def checkpoint_and_validate(model, valid_loader, out_path, epoch_n, keep_n=5, final_model=False,
|
||||
validation_metric=torchmetrics.Accuracy):
|
||||
out_path = Path(out_path)
|
||||
ckpt_path = set_checkpoint(model, out_path, epoch_n, final_model=final_model)
|
||||
# Clean up Checkpoints
|
||||
if keep_n > 0:
|
||||
all_ckpts = sorted(list(ckpt_path.parent.iterdir()))
|
||||
while len(all_ckpts) > keep_n:
|
||||
all_ckpts.pop(0).unlink()
|
||||
elif keep_n == 0:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f'"keep_n" cannot be negative, but was: {keep_n}')
|
||||
|
||||
result = validate(ckpt_path, valid_loader, metric_class=validation_metric)
|
||||
return result
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ def test_robustness(model_path, noise_levels=10, seeds=10, log_step_size=10):
|
||||
# When this raises a Type Error, we found a second order fixpoint!
|
||||
steps += 1
|
||||
|
||||
df.loc[df.shape[0]] = [setting, f'$\mathregular{{10^{{-{noise_level}}}}}$',
|
||||
df.loc[df.shape[0]] = [f'{setting}_{seed}', fr'$\mathregular{{10^{{-{noise_level}}}}}$',
|
||||
steps, absolute_loss,
|
||||
time_to_vergence[setting][noise_level],
|
||||
time_as_fixpoint[setting][noise_level]]
|
||||
|
||||
Reference in New Issue
Block a user