Model IO
This commit is contained in:
parent
d3fa32ae7b
commit
ece80ecbed
@ -1,5 +1,7 @@
|
|||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from collections import Mapping
|
from collections import Mapping
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -66,16 +68,19 @@ class ModelParameters(Namespace, Mapping):
|
|||||||
class SavedLightningModels(object):
|
class SavedLightningModels(object):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_checkpoint(cls, models_root_path, model=None, n=-1, tags_file_path=''):
|
def load_checkpoint(cls, models_root_path, model=None, n=-1, checkpoint: Union[None, str] = None):
|
||||||
assert models_root_path.exists(), f'The path {models_root_path.absolute()} does not exist!'
|
assert models_root_path.exists(), f'The path {models_root_path.absolute()} does not exist!'
|
||||||
found_checkpoints = list(Path(models_root_path).rglob('*.ckpt'))
|
if checkpoint is not None:
|
||||||
|
checkpoint_path = Path(checkpoint)
|
||||||
found_checkpoints = natsorted(found_checkpoints, key=lambda y: y.name)
|
assert checkpoint_path.exists(), f'The path ({checkpoint_path} does not exist).'
|
||||||
|
else:
|
||||||
|
found_checkpoints = list(Path(models_root_path).rglob('*.ckpt'))
|
||||||
|
checkpoint_path = natsorted(found_checkpoints, key=lambda y: y.name)[n]
|
||||||
if model is None:
|
if model is None:
|
||||||
model = torch.load(models_root_path / 'model_class.obj')
|
model = torch.load(models_root_path / 'model_class.obj')
|
||||||
assert model is not None
|
assert model is not None
|
||||||
|
|
||||||
return cls(weights=found_checkpoints[n], model=model)
|
return cls(weights=checkpoint_path, model=model)
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.weights: str = kwargs.get('weights', '')
|
self.weights: str = kwargs.get('weights', '')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user