diff --git a/audio_toolset/audio_to_mel_dataset.py b/audio_toolset/audio_to_mel_dataset.py
index 058a326..f86e1aa 100644
--- a/audio_toolset/audio_to_mel_dataset.py
+++ b/audio_toolset/audio_to_mel_dataset.py
@@ -10,24 +10,31 @@ from ml_lib.audio_toolset.audio_io import LibrosaAudioToMel, MelToImage
 from ml_lib.audio_toolset.mel_dataset import TorchMelDataset
 
 
-class _AudioToMelDataset(Dataset, ABC):
+import librosa
+
+
+class LibrosaAudioToMelDataset(Dataset):
 
     @property
     def audio_file_duration(self):
-        raise NotImplementedError
+        return librosa.get_duration(sr=self.mel_kwargs.get('sr', None), filename=self.audio_path)
 
     @property
     def sampling_rate(self):
-        raise NotImplementedError
+        return self.mel_kwargs.get('sr', None)
 
     def __init__(self, audio_file_path, label, sample_segment_len=0, sample_hop_len=0, reset=False,
                  audio_augmentations=None, mel_augmentations=None, mel_kwargs=None, **kwargs):
-        self.ignored_kwargs = kwargs
+        super(LibrosaAudioToMelDataset, self).__init__()
+
+        # audio_file, sampling_rate = librosa.load(self.audio_path, sr=sampling_rate)
+        mel_kwargs.update(sr=mel_kwargs.get('sr', None) or librosa.get_samplerate(audio_file_path))
         self.mel_kwargs = mel_kwargs
         self.reset = reset
         self.audio_path = Path(audio_file_path)
 
         mel_folder_suffix = self.audio_path.parent.parent.name
+
         self.mel_file_path = Path(str(self.audio_path)
                                   .replace(mel_folder_suffix, f'{mel_folder_suffix}_mel_folder')
                                   .replace(self.audio_path.suffix, '.npy'))
@@ -38,59 +45,25 @@ class _AudioToMelDataset(Dataset, ABC):
                                        self.audio_file_duration, mel_kwargs['sr'], mel_kwargs['hop_length'],
                                        mel_kwargs['n_mels'], transform=mel_augmentations)
 
-    def _build_mel(self):
-        raise NotImplementedError
-
-    def __getitem__(self, item):
-        try:
-            return self.dataset[item]
-        except FileNotFoundError:
-            assert self._build_mel()
-            return self.dataset[item]
-
-    def __len__(self):
-        return len(self.dataset)
-
-
-import librosa
-
-
-class LibrosaAudioToMelDataset(_AudioToMelDataset):
-
-    @property
-    def audio_file_duration(self):
-        return librosa.get_duration(sr=self.mel_kwargs.get('sr', None), filename=self.audio_path)
-
-    @property
-    def sampling_rate(self):
-        return self.mel_kwargs.get('sr', None)
-
-    def __init__(self, audio_file_path, *args, **kwargs):
-
-        audio_file_path = Path(audio_file_path)
-        # audio_file, sampling_rate = librosa.load(self.audio_path, sr=sampling_rate)
-        mel_kwargs = kwargs.get('mel_kwargs', dict())
-        mel_kwargs.update(sr=mel_kwargs.get('sr', None) or librosa.get_samplerate(audio_file_path))
-        kwargs.update(mel_kwargs=mel_kwargs)
-
-        super(LibrosaAudioToMelDataset, self).__init__(audio_file_path, *args, **kwargs)
-
         self._mel_transform = Compose([LibrosaAudioToMel(**mel_kwargs),
                                        MelToImage()
                                        ])
 
-    def _build_mel(self):
+    def __getitem__(self, item):
+        return self.dataset[item]
+
+    def __len__(self):
+        return len(self.dataset)
+
+    def build_mel(self):
         if self.reset:
             self.mel_file_path.unlink(missing_ok=True)
         if not self.mel_file_path.exists():
-            lockfile = Path(str(self.mel_file_path).replace(self.mel_file_path.suffix, '.lock'))
             self.mel_file_path.parent.mkdir(parents=True, exist_ok=True)
-            lockfile.touch(exist_ok=False)
             raw_sample, _ = librosa.core.load(self.audio_path, sr=self.sampling_rate)
             mel_sample = self._mel_transform(raw_sample)
             with self.mel_file_path.open('wb') as mel_file:
                 pickle.dump(mel_sample, mel_file, protocol=pickle.HIGHEST_PROTOCOL)
-            lockfile.unlink(missing_ok=False)
         else:
             pass
 
diff --git a/audio_toolset/mel_dataset.py b/audio_toolset/mel_dataset.py
index 6b6f245..4948736 100644
--- a/audio_toolset/mel_dataset.py
+++ b/audio_toolset/mel_dataset.py
@@ -11,13 +11,16 @@ class TorchMelDataset(Dataset):
     def __init__(self, mel_path, sub_segment_len, sub_segment_hop_len, label, audio_file_len,
                  sampling_rate, mel_hop_len, n_mels, transform=None, auto_pad_to_shape=True):
         super(TorchMelDataset, self).__init__()
-        self.sampling_rate = sampling_rate
-        self.audio_file_len = audio_file_len
-        self.padding = AutoPadToShape((n_mels, sub_segment_len)) if auto_pad_to_shape and sub_segment_len else None
+        self.sampling_rate = int(sampling_rate)
+        self.audio_file_len = int(audio_file_len)
+        if auto_pad_to_shape and sub_segment_len:
+            self.padding = AutoPadToShape((int(n_mels), int(sub_segment_len)))
+        else:
+            self.padding = None
         self.path = Path(mel_path)
-        self.sub_segment_len = sub_segment_len
-        self.mel_hop_len = mel_hop_len
-        self.sub_segment_hop_len = sub_segment_hop_len
+        self.sub_segment_len = int(sub_segment_len)
+        self.mel_hop_len = int(mel_hop_len)
+        self.sub_segment_hop_len = int(sub_segment_hop_len)
         self.n = int((self.sampling_rate / self.mel_hop_len) * self.audio_file_len + 1)
         if self.sub_segment_len and self.sub_segment_hop_len:
             self.offsets = list(range(0, self.n - self.sub_segment_len, self.sub_segment_hop_len))
@@ -27,8 +30,6 @@ class TorchMelDataset(Dataset):
         self.transform = transform
 
     def __getitem__(self, item):
-        while Path(str(self.path).replace(self.path.suffix, '.lock')).exists():
-            time.sleep(0.01)
         with self.path.open('rb') as mel_file:
             mel_spec = pickle.load(mel_file, fix_imports=True)
         start = self.offsets[item]
@@ -38,7 +39,7 @@ class TorchMelDataset(Dataset):
             snippet = self.transform(snippet)
         if self.padding:
             snippet = self.padding(snippet)
-        return snippet, self.label
+        return self.path.__str__(), snippet, self.label
 
     def __len__(self):
         return len(self.offsets)
diff --git a/experiments.py b/experiments.py
deleted file mode 100644
index 66d4171..0000000
--- a/experiments.py
+++ /dev/null
@@ -1,67 +0,0 @@
-
-import torchaudio
-if sys.platform =='windows':
-    torchaudio.set_audio_backend('soundfile')
-else:
-    torchaudio.set_audio_backend('sox_io')
-
-
-class PyTorchAudioToMelDataset(_AudioToMelDataset):
-
-    @property
-    def audio_file_duration(self):
-        info_obj = torchaudio.info(self.audio_path)
-        return info_obj.num_frames / info_obj.sample_rate
-
-    @property
-    def sampling_rate(self):
-        return self.mel_kwargs['sample_rate']
-
-    def __init__(self, audio_file_path, *args, **kwargs):
-        super(PyTorchAudioToMelDataset, self).__init__(audio_file_path, *args, **kwargs)
-
-        audio_file_path = Path(audio_file_path)
-        # audio_file, sampling_rate = librosa.load(self.audio_path, sr=sampling_rate)
-
-        from torchaudio.transforms import MelSpectrogram
-        self._mel_transform = Compose([MelSpectrogram(**self.mel_kwargs),
-                                       MelToImage()
-                                       ])
-
-    def _build_mel(self):
-        if self.reset:
-            self.mel_file_path.unlink(missing_ok=True)
-        if not self.mel_file_path.exists():
-            self.mel_file_path.parent.mkdir(parents=True, exist_ok=True)
-            lock_file = Path(str(self.mel_file_path).replace(self.mel_file_path.suffix, '.lock'))
-            lock_file.touch(exist_ok=False)
-
-            try:
-                audio_sample, sample_rate = torchaudio.load(self.audio_path)
-            except RuntimeError:
-                import soundfile
-
-                data, samplerate = soundfile.read(self.audio_path)
-                # sf.available_formats()
-                # sf.available_subtypes()
-                soundfile.write(self.audio_path, data, samplerate, subtype='PCM_32')
-
-                audio_sample, sample_rate = torchaudio.load(self.audio_path)
-            if sample_rate != self.sampling_rate:
-                resample = torchaudio.transforms.Resample(orig_freq=int(sample_rate), new_freq=int(self.sampling_rate))
-                audio_sample = resample(audio_sample)
-            if audio_sample.shape[0] > 1:
-                # Transform Stereo to Mono
-                audio_sample = audio_sample.mean(dim=0, keepdim=True)
-            mel_sample = self._mel_transform(audio_sample)
-            with self.mel_file_path.open('wb') as mel_file:
-                pickle.dump(mel_sample, mel_file, protocol=pickle.HIGHEST_PROTOCOL)
-            lock_file.unlink()
-        else:
-            # print(f"Already existed.. Skipping {filename}")
-            # mel_file = mel_file
-            pass
-
-        # with mel_file.open(mode='rb') as f:
-        #     mel_sample = pickle.load(f, fix_imports=True)
-        return self.mel_file_path.exists()
diff --git a/metrics/attention_rollout.py b/metrics/attention_rollout.py
new file mode 100644
index 0000000..972f151
--- /dev/null
+++ b/metrics/attention_rollout.py
@@ -0,0 +1,47 @@
+import numpy as np
+
+from einops import reduce
+
+
+import torch
+from sklearn.ensemble import IsolationForest
+from sklearn.metrics import recall_score, roc_auc_score, average_precision_score
+
+from ml_lib.metrics._base_score import _BaseScores
+
+
+class AttentionRollout(_BaseScores):
+
+    def __init__(self, *args):
+        super(AttentionRollout, self).__init__(*args)
+        pass
+
+    def __call__(self, outputs):
+        summary_dict = dict()
+        #######################################################################################
+        # Additional Score  -  Histogram Distances - Image Plotting
+        #######################################################################################
+        #
+        # INIT
+        attn_weights = [output['attn_weights'].cpu().numpy() for output in outputs]
+        attn_reduce_heads = [reduce(x, '') for x in attn_weights]
+
+        if self.model.params.use_residual:
+            residual_att = np.eye(att_mat.shape[1])[None, ...]
+            aug_att_mat = att_mat + residual_att
+            aug_att_mat = aug_att_mat / aug_att_mat.sum(axis=-1)[..., None]
+        else:
+            aug_att_mat = att_mat
+
+        joint_attentions = np.zeros(aug_att_mat.shape)
+
+        layers = joint_attentions.shape[0]
+        joint_attentions[0] = aug_att_mat[0]
+        for i in np.arange(1, layers):
+            joint_attentions[i] = aug_att_mat[i].dot(joint_attentions[i - 1])
+
+
+
+
+
+
diff --git a/metrics/multi_class_classification.py b/metrics/multi_class_classification.py
index fe7d6d8..4bb77f6 100644
--- a/metrics/multi_class_classification.py
+++ b/metrics/multi_class_classification.py
@@ -113,17 +113,17 @@ class MultiClassScores(_BaseScores):
         #######################################################################################
         #
         # Confusion matrix
-
+        fig1, ax1 = plt.subplots(dpi=96)
         cm = confusion_matrix([class_names[x] for x in y_true], [class_names[x] for x in y_pred_max],
                               labels=[class_names[key] for key in class_names.keys()],
                               normalize='all')
         disp = ConfusionMatrixDisplay(confusion_matrix=cm,
                                       display_labels=[class_names[i] for i in range(self.model.n_classes)]
                                       )
-        disp.plot(include_values=True)
+        disp.plot(include_values=True, ax=ax1)
 
-        self.model.logger.log_image('Confusion_Matrix', image=disp.figure_, step=self.model.current_epoch)
+        self.model.logger.log_image('Confusion_Matrix', image=fig1, step=self.model.current_epoch)
         # self.model.logger.log_image('Confusion_Matrix', image=disp.figure_, step=self.model.current_epoch, ext='pdf')
 
         plt.close('all')
-        return summary_dict
\ No newline at end of file
+        return summary_dict
diff --git a/modules/blocks.py b/modules/blocks.py
index 2d3a359..d4c425b 100644
--- a/modules/blocks.py
+++ b/modules/blocks.py
@@ -291,19 +291,17 @@ class TransformerModule(ShapeMixin, nn.Module):
 
         for attn, mlp in zip(self.attns, self.mlps):
             # Attention
-            skip_connection = tensor.clone()
-            tensor = self.norm(tensor)
+            attn_tensor = self.norm(tensor)
             if return_attn_weights:
-                tensor, attn_weight = attn(tensor, mask=mask, return_attn_weights=return_attn_weights)
+                attn_tensor, attn_weight = attn(attn_tensor, mask=mask, return_attn_weights=return_attn_weights)
                 attn_weights.append(attn_weight)
             else:
-                tensor = attn(tensor, mask=mask)
-            tensor = tensor + skip_connection
+                attn_tensor = attn(attn_tensor, mask=mask)
+            tensor = attn_tensor + tensor
 
             # MLP
-            skip_connection = tensor.clone()
-            tensor = self.norm(tensor)
-            tensor = mlp(tensor)
-            tensor = tensor + skip_connection
+            mlp_tensor = self.norm(tensor)
+            mlp_tensor = mlp(mlp_tensor)
+            tensor = tensor + mlp_tensor
 
         return (tensor, attn_weights) if return_attn_weights else tensor
diff --git a/modules/util.py b/modules/util.py
index be56c35..6492bc2 100644
--- a/modules/util.py
+++ b/modules/util.py
@@ -1,3 +1,6 @@
+import inspect
+from argparse import ArgumentParser
+
 from functools import reduce
 
 from abc import ABC
@@ -5,13 +8,14 @@ from pathlib import Path
 
 import torch
 from operator import mul
+from pytorch_lightning.utilities import argparse_utils
 from torch import nn
 from torch.nn import functional as F, Unfold
 
 # Utility - Modules
 ###################
 from ..utils.model_io import ModelParameters
-from ..utils.tools import locate_and_import_class
+from ..utils.tools import locate_and_import_class, add_argparse_args
 
 try:
     import pytorch_lightning as pl
@@ -32,14 +36,18 @@ try:
                 print(e)
                 return -1
 
-        def __init__(self, hparams):
-            super(LightningBaseModule, self).__init__()
+        @classmethod
+        def from_argparse_args(cls, args, **kwargs):
+            return argparse_utils.from_argparse_args(cls, args, **kwargs)
 
-            # Set Parameters
-            ################################
-            self.hparams = hparams
-            self.params = ModelParameters(hparams)
-            self.lr = self.params.lr or 1e-4
+        @classmethod
+        def add_argparse_args(cls, parent_parser):
+            return add_argparse_args(cls, parent_parser)
+
+        def __init__(self, model_parameters, weight_init='xavier_normal_'):
+            super(LightningBaseModule, self).__init__()
+            self._weight_init = weight_init
+            self.params = ModelParameters(model_parameters)
 
         def size(self):
             return self.shape
@@ -47,15 +55,6 @@ try:
         def additional_scores(self, outputs):
             raise NotImplementedError
 
-        @property
-        def dataset_class(self):
-            try:
-                return locate_and_import_class(self.params.class_name, folder_path='datasets')
-            except AttributeError as e:
-                raise AttributeError(f'The dataset alias you provided ("{self.params.class_name}") ' +
-                                     f'was not found!\n' +
-                                     f'{e}')
-
         def save_to_disk(self, model_path):
             Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
             if not (model_path / 'model_class.obj').exists():
@@ -86,8 +85,12 @@ try:
         def test_epoch_end(self, outputs):
             raise NotImplementedError
 
-        def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_):
-            weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
+        def init_weights(self):
+            if isinstance(self._weight_init, str):
+                mod = __import__('torch.nn.init', fromlist=[self._weight_init])
+                self._weight_init = getattr(mod, self._weight_init)
+            assert callable(self._weight_init)
+            weight_initializer = WeightInit(in_place_init_function=self._weight_init)
             self.apply(weight_initializer)
 
     module_types = (LightningBaseModule, nn.Module,)
diff --git a/utils/_basedatamodule.py b/utils/_basedatamodule.py
new file mode 100644
index 0000000..c2e0a2f
--- /dev/null
+++ b/utils/_basedatamodule.py
@@ -0,0 +1,29 @@
+from pytorch_lightning import LightningDataModule
+
+
+# Dataset Options
+from ml_lib.utils.tools import add_argparse_args
+
+DATA_OPTION_test = 'test'
+DATA_OPTION_devel = 'devel'
+DATA_OPTION_train = 'train'
+DATA_OPTIONS = [DATA_OPTION_train, DATA_OPTION_devel, DATA_OPTION_test]
+
+
+class _BaseDataModule(LightningDataModule):
+
+    @property
+    def shape(self):
+        return self.datasets[DATA_OPTION_train].sample_shape
+
+    @classmethod
+    def add_argparse_args(cls, parent_parser):
+        return add_argparse_args(cls, parent_parser)
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.datasets = dict()
+
+    def transfer_batch_to_device(self, batch, device):
+        return batch.to(device)
+
diff --git a/utils/logging.py b/utils/logging.py
index d0d983f..f9ef373 100644
--- a/utils/logging.py
+++ b/utils/logging.py
@@ -1,19 +1,34 @@
-from abc import ABC
+import inspect
+from argparse import ArgumentParser
 from pathlib import Path
 
+import os
 from pytorch_lightning.loggers.base import LightningLoggerBase
 from pytorch_lightning.loggers.neptune import NeptuneLogger
 from neptune.api_exceptions import ProjectNotFound
-# noinspection PyUnresolvedReferences
+
 from pytorch_lightning.loggers.csv_logs import CSVLogger
+from pytorch_lightning.utilities import argparse_utils
 
-from .config import Config
+from ml_lib.utils.tools import add_argparse_args
 
 
-class Logger(LightningLoggerBase, ABC):
+class Logger(LightningLoggerBase):
+
+    @classmethod
+    def from_argparse_args(cls, args, **kwargs):
+        return argparse_utils.from_argparse_args(cls, args, **kwargs)
+
+    @property
+    def name(self) -> str:
+        return self._name
 
     media_dir = 'media'
 
+    @classmethod
+    def add_argparse_args(cls, parent_parser):
+        return add_argparse_args(cls, parent_parser)
+
     @property
     def experiment(self):
         if self.debug:
@@ -25,27 +40,23 @@ class Logger(LightningLoggerBase, ABC):
     def log_dir(self):
         return Path(self.csvlogger.experiment.log_dir)
 
-    @property
-    def name(self):
-        return self.config.name
-
     @property
     def project_name(self):
-        return f"{self.config.project.owner}/{self.config.project.name.replace('_', '-')}"
+        return f"{self.owner}/{self.name.replace('_', '-')}"
 
     @property
     def version(self):
-        return self.config.get('main', 'seed')
+        return self.seed
+
+    @property
+    def save_dir(self):
+        return self.log_dir
 
     @property
     def outpath(self):
-        return Path(self.config.train.outpath) / self.config.model.type
+        return Path(self.root_out) / self.model_name
 
-    @property
-    def exp_path(self):
-        return Path(self.outpath) / self.name
-
-    def __init__(self, config: Config):
+    def __init__(self, owner, neptune_key, model_name, project_name='', outpath='output', seed=69, debug=False):
         """
         params (dict|None): Optional. Parameters of the experiment. After experiment creation params are read-only.
            Parameters are displayed in the experiment’s Parameters section and each key-value pair can be
@@ -59,19 +70,19 @@ class Logger(LightningLoggerBase, ABC):
         """
         super(Logger, self).__init__()
 
-        self.config = config
-        self.debug = self.config.main.debug
-        if self.debug:
-            self.config.add_section('project')
-            self.config.set('project', 'owner', 'testuser')
-            self.config.set('project', 'name', 'test')
-            self.config.set('project', 'neptune_key', 'XXX')
+        self.debug = debug
+        self._name = project_name or Path(os.getcwd()).name if not self.debug else 'test'
+        self.owner = owner if not self.debug else 'testuser'
+        self.neptune_key = neptune_key if not self.debug else 'XXX'
+        self.root_out = outpath if not self.debug else 'debug_out'
+        self.seed = seed
+        self.model_name = model_name
+
         self._csvlogger_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
         self._neptune_kwargs = dict(offline_mode=self.debug,
-                                    api_key=self.config.project.neptune_key,
+                                    api_key=self.neptune_key,
                                     experiment_name=self.name,
-                                    project_name=self.project_name,
-                                    params=self.config.model_paramters)
+                                    project_name=self.project_name)
         try:
             self.neptunelogger = NeptuneLogger(**self._neptune_kwargs)
         except ProjectNotFound as e:
@@ -79,7 +90,6 @@ class Logger(LightningLoggerBase, ABC):
             print(e)
 
         self.csvlogger = CSVLogger(**self._csvlogger_kwargs)
-        self.log_config_as_ini()
 
     def log_hyperparams(self, params):
         self.neptunelogger.log_hyperparams(params)
@@ -95,19 +105,15 @@ class Logger(LightningLoggerBase, ABC):
         self.csvlogger.close()
         self.neptunelogger.close()
 
-    def log_config_as_ini(self):
-        self.config.write(self.log_dir / 'config.ini')
-
-    def log_text(self, name, text, step_nb=0, **_):
+    def log_text(self, name, text, **_):
         # TODO Implement Offline variant.
-        self.neptunelogger.log_text(name, text, step_nb)
+        self.neptunelogger.log_text(name, text)
 
     def log_metric(self, metric_name, metric_value, **kwargs):
         self.csvlogger.log_metrics(dict(metric_name=metric_value))
         self.neptunelogger.log_metric(metric_name, metric_value, **kwargs)
 
     def log_image(self, name, image, ext='png', **kwargs):
-
         step = kwargs.get('step', None)
         image_name = f'{step}_{name}' if step is not None else name
         image_path = self.log_dir / self.media_dir / f'{image_name}.{ext[1:] if ext.startswith(".") else ext}'
diff --git a/utils/model_io.py b/utils/model_io.py
index 9bcc18d..f3ae503 100644
--- a/utils/model_io.py
+++ b/utils/model_io.py
@@ -13,6 +13,10 @@ from torch import nn
 # Hyperparamter Object
 class ModelParameters(Namespace, Mapping):
 
+    @property
+    def as_dict(self):
+        return {key: self.get(key) if key != 'activation' else self.activation_as_string for key in self.keys()}
+
     @property
     def activation_as_string(self):
         return self['activation'].lower()
@@ -50,13 +54,7 @@ class ModelParameters(Namespace, Mapping):
         if name == 'activation':
             return self._activations[self['activation'].lower()]
         else:
-            try:
-                return super(ModelParameters, self).__getattribute__(name)
-            except AttributeError as e:
-                if name == 'stretch':
-                    return False
-                else:
-                    return None
+            return super(ModelParameters, self).__getattribute__(name)
 
     _activations = dict(
         leaky_relu=nn.LeakyReLU,
@@ -88,16 +86,20 @@ class SavedLightningModels(object):
             model = torch.load(models_root_path / 'model_class.obj')
         assert model is not None
 
-        return cls(weights=str(checkpoint_path), model=model)
+        return cls(weights=checkpoint_path, model=model)
 
     def __init__(self, **kwargs):
-        self.weights: str = kwargs.get('weights', '')
+        self.weights: Path = Path(kwargs.get('weights', ''))
+        self.hparams: Path = self.weights.parent / 'hparams.yaml'
 
         self.model = kwargs.get('model', None)
         assert self.model is not None
 
     def restore(self):
-        pretrained_model = self.model.load_from_checkpoint(self.weights)
+
+        pretrained_model = self.model.load_from_checkpoint(self.weights.__str__())
+        # , hparams_file=self.hparams.__str__())
         pretrained_model.eval()
         pretrained_model.freeze()
-        return pretrained_model
\ No newline at end of file
+        return pretrained_model
+
diff --git a/utils/tools.py b/utils/tools.py
index a85b495..08e2c20 100644
--- a/utils/tools.py
+++ b/utils/tools.py
@@ -1,6 +1,9 @@
 import importlib
+import inspect
 import pickle
 import shelve
+from argparse import ArgumentParser
+from ast import literal_eval
 from pathlib import Path, PurePath
 from typing import Union
 
@@ -9,6 +12,13 @@ import torch
 import random
 
 
+def auto_cast(a):
+  try:
+    return literal_eval(a)
+  except:
+    return a
+
+
 def to_one_hot(idx_array, max_classes):
     one_hot = np.zeros((idx_array.size, max_classes))
     one_hot[np.arange(idx_array.size), idx_array] = 1
@@ -54,3 +64,20 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
            continue
     raise AttributeError(f'Check the Model name. Possible model files are:\n{[x.name for x in module_paths]}')
 
+
+def add_argparse_args(cls, parent_parser):
+    parser = ArgumentParser(parents=[parent_parser], add_help=False)
+    full_arg_spec = inspect.getfullargspec(cls.__init__)
+    n_non_defaults = len(full_arg_spec.args) - (len(full_arg_spec.defaults) if full_arg_spec.defaults else 0)
+    for idx, argument in enumerate(full_arg_spec.args):
+        if argument == 'self':
+            continue
+        if idx < n_non_defaults:
+            parser.add_argument(f'--{argument}', type=int)
+        else:
+            argument_type = type(argument)
+            parser.add_argument(f'--{argument}',
+                                type=argument_type,
+                                default=full_arg_spec.defaults[idx - n_non_defaults]
+                                )
+    return parser