From a9663215760b5899217838544e2cdb127b28587e Mon Sep 17 00:00:00 2001 From: Steffen Illium Date: Mon, 15 Feb 2021 11:39:54 +0100 Subject: [PATCH] bringing brances up to date --- audio_toolset/audio_to_mel_dataset.py | 63 +++++++---------------- audio_toolset/mel_dataset.py | 19 +++---- experiments.py | 67 ------------------------- metrics/attention_rollout.py | 47 +++++++++++++++++ metrics/multi_class_classification.py | 8 +-- modules/blocks.py | 16 +++--- modules/util.py | 41 ++++++++------- utils/_basedatamodule.py | 29 +++++++++++ utils/logging.py | 72 +++++++++++++++------------ utils/model_io.py | 24 +++++---- utils/tools.py | 27 ++++++++++ 11 files changed, 216 insertions(+), 197 deletions(-) delete mode 100644 experiments.py create mode 100644 metrics/attention_rollout.py create mode 100644 utils/_basedatamodule.py diff --git a/audio_toolset/audio_to_mel_dataset.py b/audio_toolset/audio_to_mel_dataset.py index 058a326..f86e1aa 100644 --- a/audio_toolset/audio_to_mel_dataset.py +++ b/audio_toolset/audio_to_mel_dataset.py @@ -10,24 +10,31 @@ from ml_lib.audio_toolset.audio_io import LibrosaAudioToMel, MelToImage from ml_lib.audio_toolset.mel_dataset import TorchMelDataset -class _AudioToMelDataset(Dataset, ABC): +import librosa + + +class LibrosaAudioToMelDataset(Dataset): @property def audio_file_duration(self): - raise NotImplementedError + return librosa.get_duration(sr=self.mel_kwargs.get('sr', None), filename=self.audio_path) @property def sampling_rate(self): - raise NotImplementedError + return self.mel_kwargs.get('sr', None) def __init__(self, audio_file_path, label, sample_segment_len=0, sample_hop_len=0, reset=False, audio_augmentations=None, mel_augmentations=None, mel_kwargs=None, **kwargs): - self.ignored_kwargs = kwargs + super(LibrosaAudioToMelDataset, self).__init__() + + # audio_file, sampling_rate = librosa.load(self.audio_path, sr=sampling_rate) + mel_kwargs.update(sr=mel_kwargs.get('sr', None) or librosa.get_samplerate(audio_file_path)) self.mel_kwargs = mel_kwargs self.reset = reset self.audio_path = Path(audio_file_path) mel_folder_suffix = self.audio_path.parent.parent.name + self.mel_file_path = Path(str(self.audio_path) .replace(mel_folder_suffix, f'{mel_folder_suffix}_mel_folder') .replace(self.audio_path.suffix, '.npy')) @@ -38,59 +45,25 @@ class _AudioToMelDataset(Dataset, ABC): self.audio_file_duration, mel_kwargs['sr'], mel_kwargs['hop_length'], mel_kwargs['n_mels'], transform=mel_augmentations) - def _build_mel(self): - raise NotImplementedError - - def __getitem__(self, item): - try: - return self.dataset[item] - except FileNotFoundError: - assert self._build_mel() - return self.dataset[item] - - def __len__(self): - return len(self.dataset) - - -import librosa - - -class LibrosaAudioToMelDataset(_AudioToMelDataset): - - @property - def audio_file_duration(self): - return librosa.get_duration(sr=self.mel_kwargs.get('sr', None), filename=self.audio_path) - - @property - def sampling_rate(self): - return self.mel_kwargs.get('sr', None) - - def __init__(self, audio_file_path, *args, **kwargs): - - audio_file_path = Path(audio_file_path) - # audio_file, sampling_rate = librosa.load(self.audio_path, sr=sampling_rate) - mel_kwargs = kwargs.get('mel_kwargs', dict()) - mel_kwargs.update(sr=mel_kwargs.get('sr', None) or librosa.get_samplerate(audio_file_path)) - kwargs.update(mel_kwargs=mel_kwargs) - - super(LibrosaAudioToMelDataset, self).__init__(audio_file_path, *args, **kwargs) - self._mel_transform = Compose([LibrosaAudioToMel(**mel_kwargs), MelToImage() ]) - def _build_mel(self): + def __getitem__(self, item): + return self.dataset[item] + + def __len__(self): + return len(self.dataset) + + def build_mel(self): if self.reset: self.mel_file_path.unlink(missing_ok=True) if not self.mel_file_path.exists(): - lockfile = Path(str(self.mel_file_path).replace(self.mel_file_path.suffix, '.lock')) self.mel_file_path.parent.mkdir(parents=True, exist_ok=True) - lockfile.touch(exist_ok=False) raw_sample, _ = librosa.core.load(self.audio_path, sr=self.sampling_rate) mel_sample = self._mel_transform(raw_sample) with self.mel_file_path.open('wb') as mel_file: pickle.dump(mel_sample, mel_file, protocol=pickle.HIGHEST_PROTOCOL) - lockfile.unlink(missing_ok=False) else: pass diff --git a/audio_toolset/mel_dataset.py b/audio_toolset/mel_dataset.py index 6b6f245..4948736 100644 --- a/audio_toolset/mel_dataset.py +++ b/audio_toolset/mel_dataset.py @@ -11,13 +11,16 @@ class TorchMelDataset(Dataset): def __init__(self, mel_path, sub_segment_len, sub_segment_hop_len, label, audio_file_len, sampling_rate, mel_hop_len, n_mels, transform=None, auto_pad_to_shape=True): super(TorchMelDataset, self).__init__() - self.sampling_rate = sampling_rate - self.audio_file_len = audio_file_len - self.padding = AutoPadToShape((n_mels, sub_segment_len)) if auto_pad_to_shape and sub_segment_len else None + self.sampling_rate = int(sampling_rate) + self.audio_file_len = int(audio_file_len) + if auto_pad_to_shape and sub_segment_len: + self.padding = AutoPadToShape((int(n_mels), int(sub_segment_len))) + else: + self.padding = None self.path = Path(mel_path) - self.sub_segment_len = sub_segment_len - self.mel_hop_len = mel_hop_len - self.sub_segment_hop_len = sub_segment_hop_len + self.sub_segment_len = int(sub_segment_len) + self.mel_hop_len = int(mel_hop_len) + self.sub_segment_hop_len = int(sub_segment_hop_len) self.n = int((self.sampling_rate / self.mel_hop_len) * self.audio_file_len + 1) if self.sub_segment_len and self.sub_segment_hop_len: self.offsets = list(range(0, self.n - self.sub_segment_len, self.sub_segment_hop_len)) @@ -27,8 +30,6 @@ class TorchMelDataset(Dataset): self.transform = transform def __getitem__(self, item): - while Path(str(self.path).replace(self.path.suffix, '.lock')).exists(): - time.sleep(0.01) with self.path.open('rb') as mel_file: mel_spec = pickle.load(mel_file, fix_imports=True) start = self.offsets[item] @@ -38,7 +39,7 @@ class TorchMelDataset(Dataset): snippet = self.transform(snippet) if self.padding: snippet = self.padding(snippet) - return snippet, self.label + return self.path.__str__(), snippet, self.label def __len__(self): return len(self.offsets) diff --git a/experiments.py b/experiments.py deleted file mode 100644 index 66d4171..0000000 --- a/experiments.py +++ /dev/null @@ -1,67 +0,0 @@ - -import torchaudio -if sys.platform =='windows': - torchaudio.set_audio_backend('soundfile') -else: - torchaudio.set_audio_backend('sox_io') - - -class PyTorchAudioToMelDataset(_AudioToMelDataset): - - @property - def audio_file_duration(self): - info_obj = torchaudio.info(self.audio_path) - return info_obj.num_frames / info_obj.sample_rate - - @property - def sampling_rate(self): - return self.mel_kwargs['sample_rate'] - - def __init__(self, audio_file_path, *args, **kwargs): - super(PyTorchAudioToMelDataset, self).__init__(audio_file_path, *args, **kwargs) - - audio_file_path = Path(audio_file_path) - # audio_file, sampling_rate = librosa.load(self.audio_path, sr=sampling_rate) - - from torchaudio.transforms import MelSpectrogram - self._mel_transform = Compose([MelSpectrogram(**self.mel_kwargs), - MelToImage() - ]) - - def _build_mel(self): - if self.reset: - self.mel_file_path.unlink(missing_ok=True) - if not self.mel_file_path.exists(): - self.mel_file_path.parent.mkdir(parents=True, exist_ok=True) - lock_file = Path(str(self.mel_file_path).replace(self.mel_file_path.suffix, '.lock')) - lock_file.touch(exist_ok=False) - - try: - audio_sample, sample_rate = torchaudio.load(self.audio_path) - except RuntimeError: - import soundfile - - data, samplerate = soundfile.read(self.audio_path) - # sf.available_formats() - # sf.available_subtypes() - soundfile.write(self.audio_path, data, samplerate, subtype='PCM_32') - - audio_sample, sample_rate = torchaudio.load(self.audio_path) - if sample_rate != self.sampling_rate: - resample = torchaudio.transforms.Resample(orig_freq=int(sample_rate), new_freq=int(self.sampling_rate)) - audio_sample = resample(audio_sample) - if audio_sample.shape[0] > 1: - # Transform Stereo to Mono - audio_sample = audio_sample.mean(dim=0, keepdim=True) - mel_sample = self._mel_transform(audio_sample) - with self.mel_file_path.open('wb') as mel_file: - pickle.dump(mel_sample, mel_file, protocol=pickle.HIGHEST_PROTOCOL) - lock_file.unlink() - else: - # print(f"Already existed.. Skipping {filename}") - # mel_file = mel_file - pass - - # with mel_file.open(mode='rb') as f: - # mel_sample = pickle.load(f, fix_imports=True) - return self.mel_file_path.exists() diff --git a/metrics/attention_rollout.py b/metrics/attention_rollout.py new file mode 100644 index 0000000..972f151 --- /dev/null +++ b/metrics/attention_rollout.py @@ -0,0 +1,47 @@ +import numpy as np + +from einops import reduce + + +import torch +from sklearn.ensemble import IsolationForest +from sklearn.metrics import recall_score, roc_auc_score, average_precision_score + +from ml_lib.metrics._base_score import _BaseScores + + +class AttentionRollout(_BaseScores): + + def __init__(self, *args): + super(AttentionRollout, self).__init__(*args) + pass + + def __call__(self, outputs): + summary_dict = dict() + ####################################################################################### + # Additional Score - Histogram Distances - Image Plotting + ####################################################################################### + # + # INIT + attn_weights = [output['attn_weights'].cpu().numpy() for output in outputs] + attn_reduce_heads = [reduce(x, '') for x in attn_weights] + + if self.model.params.use_residual: + residual_att = np.eye(att_mat.shape[1])[None, ...] + aug_att_mat = att_mat + residual_att + aug_att_mat = aug_att_mat / aug_att_mat.sum(axis=-1)[..., None] + else: + aug_att_mat = att_mat + + joint_attentions = np.zeros(aug_att_mat.shape) + + layers = joint_attentions.shape[0] + joint_attentions[0] = aug_att_mat[0] + for i in np.arange(1, layers): + joint_attentions[i] = aug_att_mat[i].dot(joint_attentions[i - 1]) + + + + + + diff --git a/metrics/multi_class_classification.py b/metrics/multi_class_classification.py index fe7d6d8..4bb77f6 100644 --- a/metrics/multi_class_classification.py +++ b/metrics/multi_class_classification.py @@ -113,17 +113,17 @@ class MultiClassScores(_BaseScores): ####################################################################################### # # Confusion matrix - + fig1, ax1 = plt.subplots(dpi=96) cm = confusion_matrix([class_names[x] for x in y_true], [class_names[x] for x in y_pred_max], labels=[class_names[key] for key in class_names.keys()], normalize='all') disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[class_names[i] for i in range(self.model.n_classes)] ) - disp.plot(include_values=True) + disp.plot(include_values=True, ax=ax1) - self.model.logger.log_image('Confusion_Matrix', image=disp.figure_, step=self.model.current_epoch) + self.model.logger.log_image('Confusion_Matrix', image=fig1, step=self.model.current_epoch) # self.model.logger.log_image('Confusion_Matrix', image=disp.figure_, step=self.model.current_epoch, ext='pdf') plt.close('all') - return summary_dict \ No newline at end of file + return summary_dict diff --git a/modules/blocks.py b/modules/blocks.py index 2d3a359..d4c425b 100644 --- a/modules/blocks.py +++ b/modules/blocks.py @@ -291,19 +291,17 @@ class TransformerModule(ShapeMixin, nn.Module): for attn, mlp in zip(self.attns, self.mlps): # Attention - skip_connection = tensor.clone() - tensor = self.norm(tensor) + attn_tensor = self.norm(tensor) if return_attn_weights: - tensor, attn_weight = attn(tensor, mask=mask, return_attn_weights=return_attn_weights) + attn_tensor, attn_weight = attn(attn_tensor, mask=mask, return_attn_weights=return_attn_weights) attn_weights.append(attn_weight) else: - tensor = attn(tensor, mask=mask) - tensor = tensor + skip_connection + attn_tensor = attn(attn_tensor, mask=mask) + tensor = attn_tensor + tensor # MLP - skip_connection = tensor.clone() - tensor = self.norm(tensor) - tensor = mlp(tensor) - tensor = tensor + skip_connection + mlp_tensor = self.norm(tensor) + mlp_tensor = mlp(mlp_tensor) + tensor = tensor + mlp_tensor return (tensor, attn_weights) if return_attn_weights else tensor diff --git a/modules/util.py b/modules/util.py index be56c35..6492bc2 100644 --- a/modules/util.py +++ b/modules/util.py @@ -1,3 +1,6 @@ +import inspect +from argparse import ArgumentParser + from functools import reduce from abc import ABC @@ -5,13 +8,14 @@ from pathlib import Path import torch from operator import mul +from pytorch_lightning.utilities import argparse_utils from torch import nn from torch.nn import functional as F, Unfold # Utility - Modules ################### from ..utils.model_io import ModelParameters -from ..utils.tools import locate_and_import_class +from ..utils.tools import locate_and_import_class, add_argparse_args try: import pytorch_lightning as pl @@ -32,14 +36,18 @@ try: print(e) return -1 - def __init__(self, hparams): - super(LightningBaseModule, self).__init__() + @classmethod + def from_argparse_args(cls, args, **kwargs): + return argparse_utils.from_argparse_args(cls, args, **kwargs) - # Set Parameters - ################################ - self.hparams = hparams - self.params = ModelParameters(hparams) - self.lr = self.params.lr or 1e-4 + @classmethod + def add_argparse_args(cls, parent_parser): + return add_argparse_args(cls, parent_parser) + + def __init__(self, model_parameters, weight_init='xavier_normal_'): + super(LightningBaseModule, self).__init__() + self._weight_init = weight_init + self.params = ModelParameters(model_parameters) def size(self): return self.shape @@ -47,15 +55,6 @@ try: def additional_scores(self, outputs): raise NotImplementedError - @property - def dataset_class(self): - try: - return locate_and_import_class(self.params.class_name, folder_path='datasets') - except AttributeError as e: - raise AttributeError(f'The dataset alias you provided ("{self.params.class_name}") ' + - f'was not found!\n' + - f'{e}') - def save_to_disk(self, model_path): Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True) if not (model_path / 'model_class.obj').exists(): @@ -86,8 +85,12 @@ try: def test_epoch_end(self, outputs): raise NotImplementedError - def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_): - weight_initializer = WeightInit(in_place_init_function=in_place_init_func_) + def init_weights(self): + if isinstance(self._weight_init, str): + mod = __import__('torch.nn.init', fromlist=[self._weight_init]) + self._weight_init = getattr(mod, self._weight_init) + assert callable(self._weight_init) + weight_initializer = WeightInit(in_place_init_function=self._weight_init) self.apply(weight_initializer) module_types = (LightningBaseModule, nn.Module,) diff --git a/utils/_basedatamodule.py b/utils/_basedatamodule.py new file mode 100644 index 0000000..c2e0a2f --- /dev/null +++ b/utils/_basedatamodule.py @@ -0,0 +1,29 @@ +from pytorch_lightning import LightningDataModule + + +# Dataset Options +from ml_lib.utils.tools import add_argparse_args + +DATA_OPTION_test = 'test' +DATA_OPTION_devel = 'devel' +DATA_OPTION_train = 'train' +DATA_OPTIONS = [DATA_OPTION_train, DATA_OPTION_devel, DATA_OPTION_test] + + +class _BaseDataModule(LightningDataModule): + + @property + def shape(self): + return self.datasets[DATA_OPTION_train].sample_shape + + @classmethod + def add_argparse_args(cls, parent_parser): + return add_argparse_args(cls, parent_parser) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.datasets = dict() + + def transfer_batch_to_device(self, batch, device): + return batch.to(device) + diff --git a/utils/logging.py b/utils/logging.py index d0d983f..f9ef373 100644 --- a/utils/logging.py +++ b/utils/logging.py @@ -1,19 +1,34 @@ -from abc import ABC +import inspect +from argparse import ArgumentParser from pathlib import Path +import os from pytorch_lightning.loggers.base import LightningLoggerBase from pytorch_lightning.loggers.neptune import NeptuneLogger from neptune.api_exceptions import ProjectNotFound -# noinspection PyUnresolvedReferences + from pytorch_lightning.loggers.csv_logs import CSVLogger +from pytorch_lightning.utilities import argparse_utils -from .config import Config +from ml_lib.utils.tools import add_argparse_args -class Logger(LightningLoggerBase, ABC): +class Logger(LightningLoggerBase): + + @classmethod + def from_argparse_args(cls, args, **kwargs): + return argparse_utils.from_argparse_args(cls, args, **kwargs) + + @property + def name(self) -> str: + return self._name media_dir = 'media' + @classmethod + def add_argparse_args(cls, parent_parser): + return add_argparse_args(cls, parent_parser) + @property def experiment(self): if self.debug: @@ -25,27 +40,23 @@ class Logger(LightningLoggerBase, ABC): def log_dir(self): return Path(self.csvlogger.experiment.log_dir) - @property - def name(self): - return self.config.name - @property def project_name(self): - return f"{self.config.project.owner}/{self.config.project.name.replace('_', '-')}" + return f"{self.owner}/{self.name.replace('_', '-')}" @property def version(self): - return self.config.get('main', 'seed') + return self.seed + + @property + def save_dir(self): + return self.log_dir @property def outpath(self): - return Path(self.config.train.outpath) / self.config.model.type + return Path(self.root_out) / self.model_name - @property - def exp_path(self): - return Path(self.outpath) / self.name - - def __init__(self, config: Config): + def __init__(self, owner, neptune_key, model_name, project_name='', outpath='output', seed=69, debug=False): """ params (dict|None): Optional. Parameters of the experiment. After experiment creation params are read-only. Parameters are displayed in the experiment’s Parameters section and each key-value pair can be @@ -59,19 +70,19 @@ class Logger(LightningLoggerBase, ABC): """ super(Logger, self).__init__() - self.config = config - self.debug = self.config.main.debug - if self.debug: - self.config.add_section('project') - self.config.set('project', 'owner', 'testuser') - self.config.set('project', 'name', 'test') - self.config.set('project', 'neptune_key', 'XXX') + self.debug = debug + self._name = project_name or Path(os.getcwd()).name if not self.debug else 'test' + self.owner = owner if not self.debug else 'testuser' + self.neptune_key = neptune_key if not self.debug else 'XXX' + self.root_out = outpath if not self.debug else 'debug_out' + self.seed = seed + self.model_name = model_name + self._csvlogger_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name) self._neptune_kwargs = dict(offline_mode=self.debug, - api_key=self.config.project.neptune_key, + api_key=self.neptune_key, experiment_name=self.name, - project_name=self.project_name, - params=self.config.model_paramters) + project_name=self.project_name) try: self.neptunelogger = NeptuneLogger(**self._neptune_kwargs) except ProjectNotFound as e: @@ -79,7 +90,6 @@ class Logger(LightningLoggerBase, ABC): print(e) self.csvlogger = CSVLogger(**self._csvlogger_kwargs) - self.log_config_as_ini() def log_hyperparams(self, params): self.neptunelogger.log_hyperparams(params) @@ -95,19 +105,15 @@ class Logger(LightningLoggerBase, ABC): self.csvlogger.close() self.neptunelogger.close() - def log_config_as_ini(self): - self.config.write(self.log_dir / 'config.ini') - - def log_text(self, name, text, step_nb=0, **_): + def log_text(self, name, text, **_): # TODO Implement Offline variant. - self.neptunelogger.log_text(name, text, step_nb) + self.neptunelogger.log_text(name, text) def log_metric(self, metric_name, metric_value, **kwargs): self.csvlogger.log_metrics(dict(metric_name=metric_value)) self.neptunelogger.log_metric(metric_name, metric_value, **kwargs) def log_image(self, name, image, ext='png', **kwargs): - step = kwargs.get('step', None) image_name = f'{step}_{name}' if step is not None else name image_path = self.log_dir / self.media_dir / f'{image_name}.{ext[1:] if ext.startswith(".") else ext}' diff --git a/utils/model_io.py b/utils/model_io.py index 9bcc18d..f3ae503 100644 --- a/utils/model_io.py +++ b/utils/model_io.py @@ -13,6 +13,10 @@ from torch import nn # Hyperparamter Object class ModelParameters(Namespace, Mapping): + @property + def as_dict(self): + return {key: self.get(key) if key != 'activation' else self.activation_as_string for key in self.keys()} + @property def activation_as_string(self): return self['activation'].lower() @@ -50,13 +54,7 @@ class ModelParameters(Namespace, Mapping): if name == 'activation': return self._activations[self['activation'].lower()] else: - try: - return super(ModelParameters, self).__getattribute__(name) - except AttributeError as e: - if name == 'stretch': - return False - else: - return None + return super(ModelParameters, self).__getattribute__(name) _activations = dict( leaky_relu=nn.LeakyReLU, @@ -88,16 +86,20 @@ class SavedLightningModels(object): model = torch.load(models_root_path / 'model_class.obj') assert model is not None - return cls(weights=str(checkpoint_path), model=model) + return cls(weights=checkpoint_path, model=model) def __init__(self, **kwargs): - self.weights: str = kwargs.get('weights', '') + self.weights: Path = Path(kwargs.get('weights', '')) + self.hparams: Path = self.weights.parent / 'hparams.yaml' self.model = kwargs.get('model', None) assert self.model is not None def restore(self): - pretrained_model = self.model.load_from_checkpoint(self.weights) + + pretrained_model = self.model.load_from_checkpoint(self.weights.__str__()) + # , hparams_file=self.hparams.__str__()) pretrained_model.eval() pretrained_model.freeze() - return pretrained_model \ No newline at end of file + return pretrained_model + diff --git a/utils/tools.py b/utils/tools.py index a85b495..08e2c20 100644 --- a/utils/tools.py +++ b/utils/tools.py @@ -1,6 +1,9 @@ import importlib +import inspect import pickle import shelve +from argparse import ArgumentParser +from ast import literal_eval from pathlib import Path, PurePath from typing import Union @@ -9,6 +12,13 @@ import torch import random +def auto_cast(a): + try: + return literal_eval(a) + except: + return a + + def to_one_hot(idx_array, max_classes): one_hot = np.zeros((idx_array.size, max_classes)) one_hot[np.arange(idx_array.size), idx_array] = 1 @@ -54,3 +64,20 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''): continue raise AttributeError(f'Check the Model name. Possible model files are:\n{[x.name for x in module_paths]}') + +def add_argparse_args(cls, parent_parser): + parser = ArgumentParser(parents=[parent_parser], add_help=False) + full_arg_spec = inspect.getfullargspec(cls.__init__) + n_non_defaults = len(full_arg_spec.args) - (len(full_arg_spec.defaults) if full_arg_spec.defaults else 0) + for idx, argument in enumerate(full_arg_spec.args): + if argument == 'self': + continue + if idx < n_non_defaults: + parser.add_argument(f'--{argument}', type=int) + else: + argument_type = type(argument) + parser.add_argument(f'--{argument}', + type=argument_type, + default=full_arg_spec.defaults[idx - n_non_defaults] + ) + return parser