bringing brances up to date
This commit is contained in:
		| @@ -10,24 +10,31 @@ from ml_lib.audio_toolset.audio_io import LibrosaAudioToMel, MelToImage | ||||
| from ml_lib.audio_toolset.mel_dataset import TorchMelDataset | ||||
|  | ||||
|  | ||||
| class _AudioToMelDataset(Dataset, ABC): | ||||
| import librosa | ||||
|  | ||||
|  | ||||
| class LibrosaAudioToMelDataset(Dataset): | ||||
|  | ||||
|     @property | ||||
|     def audio_file_duration(self): | ||||
|         raise NotImplementedError | ||||
|         return librosa.get_duration(sr=self.mel_kwargs.get('sr', None), filename=self.audio_path) | ||||
|  | ||||
|     @property | ||||
|     def sampling_rate(self): | ||||
|         raise NotImplementedError | ||||
|         return self.mel_kwargs.get('sr', None) | ||||
|  | ||||
|     def __init__(self, audio_file_path, label, sample_segment_len=0, sample_hop_len=0, reset=False, | ||||
|                  audio_augmentations=None, mel_augmentations=None, mel_kwargs=None, **kwargs): | ||||
|         self.ignored_kwargs = kwargs | ||||
|         super(LibrosaAudioToMelDataset, self).__init__() | ||||
|  | ||||
|         # audio_file, sampling_rate = librosa.load(self.audio_path, sr=sampling_rate) | ||||
|         mel_kwargs.update(sr=mel_kwargs.get('sr', None) or librosa.get_samplerate(audio_file_path)) | ||||
|         self.mel_kwargs = mel_kwargs | ||||
|         self.reset = reset | ||||
|         self.audio_path = Path(audio_file_path) | ||||
|  | ||||
|         mel_folder_suffix = self.audio_path.parent.parent.name | ||||
|  | ||||
|         self.mel_file_path = Path(str(self.audio_path) | ||||
|                                   .replace(mel_folder_suffix, f'{mel_folder_suffix}_mel_folder') | ||||
|                                   .replace(self.audio_path.suffix, '.npy')) | ||||
| @@ -38,59 +45,25 @@ class _AudioToMelDataset(Dataset, ABC): | ||||
|                                        self.audio_file_duration, mel_kwargs['sr'], mel_kwargs['hop_length'], | ||||
|                                        mel_kwargs['n_mels'], transform=mel_augmentations) | ||||
|  | ||||
|     def _build_mel(self): | ||||
|         raise NotImplementedError | ||||
|         self._mel_transform = Compose([LibrosaAudioToMel(**mel_kwargs), | ||||
|                                        MelToImage() | ||||
|                                        ]) | ||||
|  | ||||
|     def __getitem__(self, item): | ||||
|         try: | ||||
|             return self.dataset[item] | ||||
|         except FileNotFoundError: | ||||
|             assert self._build_mel() | ||||
|         return self.dataset[item] | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self.dataset) | ||||
|  | ||||
|  | ||||
| import librosa | ||||
|  | ||||
|  | ||||
| class LibrosaAudioToMelDataset(_AudioToMelDataset): | ||||
|  | ||||
|     @property | ||||
|     def audio_file_duration(self): | ||||
|         return librosa.get_duration(sr=self.mel_kwargs.get('sr', None), filename=self.audio_path) | ||||
|  | ||||
|     @property | ||||
|     def sampling_rate(self): | ||||
|         return self.mel_kwargs.get('sr', None) | ||||
|  | ||||
|     def __init__(self, audio_file_path, *args, **kwargs): | ||||
|  | ||||
|         audio_file_path = Path(audio_file_path) | ||||
|         # audio_file, sampling_rate = librosa.load(self.audio_path, sr=sampling_rate) | ||||
|         mel_kwargs = kwargs.get('mel_kwargs', dict()) | ||||
|         mel_kwargs.update(sr=mel_kwargs.get('sr', None) or librosa.get_samplerate(audio_file_path)) | ||||
|         kwargs.update(mel_kwargs=mel_kwargs) | ||||
|  | ||||
|         super(LibrosaAudioToMelDataset, self).__init__(audio_file_path, *args, **kwargs) | ||||
|  | ||||
|         self._mel_transform = Compose([LibrosaAudioToMel(**mel_kwargs), | ||||
|                                        MelToImage() | ||||
|                                        ]) | ||||
|  | ||||
|     def _build_mel(self): | ||||
|     def build_mel(self): | ||||
|         if self.reset: | ||||
|             self.mel_file_path.unlink(missing_ok=True) | ||||
|         if not self.mel_file_path.exists(): | ||||
|             lockfile = Path(str(self.mel_file_path).replace(self.mel_file_path.suffix, '.lock')) | ||||
|             self.mel_file_path.parent.mkdir(parents=True, exist_ok=True) | ||||
|             lockfile.touch(exist_ok=False) | ||||
|             raw_sample, _ = librosa.core.load(self.audio_path, sr=self.sampling_rate) | ||||
|             mel_sample = self._mel_transform(raw_sample) | ||||
|             with self.mel_file_path.open('wb') as mel_file: | ||||
|                 pickle.dump(mel_sample, mel_file, protocol=pickle.HIGHEST_PROTOCOL) | ||||
|             lockfile.unlink(missing_ok=False) | ||||
|         else: | ||||
|             pass | ||||
|  | ||||
|   | ||||
| @@ -11,13 +11,16 @@ class TorchMelDataset(Dataset): | ||||
|     def __init__(self, mel_path, sub_segment_len, sub_segment_hop_len, label, audio_file_len, | ||||
|                  sampling_rate, mel_hop_len, n_mels, transform=None, auto_pad_to_shape=True): | ||||
|         super(TorchMelDataset, self).__init__() | ||||
|         self.sampling_rate = sampling_rate | ||||
|         self.audio_file_len = audio_file_len | ||||
|         self.padding = AutoPadToShape((n_mels, sub_segment_len)) if auto_pad_to_shape and sub_segment_len else None | ||||
|         self.sampling_rate = int(sampling_rate) | ||||
|         self.audio_file_len = int(audio_file_len) | ||||
|         if auto_pad_to_shape and sub_segment_len: | ||||
|             self.padding = AutoPadToShape((int(n_mels), int(sub_segment_len))) | ||||
|         else: | ||||
|             self.padding = None | ||||
|         self.path = Path(mel_path) | ||||
|         self.sub_segment_len = sub_segment_len | ||||
|         self.mel_hop_len = mel_hop_len | ||||
|         self.sub_segment_hop_len = sub_segment_hop_len | ||||
|         self.sub_segment_len = int(sub_segment_len) | ||||
|         self.mel_hop_len = int(mel_hop_len) | ||||
|         self.sub_segment_hop_len = int(sub_segment_hop_len) | ||||
|         self.n = int((self.sampling_rate / self.mel_hop_len) * self.audio_file_len + 1) | ||||
|         if self.sub_segment_len and self.sub_segment_hop_len: | ||||
|             self.offsets = list(range(0, self.n - self.sub_segment_len, self.sub_segment_hop_len)) | ||||
| @@ -27,8 +30,6 @@ class TorchMelDataset(Dataset): | ||||
|         self.transform = transform | ||||
|  | ||||
|     def __getitem__(self, item): | ||||
|         while Path(str(self.path).replace(self.path.suffix, '.lock')).exists(): | ||||
|             time.sleep(0.01) | ||||
|         with self.path.open('rb') as mel_file: | ||||
|             mel_spec = pickle.load(mel_file, fix_imports=True) | ||||
|         start = self.offsets[item] | ||||
| @@ -38,7 +39,7 @@ class TorchMelDataset(Dataset): | ||||
|             snippet = self.transform(snippet) | ||||
|         if self.padding: | ||||
|             snippet = self.padding(snippet) | ||||
|         return snippet, self.label | ||||
|         return self.path.__str__(), snippet, self.label | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self.offsets) | ||||
|   | ||||
| @@ -1,67 +0,0 @@ | ||||
|  | ||||
| import torchaudio | ||||
| if sys.platform =='windows': | ||||
|     torchaudio.set_audio_backend('soundfile') | ||||
| else: | ||||
|     torchaudio.set_audio_backend('sox_io') | ||||
|  | ||||
|  | ||||
| class PyTorchAudioToMelDataset(_AudioToMelDataset): | ||||
|  | ||||
|     @property | ||||
|     def audio_file_duration(self): | ||||
|         info_obj = torchaudio.info(self.audio_path) | ||||
|         return info_obj.num_frames / info_obj.sample_rate | ||||
|  | ||||
|     @property | ||||
|     def sampling_rate(self): | ||||
|         return self.mel_kwargs['sample_rate'] | ||||
|  | ||||
|     def __init__(self, audio_file_path, *args, **kwargs): | ||||
|         super(PyTorchAudioToMelDataset, self).__init__(audio_file_path, *args, **kwargs) | ||||
|  | ||||
|         audio_file_path = Path(audio_file_path) | ||||
|         # audio_file, sampling_rate = librosa.load(self.audio_path, sr=sampling_rate) | ||||
|  | ||||
|         from torchaudio.transforms import MelSpectrogram | ||||
|         self._mel_transform = Compose([MelSpectrogram(**self.mel_kwargs), | ||||
|                                        MelToImage() | ||||
|                                        ]) | ||||
|  | ||||
|     def _build_mel(self): | ||||
|         if self.reset: | ||||
|             self.mel_file_path.unlink(missing_ok=True) | ||||
|         if not self.mel_file_path.exists(): | ||||
|             self.mel_file_path.parent.mkdir(parents=True, exist_ok=True) | ||||
|             lock_file = Path(str(self.mel_file_path).replace(self.mel_file_path.suffix, '.lock')) | ||||
|             lock_file.touch(exist_ok=False) | ||||
|  | ||||
|             try: | ||||
|                 audio_sample, sample_rate = torchaudio.load(self.audio_path) | ||||
|             except RuntimeError: | ||||
|                 import soundfile | ||||
|  | ||||
|                 data, samplerate = soundfile.read(self.audio_path) | ||||
|                 # sf.available_formats() | ||||
|                 # sf.available_subtypes() | ||||
|                 soundfile.write(self.audio_path, data, samplerate, subtype='PCM_32') | ||||
|  | ||||
|                 audio_sample, sample_rate = torchaudio.load(self.audio_path) | ||||
|             if sample_rate != self.sampling_rate: | ||||
|                 resample = torchaudio.transforms.Resample(orig_freq=int(sample_rate), new_freq=int(self.sampling_rate)) | ||||
|                 audio_sample = resample(audio_sample) | ||||
|             if audio_sample.shape[0] > 1: | ||||
|                 # Transform Stereo to Mono | ||||
|                 audio_sample = audio_sample.mean(dim=0, keepdim=True) | ||||
|             mel_sample = self._mel_transform(audio_sample) | ||||
|             with self.mel_file_path.open('wb') as mel_file: | ||||
|                 pickle.dump(mel_sample, mel_file, protocol=pickle.HIGHEST_PROTOCOL) | ||||
|             lock_file.unlink() | ||||
|         else: | ||||
|             # print(f"Already existed.. Skipping {filename}") | ||||
|             # mel_file = mel_file | ||||
|             pass | ||||
|  | ||||
|         # with mel_file.open(mode='rb') as f: | ||||
|         #     mel_sample = pickle.load(f, fix_imports=True) | ||||
|         return self.mel_file_path.exists() | ||||
							
								
								
									
										47
									
								
								metrics/attention_rollout.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								metrics/attention_rollout.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,47 @@ | ||||
| import numpy as np | ||||
|  | ||||
| from einops import reduce | ||||
|  | ||||
|  | ||||
| import torch | ||||
| from sklearn.ensemble import IsolationForest | ||||
| from sklearn.metrics import recall_score, roc_auc_score, average_precision_score | ||||
|  | ||||
| from ml_lib.metrics._base_score import _BaseScores | ||||
|  | ||||
|  | ||||
| class AttentionRollout(_BaseScores): | ||||
|  | ||||
|     def __init__(self, *args): | ||||
|         super(AttentionRollout, self).__init__(*args) | ||||
|         pass | ||||
|  | ||||
|     def __call__(self, outputs): | ||||
|         summary_dict = dict() | ||||
|         ####################################################################################### | ||||
|         # Additional Score  -  Histogram Distances - Image Plotting | ||||
|         ####################################################################################### | ||||
|         # | ||||
|         # INIT | ||||
|         attn_weights = [output['attn_weights'].cpu().numpy() for output in outputs] | ||||
|         attn_reduce_heads = [reduce(x, '') for x in attn_weights] | ||||
|  | ||||
|         if self.model.params.use_residual: | ||||
|             residual_att = np.eye(att_mat.shape[1])[None, ...] | ||||
|             aug_att_mat = att_mat + residual_att | ||||
|             aug_att_mat = aug_att_mat / aug_att_mat.sum(axis=-1)[..., None] | ||||
|         else: | ||||
|             aug_att_mat = att_mat | ||||
|  | ||||
|         joint_attentions = np.zeros(aug_att_mat.shape) | ||||
|  | ||||
|         layers = joint_attentions.shape[0] | ||||
|         joint_attentions[0] = aug_att_mat[0] | ||||
|         for i in np.arange(1, layers): | ||||
|             joint_attentions[i] = aug_att_mat[i].dot(joint_attentions[i - 1]) | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| @@ -113,16 +113,16 @@ class MultiClassScores(_BaseScores): | ||||
|         ####################################################################################### | ||||
|         # | ||||
|         # Confusion matrix | ||||
|  | ||||
|         fig1, ax1 = plt.subplots(dpi=96) | ||||
|         cm = confusion_matrix([class_names[x] for x in y_true], [class_names[x] for x in y_pred_max], | ||||
|                               labels=[class_names[key] for key in class_names.keys()], | ||||
|                               normalize='all') | ||||
|         disp = ConfusionMatrixDisplay(confusion_matrix=cm, | ||||
|                                       display_labels=[class_names[i] for i in range(self.model.n_classes)] | ||||
|                                       ) | ||||
|         disp.plot(include_values=True) | ||||
|         disp.plot(include_values=True, ax=ax1) | ||||
|  | ||||
|         self.model.logger.log_image('Confusion_Matrix', image=disp.figure_, step=self.model.current_epoch) | ||||
|         self.model.logger.log_image('Confusion_Matrix', image=fig1, step=self.model.current_epoch) | ||||
|         # self.model.logger.log_image('Confusion_Matrix', image=disp.figure_, step=self.model.current_epoch, ext='pdf') | ||||
|  | ||||
|         plt.close('all') | ||||
|   | ||||
| @@ -291,19 +291,17 @@ class TransformerModule(ShapeMixin, nn.Module): | ||||
|  | ||||
|         for attn, mlp in zip(self.attns, self.mlps): | ||||
|             # Attention | ||||
|             skip_connection = tensor.clone() | ||||
|             tensor = self.norm(tensor) | ||||
|             attn_tensor = self.norm(tensor) | ||||
|             if return_attn_weights: | ||||
|                 tensor, attn_weight = attn(tensor, mask=mask, return_attn_weights=return_attn_weights) | ||||
|                 attn_tensor, attn_weight = attn(attn_tensor, mask=mask, return_attn_weights=return_attn_weights) | ||||
|                 attn_weights.append(attn_weight) | ||||
|             else: | ||||
|                 tensor = attn(tensor, mask=mask) | ||||
|             tensor = tensor + skip_connection | ||||
|                 attn_tensor = attn(attn_tensor, mask=mask) | ||||
|             tensor = attn_tensor + tensor | ||||
|  | ||||
|             # MLP | ||||
|             skip_connection = tensor.clone() | ||||
|             tensor = self.norm(tensor) | ||||
|             tensor = mlp(tensor) | ||||
|             tensor = tensor + skip_connection | ||||
|             mlp_tensor = self.norm(tensor) | ||||
|             mlp_tensor = mlp(mlp_tensor) | ||||
|             tensor = tensor + mlp_tensor | ||||
|  | ||||
|         return (tensor, attn_weights) if return_attn_weights else tensor | ||||
|   | ||||
| @@ -1,3 +1,6 @@ | ||||
| import inspect | ||||
| from argparse import ArgumentParser | ||||
|  | ||||
| from functools import reduce | ||||
|  | ||||
| from abc import ABC | ||||
| @@ -5,13 +8,14 @@ from pathlib import Path | ||||
|  | ||||
| import torch | ||||
| from operator import mul | ||||
| from pytorch_lightning.utilities import argparse_utils | ||||
| from torch import nn | ||||
| from torch.nn import functional as F, Unfold | ||||
|  | ||||
| # Utility - Modules | ||||
| ################### | ||||
| from ..utils.model_io import ModelParameters | ||||
| from ..utils.tools import locate_and_import_class | ||||
| from ..utils.tools import locate_and_import_class, add_argparse_args | ||||
|  | ||||
| try: | ||||
|     import pytorch_lightning as pl | ||||
| @@ -32,14 +36,18 @@ try: | ||||
|                 print(e) | ||||
|                 return -1 | ||||
|  | ||||
|         def __init__(self, hparams): | ||||
|             super(LightningBaseModule, self).__init__() | ||||
|         @classmethod | ||||
|         def from_argparse_args(cls, args, **kwargs): | ||||
|             return argparse_utils.from_argparse_args(cls, args, **kwargs) | ||||
|  | ||||
|             # Set Parameters | ||||
|             ################################ | ||||
|             self.hparams = hparams | ||||
|             self.params = ModelParameters(hparams) | ||||
|             self.lr = self.params.lr or 1e-4 | ||||
|         @classmethod | ||||
|         def add_argparse_args(cls, parent_parser): | ||||
|             return add_argparse_args(cls, parent_parser) | ||||
|  | ||||
|         def __init__(self, model_parameters, weight_init='xavier_normal_'): | ||||
|             super(LightningBaseModule, self).__init__() | ||||
|             self._weight_init = weight_init | ||||
|             self.params = ModelParameters(model_parameters) | ||||
|  | ||||
|         def size(self): | ||||
|             return self.shape | ||||
| @@ -47,15 +55,6 @@ try: | ||||
|         def additional_scores(self, outputs): | ||||
|             raise NotImplementedError | ||||
|  | ||||
|         @property | ||||
|         def dataset_class(self): | ||||
|             try: | ||||
|                 return locate_and_import_class(self.params.class_name, folder_path='datasets') | ||||
|             except AttributeError as e: | ||||
|                 raise AttributeError(f'The dataset alias you provided ("{self.params.class_name}") ' + | ||||
|                                      f'was not found!\n' + | ||||
|                                      f'{e}') | ||||
|  | ||||
|         def save_to_disk(self, model_path): | ||||
|             Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True) | ||||
|             if not (model_path / 'model_class.obj').exists(): | ||||
| @@ -86,8 +85,12 @@ try: | ||||
|         def test_epoch_end(self, outputs): | ||||
|             raise NotImplementedError | ||||
|  | ||||
|         def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_): | ||||
|             weight_initializer = WeightInit(in_place_init_function=in_place_init_func_) | ||||
|         def init_weights(self): | ||||
|             if isinstance(self._weight_init, str): | ||||
|                 mod = __import__('torch.nn.init', fromlist=[self._weight_init]) | ||||
|                 self._weight_init = getattr(mod, self._weight_init) | ||||
|             assert callable(self._weight_init) | ||||
|             weight_initializer = WeightInit(in_place_init_function=self._weight_init) | ||||
|             self.apply(weight_initializer) | ||||
|  | ||||
|     module_types = (LightningBaseModule, nn.Module,) | ||||
|   | ||||
							
								
								
									
										29
									
								
								utils/_basedatamodule.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								utils/_basedatamodule.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | ||||
| from pytorch_lightning import LightningDataModule | ||||
|  | ||||
|  | ||||
| # Dataset Options | ||||
| from ml_lib.utils.tools import add_argparse_args | ||||
|  | ||||
| DATA_OPTION_test = 'test' | ||||
| DATA_OPTION_devel = 'devel' | ||||
| DATA_OPTION_train = 'train' | ||||
| DATA_OPTIONS = [DATA_OPTION_train, DATA_OPTION_devel, DATA_OPTION_test] | ||||
|  | ||||
|  | ||||
| class _BaseDataModule(LightningDataModule): | ||||
|  | ||||
|     @property | ||||
|     def shape(self): | ||||
|         return self.datasets[DATA_OPTION_train].sample_shape | ||||
|  | ||||
|     @classmethod | ||||
|     def add_argparse_args(cls, parent_parser): | ||||
|         return add_argparse_args(cls, parent_parser) | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self.datasets = dict() | ||||
|  | ||||
|     def transfer_batch_to_device(self, batch, device): | ||||
|         return batch.to(device) | ||||
|  | ||||
| @@ -1,19 +1,34 @@ | ||||
| from abc import ABC | ||||
| import inspect | ||||
| from argparse import ArgumentParser | ||||
| from pathlib import Path | ||||
|  | ||||
| import os | ||||
| from pytorch_lightning.loggers.base import LightningLoggerBase | ||||
| from pytorch_lightning.loggers.neptune import NeptuneLogger | ||||
| from neptune.api_exceptions import ProjectNotFound | ||||
| # noinspection PyUnresolvedReferences | ||||
|  | ||||
| from pytorch_lightning.loggers.csv_logs import CSVLogger | ||||
| from pytorch_lightning.utilities import argparse_utils | ||||
|  | ||||
| from .config import Config | ||||
| from ml_lib.utils.tools import add_argparse_args | ||||
|  | ||||
|  | ||||
| class Logger(LightningLoggerBase, ABC): | ||||
| class Logger(LightningLoggerBase): | ||||
|  | ||||
|     @classmethod | ||||
|     def from_argparse_args(cls, args, **kwargs): | ||||
|         return argparse_utils.from_argparse_args(cls, args, **kwargs) | ||||
|  | ||||
|     @property | ||||
|     def name(self) -> str: | ||||
|         return self._name | ||||
|  | ||||
|     media_dir = 'media' | ||||
|  | ||||
|     @classmethod | ||||
|     def add_argparse_args(cls, parent_parser): | ||||
|         return add_argparse_args(cls, parent_parser) | ||||
|  | ||||
|     @property | ||||
|     def experiment(self): | ||||
|         if self.debug: | ||||
| @@ -25,27 +40,23 @@ class Logger(LightningLoggerBase, ABC): | ||||
|     def log_dir(self): | ||||
|         return Path(self.csvlogger.experiment.log_dir) | ||||
|  | ||||
|     @property | ||||
|     def name(self): | ||||
|         return self.config.name | ||||
|  | ||||
|     @property | ||||
|     def project_name(self): | ||||
|         return f"{self.config.project.owner}/{self.config.project.name.replace('_', '-')}" | ||||
|         return f"{self.owner}/{self.name.replace('_', '-')}" | ||||
|  | ||||
|     @property | ||||
|     def version(self): | ||||
|         return self.config.get('main', 'seed') | ||||
|         return self.seed | ||||
|  | ||||
|     @property | ||||
|     def save_dir(self): | ||||
|         return self.log_dir | ||||
|  | ||||
|     @property | ||||
|     def outpath(self): | ||||
|         return Path(self.config.train.outpath) / self.config.model.type | ||||
|         return Path(self.root_out) / self.model_name | ||||
|  | ||||
|     @property | ||||
|     def exp_path(self): | ||||
|         return Path(self.outpath) / self.name | ||||
|  | ||||
|     def __init__(self, config: Config): | ||||
|     def __init__(self, owner, neptune_key, model_name, project_name='', outpath='output', seed=69, debug=False): | ||||
|         """ | ||||
|         params (dict|None): Optional. Parameters of the experiment. After experiment creation params are read-only. | ||||
|            Parameters are displayed in the experiment’s Parameters section and each key-value pair can be | ||||
| @@ -59,19 +70,19 @@ class Logger(LightningLoggerBase, ABC): | ||||
|         """ | ||||
|         super(Logger, self).__init__() | ||||
|  | ||||
|         self.config = config | ||||
|         self.debug = self.config.main.debug | ||||
|         if self.debug: | ||||
|             self.config.add_section('project') | ||||
|             self.config.set('project', 'owner', 'testuser') | ||||
|             self.config.set('project', 'name', 'test') | ||||
|             self.config.set('project', 'neptune_key', 'XXX') | ||||
|         self.debug = debug | ||||
|         self._name = project_name or Path(os.getcwd()).name if not self.debug else 'test' | ||||
|         self.owner = owner if not self.debug else 'testuser' | ||||
|         self.neptune_key = neptune_key if not self.debug else 'XXX' | ||||
|         self.root_out = outpath if not self.debug else 'debug_out' | ||||
|         self.seed = seed | ||||
|         self.model_name = model_name | ||||
|  | ||||
|         self._csvlogger_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name) | ||||
|         self._neptune_kwargs = dict(offline_mode=self.debug, | ||||
|                                     api_key=self.config.project.neptune_key, | ||||
|                                     api_key=self.neptune_key, | ||||
|                                     experiment_name=self.name, | ||||
|                                     project_name=self.project_name, | ||||
|                                     params=self.config.model_paramters) | ||||
|                                     project_name=self.project_name) | ||||
|         try: | ||||
|             self.neptunelogger = NeptuneLogger(**self._neptune_kwargs) | ||||
|         except ProjectNotFound as e: | ||||
| @@ -79,7 +90,6 @@ class Logger(LightningLoggerBase, ABC): | ||||
|             print(e) | ||||
|  | ||||
|         self.csvlogger = CSVLogger(**self._csvlogger_kwargs) | ||||
|         self.log_config_as_ini() | ||||
|  | ||||
|     def log_hyperparams(self, params): | ||||
|         self.neptunelogger.log_hyperparams(params) | ||||
| @@ -95,19 +105,15 @@ class Logger(LightningLoggerBase, ABC): | ||||
|         self.csvlogger.close() | ||||
|         self.neptunelogger.close() | ||||
|  | ||||
|     def log_config_as_ini(self): | ||||
|         self.config.write(self.log_dir / 'config.ini') | ||||
|  | ||||
|     def log_text(self, name, text, step_nb=0, **_): | ||||
|     def log_text(self, name, text, **_): | ||||
|         # TODO Implement Offline variant. | ||||
|         self.neptunelogger.log_text(name, text, step_nb) | ||||
|         self.neptunelogger.log_text(name, text) | ||||
|  | ||||
|     def log_metric(self, metric_name, metric_value, **kwargs): | ||||
|         self.csvlogger.log_metrics(dict(metric_name=metric_value)) | ||||
|         self.neptunelogger.log_metric(metric_name, metric_value, **kwargs) | ||||
|  | ||||
|     def log_image(self, name, image, ext='png', **kwargs): | ||||
|  | ||||
|         step = kwargs.get('step', None) | ||||
|         image_name = f'{step}_{name}' if step is not None else name | ||||
|         image_path = self.log_dir / self.media_dir / f'{image_name}.{ext[1:] if ext.startswith(".") else ext}' | ||||
|   | ||||
| @@ -13,6 +13,10 @@ from torch import nn | ||||
| # Hyperparamter Object | ||||
| class ModelParameters(Namespace, Mapping): | ||||
|  | ||||
|     @property | ||||
|     def as_dict(self): | ||||
|         return {key: self.get(key) if key != 'activation' else self.activation_as_string for key in self.keys()} | ||||
|  | ||||
|     @property | ||||
|     def activation_as_string(self): | ||||
|         return self['activation'].lower() | ||||
| @@ -50,13 +54,7 @@ class ModelParameters(Namespace, Mapping): | ||||
|         if name == 'activation': | ||||
|             return self._activations[self['activation'].lower()] | ||||
|         else: | ||||
|             try: | ||||
|             return super(ModelParameters, self).__getattribute__(name) | ||||
|             except AttributeError as e: | ||||
|                 if name == 'stretch': | ||||
|                     return False | ||||
|                 else: | ||||
|                     return None | ||||
|  | ||||
|     _activations = dict( | ||||
|         leaky_relu=nn.LeakyReLU, | ||||
| @@ -88,16 +86,20 @@ class SavedLightningModels(object): | ||||
|             model = torch.load(models_root_path / 'model_class.obj') | ||||
|         assert model is not None | ||||
|  | ||||
|         return cls(weights=str(checkpoint_path), model=model) | ||||
|         return cls(weights=checkpoint_path, model=model) | ||||
|  | ||||
|     def __init__(self, **kwargs): | ||||
|         self.weights: str = kwargs.get('weights', '') | ||||
|         self.weights: Path = Path(kwargs.get('weights', '')) | ||||
|         self.hparams: Path = self.weights.parent / 'hparams.yaml' | ||||
|  | ||||
|         self.model = kwargs.get('model', None) | ||||
|         assert self.model is not None | ||||
|  | ||||
|     def restore(self): | ||||
|         pretrained_model = self.model.load_from_checkpoint(self.weights) | ||||
|  | ||||
|         pretrained_model = self.model.load_from_checkpoint(self.weights.__str__()) | ||||
|         # , hparams_file=self.hparams.__str__()) | ||||
|         pretrained_model.eval() | ||||
|         pretrained_model.freeze() | ||||
|         return pretrained_model | ||||
|  | ||||
|   | ||||
| @@ -1,6 +1,9 @@ | ||||
| import importlib | ||||
| import inspect | ||||
| import pickle | ||||
| import shelve | ||||
| from argparse import ArgumentParser | ||||
| from ast import literal_eval | ||||
| from pathlib import Path, PurePath | ||||
| from typing import Union | ||||
|  | ||||
| @@ -9,6 +12,13 @@ import torch | ||||
| import random | ||||
|  | ||||
|  | ||||
| def auto_cast(a): | ||||
|   try: | ||||
|     return literal_eval(a) | ||||
|   except: | ||||
|     return a | ||||
|  | ||||
|  | ||||
| def to_one_hot(idx_array, max_classes): | ||||
|     one_hot = np.zeros((idx_array.size, max_classes)) | ||||
|     one_hot[np.arange(idx_array.size), idx_array] = 1 | ||||
| @@ -54,3 +64,20 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''): | ||||
|            continue | ||||
|     raise AttributeError(f'Check the Model name. Possible model files are:\n{[x.name for x in module_paths]}') | ||||
|  | ||||
|  | ||||
| def add_argparse_args(cls, parent_parser): | ||||
|     parser = ArgumentParser(parents=[parent_parser], add_help=False) | ||||
|     full_arg_spec = inspect.getfullargspec(cls.__init__) | ||||
|     n_non_defaults = len(full_arg_spec.args) - (len(full_arg_spec.defaults) if full_arg_spec.defaults else 0) | ||||
|     for idx, argument in enumerate(full_arg_spec.args): | ||||
|         if argument == 'self': | ||||
|             continue | ||||
|         if idx < n_non_defaults: | ||||
|             parser.add_argument(f'--{argument}', type=int) | ||||
|         else: | ||||
|             argument_type = type(argument) | ||||
|             parser.add_argument(f'--{argument}', | ||||
|                                 type=argument_type, | ||||
|                                 default=full_arg_spec.defaults[idx - n_non_defaults] | ||||
|                                 ) | ||||
|     return parser | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Steffen Illium
					Steffen Illium