from argparse import Namespace
from pathlib import Path
from natsort import natsorted
from torch import nn


# Hyperparamter Object
class ModelParameters(Namespace):

    _activations = dict(
        leaky_relu=nn.LeakyReLU,
        relu=nn.ReLU,
        sigmoid=nn.Sigmoid,
        tanh=nn.Tanh
    )

    @property
    def model_param(self):
        return self._model_param

    @property
    def train_param(self):
        return self._train_param

    @property
    def data_param(self):
        return self._data_param

    def __init__(self, model_param, train_param, data_param):
        self._model_param = model_param
        self._train_param = train_param
        self._data_param = data_param
        kwargs = vars(model_param)
        kwargs.update(vars(train_param))
        kwargs.update(vars(data_param))
        super(ModelParameters, self).__init__(**kwargs)

    def __getattribute__(self, item):
        if item == 'activation':
            try:
                return self._activations[item]
            except KeyError:
                return nn.ReLU
        return super(ModelParameters, self).__getattribute__(item)


class SavedLightningModels(object):

    @classmethod
    def load_checkpoint(cls, models_root_path, model, n=-1, tags_file_path=''):
        assert models_root_path.exists(), f'The path {models_root_path.absolute()} does not exist!'
        found_checkpoints = list(Path(models_root_path).rglob('*.ckpt'))

        found_checkpoints = natsorted(found_checkpoints, key=lambda y: y.name)

        if not tags_file_path:
            tag_files = models_root_path.rglob('meta_tags.csv')
            tags_file_path = list(tag_files)[0]

        return cls(weights=found_checkpoints[n], model=model, tags=tags_file_path)

    def __init__(self, **kwargs):
        self.weights: str = kwargs.get('weights', '')
        self.tags: str = kwargs.get('tags', '')

        self.model = kwargs.get('model', None)
        assert self.model is not None

    def restore(self):
        pretrained_model = self.model.load_from_metrics(
            weights_path=self.weights,
            tags_csv=self.tags
        )
        pretrained_model.eval()
        pretrained_model.freeze()
        return pretrained_model