This commit is contained in:
Si11ium 2020-06-09 17:06:33 +02:00
parent d3fa32ae7b
commit ece80ecbed

View File

@ -1,5 +1,7 @@
from argparse import Namespace from argparse import Namespace
from collections import Mapping from collections import Mapping
from typing import Union
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
@ -66,16 +68,19 @@ class ModelParameters(Namespace, Mapping):
class SavedLightningModels(object): class SavedLightningModels(object):
@classmethod @classmethod
def load_checkpoint(cls, models_root_path, model=None, n=-1, tags_file_path=''): def load_checkpoint(cls, models_root_path, model=None, n=-1, checkpoint: Union[None, str] = None):
assert models_root_path.exists(), f'The path {models_root_path.absolute()} does not exist!' assert models_root_path.exists(), f'The path {models_root_path.absolute()} does not exist!'
found_checkpoints = list(Path(models_root_path).rglob('*.ckpt')) if checkpoint is not None:
checkpoint_path = Path(checkpoint)
found_checkpoints = natsorted(found_checkpoints, key=lambda y: y.name) assert checkpoint_path.exists(), f'The path ({checkpoint_path} does not exist).'
else:
found_checkpoints = list(Path(models_root_path).rglob('*.ckpt'))
checkpoint_path = natsorted(found_checkpoints, key=lambda y: y.name)[n]
if model is None: if model is None:
model = torch.load(models_root_path / 'model_class.obj') model = torch.load(models_root_path / 'model_class.obj')
assert model is not None assert model is not None
return cls(weights=found_checkpoints[n], model=model) return cls(weights=checkpoint_path, model=model)
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.weights: str = kwargs.get('weights', '') self.weights: str = kwargs.get('weights', '')