61 lines
1.8 KiB
Python
61 lines
1.8 KiB
Python
from argparse import Namespace
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from natsort import natsorted
|
|
from torch import nn
|
|
|
|
|
|
# Hyperparamter Object
|
|
class ModelParameters(Namespace):
|
|
|
|
_activations = dict(
|
|
leaky_relu=nn.LeakyReLU,
|
|
relu=nn.ReLU,
|
|
sigmoid=nn.Sigmoid,
|
|
tanh=nn.Tanh
|
|
)
|
|
|
|
def __init__(self, model_param, train_param, data_param):
|
|
self.model_param = model_param
|
|
self.train_param = train_param
|
|
self.data_param = data_param
|
|
kwargs = vars(model_param)
|
|
kwargs.update(vars(train_param))
|
|
kwargs.update(vars(data_param))
|
|
super(ModelParameters, self).__init__(**kwargs)
|
|
|
|
def __getattribute__(self, item):
|
|
if item == 'activation':
|
|
try:
|
|
return self._activations[item]
|
|
except KeyError:
|
|
return nn.ReLU
|
|
return super(ModelParameters, self).__getattribute__(item)
|
|
|
|
|
|
class SavedLightningModels(object):
|
|
|
|
@classmethod
|
|
def load_checkpoint(cls, models_root_path, model=None, n=-1, tags_file_path=''):
|
|
assert models_root_path.exists(), f'The path {models_root_path.absolute()} does not exist!'
|
|
found_checkpoints = list(Path(models_root_path).rglob('*.ckpt'))
|
|
|
|
found_checkpoints = natsorted(found_checkpoints, key=lambda y: y.name)
|
|
if model is None:
|
|
model = torch.load(models_root_path / 'model_class.obj')
|
|
assert model is not None
|
|
|
|
return cls(weights=found_checkpoints[n], model=model)
|
|
|
|
def __init__(self, **kwargs):
|
|
self.weights: str = kwargs.get('weights', '')
|
|
|
|
self.model = kwargs.get('model', None)
|
|
assert self.model is not None
|
|
|
|
def restore(self):
|
|
pretrained_model = self.model.load_from_checkpoint(self.weights)
|
|
pretrained_model.eval()
|
|
pretrained_model.freeze()
|
|
return pretrained_model |