ml_lib/utils/model_io.py
2020-05-16 08:18:26 +02:00

90 lines
2.4 KiB
Python

from argparse import Namespace
from collections import Mapping
from copy import deepcopy
from pathlib import Path
import torch
from natsort import natsorted
from torch import nn
# Hyperparamter Object
class ModelParameters(Namespace, Mapping):
@property
def module_kwargs(self):
paramter_mapping = deepcopy(self.__dict__)
paramter_mapping.update(
dict(
activation=self._activations[self['activation']]
)
)
return paramter_mapping
def __getitem__(self, k):
# k: _KT -> _VT_co
return self.__dict__[k]
def __len__(self):
# -> int
return len(self.__dict__)
def __iter__(self):
# -> Iterator[_T_co]
return iter(list(self.__dict__.keys()))
def __delitem__(self, key):
self.__delattr__(key)
return True
def __getattribute__(self, name):
if name == 'activation':
return self._activations[self['activation']]
else:
try:
return super(ModelParameters, self).__getattribute__(name)
except AttributeError as e:
if name == 'stretch':
return False
else:
raise AttributeError(e)
_activations = dict(
leaky_relu=nn.LeakyReLU,
relu=nn.ReLU,
sigmoid=nn.Sigmoid,
tanh=nn.Tanh
)
def __init__(self, parameter_mapping):
super(ModelParameters, self).__init__(**parameter_mapping)
class SavedLightningModels(object):
@classmethod
def load_checkpoint(cls, models_root_path, model=None, n=-1, tags_file_path=''):
assert models_root_path.exists(), f'The path {models_root_path.absolute()} does not exist!'
found_checkpoints = list(Path(models_root_path).rglob('*.ckpt'))
found_checkpoints = natsorted(found_checkpoints, key=lambda y: y.name)
if model is None:
model = torch.load(models_root_path / 'model_class.obj')
assert model is not None
return cls(weights=found_checkpoints[n], model=model)
def __init__(self, **kwargs):
self.weights: str = kwargs.get('weights', '')
self.model = kwargs.get('model', None)
assert self.model is not None
def restore(self):
pretrained_model = self.model.load_from_checkpoint(self.weights)
pretrained_model.eval()
pretrained_model.freeze()
return pretrained_model