ml_lib/utils/model_io.py
Steffen Illium b5e3e5aec1 Dataset rdy
2021-02-16 10:18:03 +01:00

108 lines
3.1 KiB
Python

from argparse import Namespace
from collections import Mapping
from typing import Union
from copy import deepcopy
from pathlib import Path
import torch
from natsort import natsorted
from torch import nn
# Hyperparamter Object
class ModelParameters(Namespace, Mapping):
@property
def as_dict(self):
return {key: self.get(key) if key != 'activation' else self.activation_as_string for key in self.keys()}
@property
def activation_as_string(self):
return self['activation'].lower()
@property
def module_kwargs(self):
paramter_mapping = deepcopy(self.__dict__)
paramter_mapping.update(
dict(
activation=self.__getattribute__('activation')
)
)
# Get rid of paramters that
paramter_mapping.__delitem__('in_shape')
return paramter_mapping
def __getitem__(self, k):
# k: _KT -> _VT_co
return self.__dict__[k]
def __len__(self):
# -> int
return len(self.__dict__)
def __iter__(self):
# -> Iterator[_T_co]
return iter(list(self.__dict__.keys()))
def __delitem__(self, key):
self.__delattr__(key)
return True
def __getattribute__(self, name):
if name == 'activation':
return self._activations[self['activation'].lower()]
else:
return super(ModelParameters, self).__getattribute__(name)
_activations = dict(
leaky_relu=nn.LeakyReLU,
gelu=nn.GELU,
elu=nn.ELU,
relu=nn.ReLU,
sigmoid=nn.Sigmoid,
tanh=nn.Tanh
)
def __init__(self, parameter_mapping):
if isinstance(parameter_mapping, Namespace):
parameter_mapping = parameter_mapping.__dict__
super(ModelParameters, self).__init__(**parameter_mapping)
class SavedLightningModels(object):
@classmethod
def load_checkpoint(cls, models_root_path, model=None, n=-1, checkpoint: Union[None, str] = None):
assert models_root_path.exists(), f'The path {models_root_path.absolute()} does not exist!'
if checkpoint is not None:
checkpoint_path = Path(checkpoint)
assert checkpoint_path.exists(), f'The path ({checkpoint_path} does not exist).'
else:
found_checkpoints = list(Path(models_root_path).rglob('*.ckpt'))
checkpoint_path = natsorted(found_checkpoints, key=lambda y: y.name)[n]
if model is None:
model = torch.load(models_root_path / 'model_class.obj')
assert model is not None
return cls(weights=checkpoint_path, model=model)
def __init__(self, **kwargs):
self.weights: Path = Path(kwargs.get('weights', ''))
self.hparams: Path = self.weights.parent / 'hparams.yaml'
self.model = kwargs.get('model', None)
assert self.model is not None
def restore(self):
pretrained_model = self.model.load_from_checkpoint(self.weights.__str__())
# , hparams_file=self.hparams.__str__())
pretrained_model.eval()
pretrained_model.freeze()
return pretrained_model