diff --git a/utils/model_io.py b/utils/model_io.py index 936f07e..fc82c4a 100644 --- a/utils/model_io.py +++ b/utils/model_io.py @@ -1,5 +1,7 @@ from argparse import Namespace from collections import Mapping +from typing import Union + from copy import deepcopy from pathlib import Path @@ -66,16 +68,19 @@ class ModelParameters(Namespace, Mapping): class SavedLightningModels(object): @classmethod - def load_checkpoint(cls, models_root_path, model=None, n=-1, tags_file_path=''): + def load_checkpoint(cls, models_root_path, model=None, n=-1, checkpoint: Union[None, str] = None): assert models_root_path.exists(), f'The path {models_root_path.absolute()} does not exist!' - found_checkpoints = list(Path(models_root_path).rglob('*.ckpt')) - - found_checkpoints = natsorted(found_checkpoints, key=lambda y: y.name) + if checkpoint is not None: + checkpoint_path = Path(checkpoint) + assert checkpoint_path.exists(), f'The path ({checkpoint_path} does not exist).' + else: + found_checkpoints = list(Path(models_root_path).rglob('*.ckpt')) + checkpoint_path = natsorted(found_checkpoints, key=lambda y: y.name)[n] if model is None: model = torch.load(models_root_path / 'model_class.obj') assert model is not None - return cls(weights=found_checkpoints[n], model=model) + return cls(weights=checkpoint_path, model=model) def __init__(self, **kwargs): self.weights: str = kwargs.get('weights', '')