bringing brances up to date
This commit is contained in:
		| @@ -10,24 +10,31 @@ from ml_lib.audio_toolset.audio_io import LibrosaAudioToMel, MelToImage | |||||||
| from ml_lib.audio_toolset.mel_dataset import TorchMelDataset | from ml_lib.audio_toolset.mel_dataset import TorchMelDataset | ||||||
|  |  | ||||||
|  |  | ||||||
| class _AudioToMelDataset(Dataset, ABC): | import librosa | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class LibrosaAudioToMelDataset(Dataset): | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def audio_file_duration(self): |     def audio_file_duration(self): | ||||||
|         raise NotImplementedError |         return librosa.get_duration(sr=self.mel_kwargs.get('sr', None), filename=self.audio_path) | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def sampling_rate(self): |     def sampling_rate(self): | ||||||
|         raise NotImplementedError |         return self.mel_kwargs.get('sr', None) | ||||||
|  |  | ||||||
|     def __init__(self, audio_file_path, label, sample_segment_len=0, sample_hop_len=0, reset=False, |     def __init__(self, audio_file_path, label, sample_segment_len=0, sample_hop_len=0, reset=False, | ||||||
|                  audio_augmentations=None, mel_augmentations=None, mel_kwargs=None, **kwargs): |                  audio_augmentations=None, mel_augmentations=None, mel_kwargs=None, **kwargs): | ||||||
|         self.ignored_kwargs = kwargs |         super(LibrosaAudioToMelDataset, self).__init__() | ||||||
|  |  | ||||||
|  |         # audio_file, sampling_rate = librosa.load(self.audio_path, sr=sampling_rate) | ||||||
|  |         mel_kwargs.update(sr=mel_kwargs.get('sr', None) or librosa.get_samplerate(audio_file_path)) | ||||||
|         self.mel_kwargs = mel_kwargs |         self.mel_kwargs = mel_kwargs | ||||||
|         self.reset = reset |         self.reset = reset | ||||||
|         self.audio_path = Path(audio_file_path) |         self.audio_path = Path(audio_file_path) | ||||||
|  |  | ||||||
|         mel_folder_suffix = self.audio_path.parent.parent.name |         mel_folder_suffix = self.audio_path.parent.parent.name | ||||||
|  |  | ||||||
|         self.mel_file_path = Path(str(self.audio_path) |         self.mel_file_path = Path(str(self.audio_path) | ||||||
|                                   .replace(mel_folder_suffix, f'{mel_folder_suffix}_mel_folder') |                                   .replace(mel_folder_suffix, f'{mel_folder_suffix}_mel_folder') | ||||||
|                                   .replace(self.audio_path.suffix, '.npy')) |                                   .replace(self.audio_path.suffix, '.npy')) | ||||||
| @@ -38,59 +45,25 @@ class _AudioToMelDataset(Dataset, ABC): | |||||||
|                                        self.audio_file_duration, mel_kwargs['sr'], mel_kwargs['hop_length'], |                                        self.audio_file_duration, mel_kwargs['sr'], mel_kwargs['hop_length'], | ||||||
|                                        mel_kwargs['n_mels'], transform=mel_augmentations) |                                        mel_kwargs['n_mels'], transform=mel_augmentations) | ||||||
|  |  | ||||||
|     def _build_mel(self): |  | ||||||
|         raise NotImplementedError |  | ||||||
|  |  | ||||||
|     def __getitem__(self, item): |  | ||||||
|         try: |  | ||||||
|             return self.dataset[item] |  | ||||||
|         except FileNotFoundError: |  | ||||||
|             assert self._build_mel() |  | ||||||
|             return self.dataset[item] |  | ||||||
|  |  | ||||||
|     def __len__(self): |  | ||||||
|         return len(self.dataset) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| import librosa |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class LibrosaAudioToMelDataset(_AudioToMelDataset): |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def audio_file_duration(self): |  | ||||||
|         return librosa.get_duration(sr=self.mel_kwargs.get('sr', None), filename=self.audio_path) |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def sampling_rate(self): |  | ||||||
|         return self.mel_kwargs.get('sr', None) |  | ||||||
|  |  | ||||||
|     def __init__(self, audio_file_path, *args, **kwargs): |  | ||||||
|  |  | ||||||
|         audio_file_path = Path(audio_file_path) |  | ||||||
|         # audio_file, sampling_rate = librosa.load(self.audio_path, sr=sampling_rate) |  | ||||||
|         mel_kwargs = kwargs.get('mel_kwargs', dict()) |  | ||||||
|         mel_kwargs.update(sr=mel_kwargs.get('sr', None) or librosa.get_samplerate(audio_file_path)) |  | ||||||
|         kwargs.update(mel_kwargs=mel_kwargs) |  | ||||||
|  |  | ||||||
|         super(LibrosaAudioToMelDataset, self).__init__(audio_file_path, *args, **kwargs) |  | ||||||
|  |  | ||||||
|         self._mel_transform = Compose([LibrosaAudioToMel(**mel_kwargs), |         self._mel_transform = Compose([LibrosaAudioToMel(**mel_kwargs), | ||||||
|                                        MelToImage() |                                        MelToImage() | ||||||
|                                        ]) |                                        ]) | ||||||
|  |  | ||||||
|     def _build_mel(self): |     def __getitem__(self, item): | ||||||
|  |         return self.dataset[item] | ||||||
|  |  | ||||||
|  |     def __len__(self): | ||||||
|  |         return len(self.dataset) | ||||||
|  |  | ||||||
|  |     def build_mel(self): | ||||||
|         if self.reset: |         if self.reset: | ||||||
|             self.mel_file_path.unlink(missing_ok=True) |             self.mel_file_path.unlink(missing_ok=True) | ||||||
|         if not self.mel_file_path.exists(): |         if not self.mel_file_path.exists(): | ||||||
|             lockfile = Path(str(self.mel_file_path).replace(self.mel_file_path.suffix, '.lock')) |  | ||||||
|             self.mel_file_path.parent.mkdir(parents=True, exist_ok=True) |             self.mel_file_path.parent.mkdir(parents=True, exist_ok=True) | ||||||
|             lockfile.touch(exist_ok=False) |  | ||||||
|             raw_sample, _ = librosa.core.load(self.audio_path, sr=self.sampling_rate) |             raw_sample, _ = librosa.core.load(self.audio_path, sr=self.sampling_rate) | ||||||
|             mel_sample = self._mel_transform(raw_sample) |             mel_sample = self._mel_transform(raw_sample) | ||||||
|             with self.mel_file_path.open('wb') as mel_file: |             with self.mel_file_path.open('wb') as mel_file: | ||||||
|                 pickle.dump(mel_sample, mel_file, protocol=pickle.HIGHEST_PROTOCOL) |                 pickle.dump(mel_sample, mel_file, protocol=pickle.HIGHEST_PROTOCOL) | ||||||
|             lockfile.unlink(missing_ok=False) |  | ||||||
|         else: |         else: | ||||||
|             pass |             pass | ||||||
|  |  | ||||||
|   | |||||||
| @@ -11,13 +11,16 @@ class TorchMelDataset(Dataset): | |||||||
|     def __init__(self, mel_path, sub_segment_len, sub_segment_hop_len, label, audio_file_len, |     def __init__(self, mel_path, sub_segment_len, sub_segment_hop_len, label, audio_file_len, | ||||||
|                  sampling_rate, mel_hop_len, n_mels, transform=None, auto_pad_to_shape=True): |                  sampling_rate, mel_hop_len, n_mels, transform=None, auto_pad_to_shape=True): | ||||||
|         super(TorchMelDataset, self).__init__() |         super(TorchMelDataset, self).__init__() | ||||||
|         self.sampling_rate = sampling_rate |         self.sampling_rate = int(sampling_rate) | ||||||
|         self.audio_file_len = audio_file_len |         self.audio_file_len = int(audio_file_len) | ||||||
|         self.padding = AutoPadToShape((n_mels, sub_segment_len)) if auto_pad_to_shape and sub_segment_len else None |         if auto_pad_to_shape and sub_segment_len: | ||||||
|  |             self.padding = AutoPadToShape((int(n_mels), int(sub_segment_len))) | ||||||
|  |         else: | ||||||
|  |             self.padding = None | ||||||
|         self.path = Path(mel_path) |         self.path = Path(mel_path) | ||||||
|         self.sub_segment_len = sub_segment_len |         self.sub_segment_len = int(sub_segment_len) | ||||||
|         self.mel_hop_len = mel_hop_len |         self.mel_hop_len = int(mel_hop_len) | ||||||
|         self.sub_segment_hop_len = sub_segment_hop_len |         self.sub_segment_hop_len = int(sub_segment_hop_len) | ||||||
|         self.n = int((self.sampling_rate / self.mel_hop_len) * self.audio_file_len + 1) |         self.n = int((self.sampling_rate / self.mel_hop_len) * self.audio_file_len + 1) | ||||||
|         if self.sub_segment_len and self.sub_segment_hop_len: |         if self.sub_segment_len and self.sub_segment_hop_len: | ||||||
|             self.offsets = list(range(0, self.n - self.sub_segment_len, self.sub_segment_hop_len)) |             self.offsets = list(range(0, self.n - self.sub_segment_len, self.sub_segment_hop_len)) | ||||||
| @@ -27,8 +30,6 @@ class TorchMelDataset(Dataset): | |||||||
|         self.transform = transform |         self.transform = transform | ||||||
|  |  | ||||||
|     def __getitem__(self, item): |     def __getitem__(self, item): | ||||||
|         while Path(str(self.path).replace(self.path.suffix, '.lock')).exists(): |  | ||||||
|             time.sleep(0.01) |  | ||||||
|         with self.path.open('rb') as mel_file: |         with self.path.open('rb') as mel_file: | ||||||
|             mel_spec = pickle.load(mel_file, fix_imports=True) |             mel_spec = pickle.load(mel_file, fix_imports=True) | ||||||
|         start = self.offsets[item] |         start = self.offsets[item] | ||||||
| @@ -38,7 +39,7 @@ class TorchMelDataset(Dataset): | |||||||
|             snippet = self.transform(snippet) |             snippet = self.transform(snippet) | ||||||
|         if self.padding: |         if self.padding: | ||||||
|             snippet = self.padding(snippet) |             snippet = self.padding(snippet) | ||||||
|         return snippet, self.label |         return self.path.__str__(), snippet, self.label | ||||||
|  |  | ||||||
|     def __len__(self): |     def __len__(self): | ||||||
|         return len(self.offsets) |         return len(self.offsets) | ||||||
|   | |||||||
| @@ -1,67 +0,0 @@ | |||||||
|  |  | ||||||
| import torchaudio |  | ||||||
| if sys.platform =='windows': |  | ||||||
|     torchaudio.set_audio_backend('soundfile') |  | ||||||
| else: |  | ||||||
|     torchaudio.set_audio_backend('sox_io') |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class PyTorchAudioToMelDataset(_AudioToMelDataset): |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def audio_file_duration(self): |  | ||||||
|         info_obj = torchaudio.info(self.audio_path) |  | ||||||
|         return info_obj.num_frames / info_obj.sample_rate |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def sampling_rate(self): |  | ||||||
|         return self.mel_kwargs['sample_rate'] |  | ||||||
|  |  | ||||||
|     def __init__(self, audio_file_path, *args, **kwargs): |  | ||||||
|         super(PyTorchAudioToMelDataset, self).__init__(audio_file_path, *args, **kwargs) |  | ||||||
|  |  | ||||||
|         audio_file_path = Path(audio_file_path) |  | ||||||
|         # audio_file, sampling_rate = librosa.load(self.audio_path, sr=sampling_rate) |  | ||||||
|  |  | ||||||
|         from torchaudio.transforms import MelSpectrogram |  | ||||||
|         self._mel_transform = Compose([MelSpectrogram(**self.mel_kwargs), |  | ||||||
|                                        MelToImage() |  | ||||||
|                                        ]) |  | ||||||
|  |  | ||||||
|     def _build_mel(self): |  | ||||||
|         if self.reset: |  | ||||||
|             self.mel_file_path.unlink(missing_ok=True) |  | ||||||
|         if not self.mel_file_path.exists(): |  | ||||||
|             self.mel_file_path.parent.mkdir(parents=True, exist_ok=True) |  | ||||||
|             lock_file = Path(str(self.mel_file_path).replace(self.mel_file_path.suffix, '.lock')) |  | ||||||
|             lock_file.touch(exist_ok=False) |  | ||||||
|  |  | ||||||
|             try: |  | ||||||
|                 audio_sample, sample_rate = torchaudio.load(self.audio_path) |  | ||||||
|             except RuntimeError: |  | ||||||
|                 import soundfile |  | ||||||
|  |  | ||||||
|                 data, samplerate = soundfile.read(self.audio_path) |  | ||||||
|                 # sf.available_formats() |  | ||||||
|                 # sf.available_subtypes() |  | ||||||
|                 soundfile.write(self.audio_path, data, samplerate, subtype='PCM_32') |  | ||||||
|  |  | ||||||
|                 audio_sample, sample_rate = torchaudio.load(self.audio_path) |  | ||||||
|             if sample_rate != self.sampling_rate: |  | ||||||
|                 resample = torchaudio.transforms.Resample(orig_freq=int(sample_rate), new_freq=int(self.sampling_rate)) |  | ||||||
|                 audio_sample = resample(audio_sample) |  | ||||||
|             if audio_sample.shape[0] > 1: |  | ||||||
|                 # Transform Stereo to Mono |  | ||||||
|                 audio_sample = audio_sample.mean(dim=0, keepdim=True) |  | ||||||
|             mel_sample = self._mel_transform(audio_sample) |  | ||||||
|             with self.mel_file_path.open('wb') as mel_file: |  | ||||||
|                 pickle.dump(mel_sample, mel_file, protocol=pickle.HIGHEST_PROTOCOL) |  | ||||||
|             lock_file.unlink() |  | ||||||
|         else: |  | ||||||
|             # print(f"Already existed.. Skipping {filename}") |  | ||||||
|             # mel_file = mel_file |  | ||||||
|             pass |  | ||||||
|  |  | ||||||
|         # with mel_file.open(mode='rb') as f: |  | ||||||
|         #     mel_sample = pickle.load(f, fix_imports=True) |  | ||||||
|         return self.mel_file_path.exists() |  | ||||||
							
								
								
									
										47
									
								
								metrics/attention_rollout.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								metrics/attention_rollout.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,47 @@ | |||||||
|  | import numpy as np | ||||||
|  |  | ||||||
|  | from einops import reduce | ||||||
|  |  | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | from sklearn.ensemble import IsolationForest | ||||||
|  | from sklearn.metrics import recall_score, roc_auc_score, average_precision_score | ||||||
|  |  | ||||||
|  | from ml_lib.metrics._base_score import _BaseScores | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class AttentionRollout(_BaseScores): | ||||||
|  |  | ||||||
|  |     def __init__(self, *args): | ||||||
|  |         super(AttentionRollout, self).__init__(*args) | ||||||
|  |         pass | ||||||
|  |  | ||||||
|  |     def __call__(self, outputs): | ||||||
|  |         summary_dict = dict() | ||||||
|  |         ####################################################################################### | ||||||
|  |         # Additional Score  -  Histogram Distances - Image Plotting | ||||||
|  |         ####################################################################################### | ||||||
|  |         # | ||||||
|  |         # INIT | ||||||
|  |         attn_weights = [output['attn_weights'].cpu().numpy() for output in outputs] | ||||||
|  |         attn_reduce_heads = [reduce(x, '') for x in attn_weights] | ||||||
|  |  | ||||||
|  |         if self.model.params.use_residual: | ||||||
|  |             residual_att = np.eye(att_mat.shape[1])[None, ...] | ||||||
|  |             aug_att_mat = att_mat + residual_att | ||||||
|  |             aug_att_mat = aug_att_mat / aug_att_mat.sum(axis=-1)[..., None] | ||||||
|  |         else: | ||||||
|  |             aug_att_mat = att_mat | ||||||
|  |  | ||||||
|  |         joint_attentions = np.zeros(aug_att_mat.shape) | ||||||
|  |  | ||||||
|  |         layers = joint_attentions.shape[0] | ||||||
|  |         joint_attentions[0] = aug_att_mat[0] | ||||||
|  |         for i in np.arange(1, layers): | ||||||
|  |             joint_attentions[i] = aug_att_mat[i].dot(joint_attentions[i - 1]) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -113,17 +113,17 @@ class MultiClassScores(_BaseScores): | |||||||
|         ####################################################################################### |         ####################################################################################### | ||||||
|         # |         # | ||||||
|         # Confusion matrix |         # Confusion matrix | ||||||
|  |         fig1, ax1 = plt.subplots(dpi=96) | ||||||
|         cm = confusion_matrix([class_names[x] for x in y_true], [class_names[x] for x in y_pred_max], |         cm = confusion_matrix([class_names[x] for x in y_true], [class_names[x] for x in y_pred_max], | ||||||
|                               labels=[class_names[key] for key in class_names.keys()], |                               labels=[class_names[key] for key in class_names.keys()], | ||||||
|                               normalize='all') |                               normalize='all') | ||||||
|         disp = ConfusionMatrixDisplay(confusion_matrix=cm, |         disp = ConfusionMatrixDisplay(confusion_matrix=cm, | ||||||
|                                       display_labels=[class_names[i] for i in range(self.model.n_classes)] |                                       display_labels=[class_names[i] for i in range(self.model.n_classes)] | ||||||
|                                       ) |                                       ) | ||||||
|         disp.plot(include_values=True) |         disp.plot(include_values=True, ax=ax1) | ||||||
|  |  | ||||||
|         self.model.logger.log_image('Confusion_Matrix', image=disp.figure_, step=self.model.current_epoch) |         self.model.logger.log_image('Confusion_Matrix', image=fig1, step=self.model.current_epoch) | ||||||
|         # self.model.logger.log_image('Confusion_Matrix', image=disp.figure_, step=self.model.current_epoch, ext='pdf') |         # self.model.logger.log_image('Confusion_Matrix', image=disp.figure_, step=self.model.current_epoch, ext='pdf') | ||||||
|  |  | ||||||
|         plt.close('all') |         plt.close('all') | ||||||
|         return summary_dict |         return summary_dict | ||||||
|   | |||||||
| @@ -291,19 +291,17 @@ class TransformerModule(ShapeMixin, nn.Module): | |||||||
|  |  | ||||||
|         for attn, mlp in zip(self.attns, self.mlps): |         for attn, mlp in zip(self.attns, self.mlps): | ||||||
|             # Attention |             # Attention | ||||||
|             skip_connection = tensor.clone() |             attn_tensor = self.norm(tensor) | ||||||
|             tensor = self.norm(tensor) |  | ||||||
|             if return_attn_weights: |             if return_attn_weights: | ||||||
|                 tensor, attn_weight = attn(tensor, mask=mask, return_attn_weights=return_attn_weights) |                 attn_tensor, attn_weight = attn(attn_tensor, mask=mask, return_attn_weights=return_attn_weights) | ||||||
|                 attn_weights.append(attn_weight) |                 attn_weights.append(attn_weight) | ||||||
|             else: |             else: | ||||||
|                 tensor = attn(tensor, mask=mask) |                 attn_tensor = attn(attn_tensor, mask=mask) | ||||||
|             tensor = tensor + skip_connection |             tensor = attn_tensor + tensor | ||||||
|  |  | ||||||
|             # MLP |             # MLP | ||||||
|             skip_connection = tensor.clone() |             mlp_tensor = self.norm(tensor) | ||||||
|             tensor = self.norm(tensor) |             mlp_tensor = mlp(mlp_tensor) | ||||||
|             tensor = mlp(tensor) |             tensor = tensor + mlp_tensor | ||||||
|             tensor = tensor + skip_connection |  | ||||||
|  |  | ||||||
|         return (tensor, attn_weights) if return_attn_weights else tensor |         return (tensor, attn_weights) if return_attn_weights else tensor | ||||||
|   | |||||||
| @@ -1,3 +1,6 @@ | |||||||
|  | import inspect | ||||||
|  | from argparse import ArgumentParser | ||||||
|  |  | ||||||
| from functools import reduce | from functools import reduce | ||||||
|  |  | ||||||
| from abc import ABC | from abc import ABC | ||||||
| @@ -5,13 +8,14 @@ from pathlib import Path | |||||||
|  |  | ||||||
| import torch | import torch | ||||||
| from operator import mul | from operator import mul | ||||||
|  | from pytorch_lightning.utilities import argparse_utils | ||||||
| from torch import nn | from torch import nn | ||||||
| from torch.nn import functional as F, Unfold | from torch.nn import functional as F, Unfold | ||||||
|  |  | ||||||
| # Utility - Modules | # Utility - Modules | ||||||
| ################### | ################### | ||||||
| from ..utils.model_io import ModelParameters | from ..utils.model_io import ModelParameters | ||||||
| from ..utils.tools import locate_and_import_class | from ..utils.tools import locate_and_import_class, add_argparse_args | ||||||
|  |  | ||||||
| try: | try: | ||||||
|     import pytorch_lightning as pl |     import pytorch_lightning as pl | ||||||
| @@ -32,14 +36,18 @@ try: | |||||||
|                 print(e) |                 print(e) | ||||||
|                 return -1 |                 return -1 | ||||||
|  |  | ||||||
|         def __init__(self, hparams): |         @classmethod | ||||||
|             super(LightningBaseModule, self).__init__() |         def from_argparse_args(cls, args, **kwargs): | ||||||
|  |             return argparse_utils.from_argparse_args(cls, args, **kwargs) | ||||||
|  |  | ||||||
|             # Set Parameters |         @classmethod | ||||||
|             ################################ |         def add_argparse_args(cls, parent_parser): | ||||||
|             self.hparams = hparams |             return add_argparse_args(cls, parent_parser) | ||||||
|             self.params = ModelParameters(hparams) |  | ||||||
|             self.lr = self.params.lr or 1e-4 |         def __init__(self, model_parameters, weight_init='xavier_normal_'): | ||||||
|  |             super(LightningBaseModule, self).__init__() | ||||||
|  |             self._weight_init = weight_init | ||||||
|  |             self.params = ModelParameters(model_parameters) | ||||||
|  |  | ||||||
|         def size(self): |         def size(self): | ||||||
|             return self.shape |             return self.shape | ||||||
| @@ -47,15 +55,6 @@ try: | |||||||
|         def additional_scores(self, outputs): |         def additional_scores(self, outputs): | ||||||
|             raise NotImplementedError |             raise NotImplementedError | ||||||
|  |  | ||||||
|         @property |  | ||||||
|         def dataset_class(self): |  | ||||||
|             try: |  | ||||||
|                 return locate_and_import_class(self.params.class_name, folder_path='datasets') |  | ||||||
|             except AttributeError as e: |  | ||||||
|                 raise AttributeError(f'The dataset alias you provided ("{self.params.class_name}") ' + |  | ||||||
|                                      f'was not found!\n' + |  | ||||||
|                                      f'{e}') |  | ||||||
|  |  | ||||||
|         def save_to_disk(self, model_path): |         def save_to_disk(self, model_path): | ||||||
|             Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True) |             Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True) | ||||||
|             if not (model_path / 'model_class.obj').exists(): |             if not (model_path / 'model_class.obj').exists(): | ||||||
| @@ -86,8 +85,12 @@ try: | |||||||
|         def test_epoch_end(self, outputs): |         def test_epoch_end(self, outputs): | ||||||
|             raise NotImplementedError |             raise NotImplementedError | ||||||
|  |  | ||||||
|         def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_): |         def init_weights(self): | ||||||
|             weight_initializer = WeightInit(in_place_init_function=in_place_init_func_) |             if isinstance(self._weight_init, str): | ||||||
|  |                 mod = __import__('torch.nn.init', fromlist=[self._weight_init]) | ||||||
|  |                 self._weight_init = getattr(mod, self._weight_init) | ||||||
|  |             assert callable(self._weight_init) | ||||||
|  |             weight_initializer = WeightInit(in_place_init_function=self._weight_init) | ||||||
|             self.apply(weight_initializer) |             self.apply(weight_initializer) | ||||||
|  |  | ||||||
|     module_types = (LightningBaseModule, nn.Module,) |     module_types = (LightningBaseModule, nn.Module,) | ||||||
|   | |||||||
							
								
								
									
										29
									
								
								utils/_basedatamodule.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								utils/_basedatamodule.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | |||||||
|  | from pytorch_lightning import LightningDataModule | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # Dataset Options | ||||||
|  | from ml_lib.utils.tools import add_argparse_args | ||||||
|  |  | ||||||
|  | DATA_OPTION_test = 'test' | ||||||
|  | DATA_OPTION_devel = 'devel' | ||||||
|  | DATA_OPTION_train = 'train' | ||||||
|  | DATA_OPTIONS = [DATA_OPTION_train, DATA_OPTION_devel, DATA_OPTION_test] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class _BaseDataModule(LightningDataModule): | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def shape(self): | ||||||
|  |         return self.datasets[DATA_OPTION_train].sample_shape | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def add_argparse_args(cls, parent_parser): | ||||||
|  |         return add_argparse_args(cls, parent_parser) | ||||||
|  |  | ||||||
|  |     def __init__(self, *args, **kwargs): | ||||||
|  |         super().__init__(*args, **kwargs) | ||||||
|  |         self.datasets = dict() | ||||||
|  |  | ||||||
|  |     def transfer_batch_to_device(self, batch, device): | ||||||
|  |         return batch.to(device) | ||||||
|  |  | ||||||
| @@ -1,19 +1,34 @@ | |||||||
| from abc import ABC | import inspect | ||||||
|  | from argparse import ArgumentParser | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  |  | ||||||
|  | import os | ||||||
| from pytorch_lightning.loggers.base import LightningLoggerBase | from pytorch_lightning.loggers.base import LightningLoggerBase | ||||||
| from pytorch_lightning.loggers.neptune import NeptuneLogger | from pytorch_lightning.loggers.neptune import NeptuneLogger | ||||||
| from neptune.api_exceptions import ProjectNotFound | from neptune.api_exceptions import ProjectNotFound | ||||||
| # noinspection PyUnresolvedReferences |  | ||||||
| from pytorch_lightning.loggers.csv_logs import CSVLogger | from pytorch_lightning.loggers.csv_logs import CSVLogger | ||||||
|  | from pytorch_lightning.utilities import argparse_utils | ||||||
|  |  | ||||||
| from .config import Config | from ml_lib.utils.tools import add_argparse_args | ||||||
|  |  | ||||||
|  |  | ||||||
| class Logger(LightningLoggerBase, ABC): | class Logger(LightningLoggerBase): | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def from_argparse_args(cls, args, **kwargs): | ||||||
|  |         return argparse_utils.from_argparse_args(cls, args, **kwargs) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def name(self) -> str: | ||||||
|  |         return self._name | ||||||
|  |  | ||||||
|     media_dir = 'media' |     media_dir = 'media' | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def add_argparse_args(cls, parent_parser): | ||||||
|  |         return add_argparse_args(cls, parent_parser) | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def experiment(self): |     def experiment(self): | ||||||
|         if self.debug: |         if self.debug: | ||||||
| @@ -25,27 +40,23 @@ class Logger(LightningLoggerBase, ABC): | |||||||
|     def log_dir(self): |     def log_dir(self): | ||||||
|         return Path(self.csvlogger.experiment.log_dir) |         return Path(self.csvlogger.experiment.log_dir) | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def name(self): |  | ||||||
|         return self.config.name |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def project_name(self): |     def project_name(self): | ||||||
|         return f"{self.config.project.owner}/{self.config.project.name.replace('_', '-')}" |         return f"{self.owner}/{self.name.replace('_', '-')}" | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def version(self): |     def version(self): | ||||||
|         return self.config.get('main', 'seed') |         return self.seed | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def save_dir(self): | ||||||
|  |         return self.log_dir | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def outpath(self): |     def outpath(self): | ||||||
|         return Path(self.config.train.outpath) / self.config.model.type |         return Path(self.root_out) / self.model_name | ||||||
|  |  | ||||||
|     @property |     def __init__(self, owner, neptune_key, model_name, project_name='', outpath='output', seed=69, debug=False): | ||||||
|     def exp_path(self): |  | ||||||
|         return Path(self.outpath) / self.name |  | ||||||
|  |  | ||||||
|     def __init__(self, config: Config): |  | ||||||
|         """ |         """ | ||||||
|         params (dict|None): Optional. Parameters of the experiment. After experiment creation params are read-only. |         params (dict|None): Optional. Parameters of the experiment. After experiment creation params are read-only. | ||||||
|            Parameters are displayed in the experiment’s Parameters section and each key-value pair can be |            Parameters are displayed in the experiment’s Parameters section and each key-value pair can be | ||||||
| @@ -59,19 +70,19 @@ class Logger(LightningLoggerBase, ABC): | |||||||
|         """ |         """ | ||||||
|         super(Logger, self).__init__() |         super(Logger, self).__init__() | ||||||
|  |  | ||||||
|         self.config = config |         self.debug = debug | ||||||
|         self.debug = self.config.main.debug |         self._name = project_name or Path(os.getcwd()).name if not self.debug else 'test' | ||||||
|         if self.debug: |         self.owner = owner if not self.debug else 'testuser' | ||||||
|             self.config.add_section('project') |         self.neptune_key = neptune_key if not self.debug else 'XXX' | ||||||
|             self.config.set('project', 'owner', 'testuser') |         self.root_out = outpath if not self.debug else 'debug_out' | ||||||
|             self.config.set('project', 'name', 'test') |         self.seed = seed | ||||||
|             self.config.set('project', 'neptune_key', 'XXX') |         self.model_name = model_name | ||||||
|  |  | ||||||
|         self._csvlogger_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name) |         self._csvlogger_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name) | ||||||
|         self._neptune_kwargs = dict(offline_mode=self.debug, |         self._neptune_kwargs = dict(offline_mode=self.debug, | ||||||
|                                     api_key=self.config.project.neptune_key, |                                     api_key=self.neptune_key, | ||||||
|                                     experiment_name=self.name, |                                     experiment_name=self.name, | ||||||
|                                     project_name=self.project_name, |                                     project_name=self.project_name) | ||||||
|                                     params=self.config.model_paramters) |  | ||||||
|         try: |         try: | ||||||
|             self.neptunelogger = NeptuneLogger(**self._neptune_kwargs) |             self.neptunelogger = NeptuneLogger(**self._neptune_kwargs) | ||||||
|         except ProjectNotFound as e: |         except ProjectNotFound as e: | ||||||
| @@ -79,7 +90,6 @@ class Logger(LightningLoggerBase, ABC): | |||||||
|             print(e) |             print(e) | ||||||
|  |  | ||||||
|         self.csvlogger = CSVLogger(**self._csvlogger_kwargs) |         self.csvlogger = CSVLogger(**self._csvlogger_kwargs) | ||||||
|         self.log_config_as_ini() |  | ||||||
|  |  | ||||||
|     def log_hyperparams(self, params): |     def log_hyperparams(self, params): | ||||||
|         self.neptunelogger.log_hyperparams(params) |         self.neptunelogger.log_hyperparams(params) | ||||||
| @@ -95,19 +105,15 @@ class Logger(LightningLoggerBase, ABC): | |||||||
|         self.csvlogger.close() |         self.csvlogger.close() | ||||||
|         self.neptunelogger.close() |         self.neptunelogger.close() | ||||||
|  |  | ||||||
|     def log_config_as_ini(self): |     def log_text(self, name, text, **_): | ||||||
|         self.config.write(self.log_dir / 'config.ini') |  | ||||||
|  |  | ||||||
|     def log_text(self, name, text, step_nb=0, **_): |  | ||||||
|         # TODO Implement Offline variant. |         # TODO Implement Offline variant. | ||||||
|         self.neptunelogger.log_text(name, text, step_nb) |         self.neptunelogger.log_text(name, text) | ||||||
|  |  | ||||||
|     def log_metric(self, metric_name, metric_value, **kwargs): |     def log_metric(self, metric_name, metric_value, **kwargs): | ||||||
|         self.csvlogger.log_metrics(dict(metric_name=metric_value)) |         self.csvlogger.log_metrics(dict(metric_name=metric_value)) | ||||||
|         self.neptunelogger.log_metric(metric_name, metric_value, **kwargs) |         self.neptunelogger.log_metric(metric_name, metric_value, **kwargs) | ||||||
|  |  | ||||||
|     def log_image(self, name, image, ext='png', **kwargs): |     def log_image(self, name, image, ext='png', **kwargs): | ||||||
|  |  | ||||||
|         step = kwargs.get('step', None) |         step = kwargs.get('step', None) | ||||||
|         image_name = f'{step}_{name}' if step is not None else name |         image_name = f'{step}_{name}' if step is not None else name | ||||||
|         image_path = self.log_dir / self.media_dir / f'{image_name}.{ext[1:] if ext.startswith(".") else ext}' |         image_path = self.log_dir / self.media_dir / f'{image_name}.{ext[1:] if ext.startswith(".") else ext}' | ||||||
|   | |||||||
| @@ -13,6 +13,10 @@ from torch import nn | |||||||
| # Hyperparamter Object | # Hyperparamter Object | ||||||
| class ModelParameters(Namespace, Mapping): | class ModelParameters(Namespace, Mapping): | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def as_dict(self): | ||||||
|  |         return {key: self.get(key) if key != 'activation' else self.activation_as_string for key in self.keys()} | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def activation_as_string(self): |     def activation_as_string(self): | ||||||
|         return self['activation'].lower() |         return self['activation'].lower() | ||||||
| @@ -50,13 +54,7 @@ class ModelParameters(Namespace, Mapping): | |||||||
|         if name == 'activation': |         if name == 'activation': | ||||||
|             return self._activations[self['activation'].lower()] |             return self._activations[self['activation'].lower()] | ||||||
|         else: |         else: | ||||||
|             try: |             return super(ModelParameters, self).__getattribute__(name) | ||||||
|                 return super(ModelParameters, self).__getattribute__(name) |  | ||||||
|             except AttributeError as e: |  | ||||||
|                 if name == 'stretch': |  | ||||||
|                     return False |  | ||||||
|                 else: |  | ||||||
|                     return None |  | ||||||
|  |  | ||||||
|     _activations = dict( |     _activations = dict( | ||||||
|         leaky_relu=nn.LeakyReLU, |         leaky_relu=nn.LeakyReLU, | ||||||
| @@ -88,16 +86,20 @@ class SavedLightningModels(object): | |||||||
|             model = torch.load(models_root_path / 'model_class.obj') |             model = torch.load(models_root_path / 'model_class.obj') | ||||||
|         assert model is not None |         assert model is not None | ||||||
|  |  | ||||||
|         return cls(weights=str(checkpoint_path), model=model) |         return cls(weights=checkpoint_path, model=model) | ||||||
|  |  | ||||||
|     def __init__(self, **kwargs): |     def __init__(self, **kwargs): | ||||||
|         self.weights: str = kwargs.get('weights', '') |         self.weights: Path = Path(kwargs.get('weights', '')) | ||||||
|  |         self.hparams: Path = self.weights.parent / 'hparams.yaml' | ||||||
|  |  | ||||||
|         self.model = kwargs.get('model', None) |         self.model = kwargs.get('model', None) | ||||||
|         assert self.model is not None |         assert self.model is not None | ||||||
|  |  | ||||||
|     def restore(self): |     def restore(self): | ||||||
|         pretrained_model = self.model.load_from_checkpoint(self.weights) |  | ||||||
|  |         pretrained_model = self.model.load_from_checkpoint(self.weights.__str__()) | ||||||
|  |         # , hparams_file=self.hparams.__str__()) | ||||||
|         pretrained_model.eval() |         pretrained_model.eval() | ||||||
|         pretrained_model.freeze() |         pretrained_model.freeze() | ||||||
|         return pretrained_model |         return pretrained_model | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,6 +1,9 @@ | |||||||
| import importlib | import importlib | ||||||
|  | import inspect | ||||||
| import pickle | import pickle | ||||||
| import shelve | import shelve | ||||||
|  | from argparse import ArgumentParser | ||||||
|  | from ast import literal_eval | ||||||
| from pathlib import Path, PurePath | from pathlib import Path, PurePath | ||||||
| from typing import Union | from typing import Union | ||||||
|  |  | ||||||
| @@ -9,6 +12,13 @@ import torch | |||||||
| import random | import random | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def auto_cast(a): | ||||||
|  |   try: | ||||||
|  |     return literal_eval(a) | ||||||
|  |   except: | ||||||
|  |     return a | ||||||
|  |  | ||||||
|  |  | ||||||
| def to_one_hot(idx_array, max_classes): | def to_one_hot(idx_array, max_classes): | ||||||
|     one_hot = np.zeros((idx_array.size, max_classes)) |     one_hot = np.zeros((idx_array.size, max_classes)) | ||||||
|     one_hot[np.arange(idx_array.size), idx_array] = 1 |     one_hot[np.arange(idx_array.size), idx_array] = 1 | ||||||
| @@ -54,3 +64,20 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''): | |||||||
|            continue |            continue | ||||||
|     raise AttributeError(f'Check the Model name. Possible model files are:\n{[x.name for x in module_paths]}') |     raise AttributeError(f'Check the Model name. Possible model files are:\n{[x.name for x in module_paths]}') | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def add_argparse_args(cls, parent_parser): | ||||||
|  |     parser = ArgumentParser(parents=[parent_parser], add_help=False) | ||||||
|  |     full_arg_spec = inspect.getfullargspec(cls.__init__) | ||||||
|  |     n_non_defaults = len(full_arg_spec.args) - (len(full_arg_spec.defaults) if full_arg_spec.defaults else 0) | ||||||
|  |     for idx, argument in enumerate(full_arg_spec.args): | ||||||
|  |         if argument == 'self': | ||||||
|  |             continue | ||||||
|  |         if idx < n_non_defaults: | ||||||
|  |             parser.add_argument(f'--{argument}', type=int) | ||||||
|  |         else: | ||||||
|  |             argument_type = type(argument) | ||||||
|  |             parser.add_argument(f'--{argument}', | ||||||
|  |                                 type=argument_type, | ||||||
|  |                                 default=full_arg_spec.defaults[idx - n_non_defaults] | ||||||
|  |                                 ) | ||||||
|  |     return parser | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Steffen Illium
					Steffen Illium