from argparse import Namespace from pathlib import Path from natsort import natsorted from torch import nn # Hyperparamter Object class ModelParameters(Namespace): _activations = dict( leaky_relu=nn.LeakyReLU, relu=nn.ReLU, sigmoid=nn.Sigmoid, tanh=nn.Tanh ) def __init__(self, model_param, train_param, data_param): self.model_param = model_param self.train_param = train_param self.data_param = data_param kwargs = vars(model_param) kwargs.update(vars(train_param)) kwargs.update(vars(data_param)) super(ModelParameters, self).__init__(**kwargs) def __getattribute__(self, item): if item == 'activation': try: return self._activations[item] except KeyError: return nn.ReLU return super(ModelParameters, self).__getattribute__(item) class SavedLightningModels(object): @classmethod def load_checkpoint(cls, models_root_path, model, n=-1, tags_file_path=''): assert models_root_path.exists(), f'The path {models_root_path.absolute()} does not exist!' found_checkpoints = list(Path(models_root_path).rglob('*.ckpt')) found_checkpoints = natsorted(found_checkpoints, key=lambda y: y.name) if not tags_file_path: tag_files = models_root_path.rglob('meta_tags.csv') tags_file_path = list(tag_files)[0] return cls(weights=found_checkpoints[n], model=model, tags=tags_file_path) def __init__(self, **kwargs): self.weights: str = kwargs.get('weights', '') self.tags: str = kwargs.get('tags', '') self.model = kwargs.get('model', None) assert self.model is not None def restore(self): pretrained_model = self.model.load_from_metrics( weights_path=self.weights, tags_csv=self.tags ) pretrained_model.eval() pretrained_model.freeze() return pretrained_model