initial
This commit is contained in:
76
lib/utils/model_io.py
Normal file
76
lib/utils/model_io.py
Normal file
@ -0,0 +1,76 @@
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from natsort import natsorted
|
||||
from torch import nn
|
||||
|
||||
|
||||
# Hyperparamter Object
|
||||
class ModelParameters(Namespace):
|
||||
|
||||
_activations = dict(
|
||||
leaky_relu=nn.LeakyReLU,
|
||||
relu=nn.ReLU,
|
||||
sigmoid=nn.Sigmoid,
|
||||
tanh=nn.Tanh
|
||||
)
|
||||
|
||||
@property
|
||||
def model_param(self):
|
||||
return self._model_param
|
||||
|
||||
@property
|
||||
def train_param(self):
|
||||
return self._train_param
|
||||
|
||||
@property
|
||||
def data_param(self):
|
||||
return self._data_param
|
||||
|
||||
def __init__(self, model_param, train_param, data_param):
|
||||
self._model_param = model_param
|
||||
self._train_param = train_param
|
||||
self._data_param = data_param
|
||||
kwargs = vars(model_param)
|
||||
kwargs.update(vars(train_param))
|
||||
kwargs.update(vars(data_param))
|
||||
super(ModelParameters, self).__init__(**kwargs)
|
||||
|
||||
def __getattribute__(self, item):
|
||||
if item == 'activation':
|
||||
try:
|
||||
return self._activations[item]
|
||||
except KeyError:
|
||||
return nn.ReLU
|
||||
return super(ModelParameters, self).__getattribute__(item)
|
||||
|
||||
|
||||
class SavedLightningModels(object):
|
||||
|
||||
@classmethod
|
||||
def load_checkpoint(cls, models_root_path, model, n=-1, tags_file_path=''):
|
||||
assert models_root_path.exists(), f'The path {models_root_path.absolute()} does not exist!'
|
||||
found_checkpoints = list(Path(models_root_path).rglob('*.ckpt'))
|
||||
|
||||
found_checkpoints = natsorted(found_checkpoints, key=lambda y: y.name)
|
||||
|
||||
if not tags_file_path:
|
||||
tag_files = models_root_path.rglob('meta_tags.csv')
|
||||
tags_file_path = list(tag_files)[0]
|
||||
|
||||
return cls(weights=found_checkpoints[n], model=model, tags=tags_file_path)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.weights: str = kwargs.get('weights', '')
|
||||
self.tags: str = kwargs.get('tags', '')
|
||||
|
||||
self.model = kwargs.get('model', None)
|
||||
assert self.model is not None
|
||||
|
||||
def restore(self):
|
||||
pretrained_model = self.model.load_from_metrics(
|
||||
weights_path=self.weights,
|
||||
tags_csv=self.tags
|
||||
)
|
||||
pretrained_model.eval()
|
||||
pretrained_model.freeze()
|
||||
return pretrained_model
|
Reference in New Issue
Block a user