108 lines
3.1 KiB
Python
108 lines
3.1 KiB
Python
from argparse import Namespace
|
|
from collections import Mapping
|
|
from typing import Union
|
|
|
|
from copy import deepcopy
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from natsort import natsorted
|
|
from torch import nn
|
|
|
|
|
|
# Hyperparamter Object
|
|
class ModelParameters(Namespace, Mapping):
|
|
|
|
@property
|
|
def as_dict(self):
|
|
return {key: self.get(key) if key != 'activation' else self.activation_as_string for key in self.keys()}
|
|
|
|
@property
|
|
def activation_as_string(self):
|
|
return self['activation'].lower()
|
|
|
|
@property
|
|
def module_kwargs(self):
|
|
|
|
paramter_mapping = deepcopy(self.__dict__)
|
|
|
|
paramter_mapping.update(
|
|
dict(
|
|
activation=self.__getattribute__('activation')
|
|
)
|
|
)
|
|
# Get rid of paramters that
|
|
paramter_mapping.__delitem__('in_shape')
|
|
|
|
return paramter_mapping
|
|
|
|
def __getitem__(self, k):
|
|
# k: _KT -> _VT_co
|
|
return self.__dict__[k]
|
|
|
|
def __len__(self):
|
|
# -> int
|
|
return len(self.__dict__)
|
|
|
|
def __iter__(self):
|
|
# -> Iterator[_T_co]
|
|
return iter(list(self.__dict__.keys()))
|
|
|
|
def __delitem__(self, key):
|
|
self.__delattr__(key)
|
|
return True
|
|
|
|
def __getattribute__(self, name):
|
|
if name == 'activation':
|
|
return self._activations[self['activation'].lower()]
|
|
else:
|
|
return super(ModelParameters, self).__getattribute__(name)
|
|
|
|
_activations = dict(
|
|
leaky_relu=nn.LeakyReLU,
|
|
gelu=nn.GELU,
|
|
elu=nn.ELU,
|
|
relu=nn.ReLU,
|
|
sigmoid=nn.Sigmoid,
|
|
tanh=nn.Tanh
|
|
)
|
|
|
|
def __init__(self, parameter_mapping):
|
|
if isinstance(parameter_mapping, Namespace):
|
|
parameter_mapping = parameter_mapping.__dict__
|
|
super(ModelParameters, self).__init__(**parameter_mapping)
|
|
|
|
|
|
class SavedLightningModels(object):
|
|
|
|
@classmethod
|
|
def load_checkpoint(cls, models_root_path, model=None, n=-1, checkpoint: Union[None, str] = None):
|
|
assert models_root_path.exists(), f'The path {models_root_path.absolute()} does not exist!'
|
|
if checkpoint is not None:
|
|
checkpoint_path = Path(checkpoint)
|
|
assert checkpoint_path.exists(), f'The path ({checkpoint_path} does not exist).'
|
|
else:
|
|
found_checkpoints = list(Path(models_root_path).rglob('*.ckpt'))
|
|
checkpoint_path = natsorted(found_checkpoints, key=lambda y: y.name)[n]
|
|
if model is None:
|
|
model = torch.load(models_root_path / 'model_class.obj')
|
|
assert model is not None
|
|
|
|
return cls(weights=checkpoint_path, model=model)
|
|
|
|
def __init__(self, **kwargs):
|
|
self.weights: Path = Path(kwargs.get('weights', ''))
|
|
self.hparams: Path = self.weights.parent / 'hparams.yaml'
|
|
|
|
self.model = kwargs.get('model', None)
|
|
assert self.model is not None
|
|
|
|
def restore(self):
|
|
|
|
pretrained_model = self.model.load_from_checkpoint(self.weights.__str__())
|
|
# , hparams_file=self.hparams.__str__())
|
|
pretrained_model.eval()
|
|
pretrained_model.freeze()
|
|
return pretrained_model
|
|
|