bringing brances up to date

This commit is contained in:
Steffen Illium 2021-02-15 11:39:54 +01:00
parent 010176e80b
commit a966321576
11 changed files with 216 additions and 197 deletions

View File

@ -10,24 +10,31 @@ from ml_lib.audio_toolset.audio_io import LibrosaAudioToMel, MelToImage
from ml_lib.audio_toolset.mel_dataset import TorchMelDataset
class _AudioToMelDataset(Dataset, ABC):
import librosa
class LibrosaAudioToMelDataset(Dataset):
@property
def audio_file_duration(self):
raise NotImplementedError
return librosa.get_duration(sr=self.mel_kwargs.get('sr', None), filename=self.audio_path)
@property
def sampling_rate(self):
raise NotImplementedError
return self.mel_kwargs.get('sr', None)
def __init__(self, audio_file_path, label, sample_segment_len=0, sample_hop_len=0, reset=False,
audio_augmentations=None, mel_augmentations=None, mel_kwargs=None, **kwargs):
self.ignored_kwargs = kwargs
super(LibrosaAudioToMelDataset, self).__init__()
# audio_file, sampling_rate = librosa.load(self.audio_path, sr=sampling_rate)
mel_kwargs.update(sr=mel_kwargs.get('sr', None) or librosa.get_samplerate(audio_file_path))
self.mel_kwargs = mel_kwargs
self.reset = reset
self.audio_path = Path(audio_file_path)
mel_folder_suffix = self.audio_path.parent.parent.name
self.mel_file_path = Path(str(self.audio_path)
.replace(mel_folder_suffix, f'{mel_folder_suffix}_mel_folder')
.replace(self.audio_path.suffix, '.npy'))
@ -38,59 +45,25 @@ class _AudioToMelDataset(Dataset, ABC):
self.audio_file_duration, mel_kwargs['sr'], mel_kwargs['hop_length'],
mel_kwargs['n_mels'], transform=mel_augmentations)
def _build_mel(self):
raise NotImplementedError
def __getitem__(self, item):
try:
return self.dataset[item]
except FileNotFoundError:
assert self._build_mel()
return self.dataset[item]
def __len__(self):
return len(self.dataset)
import librosa
class LibrosaAudioToMelDataset(_AudioToMelDataset):
@property
def audio_file_duration(self):
return librosa.get_duration(sr=self.mel_kwargs.get('sr', None), filename=self.audio_path)
@property
def sampling_rate(self):
return self.mel_kwargs.get('sr', None)
def __init__(self, audio_file_path, *args, **kwargs):
audio_file_path = Path(audio_file_path)
# audio_file, sampling_rate = librosa.load(self.audio_path, sr=sampling_rate)
mel_kwargs = kwargs.get('mel_kwargs', dict())
mel_kwargs.update(sr=mel_kwargs.get('sr', None) or librosa.get_samplerate(audio_file_path))
kwargs.update(mel_kwargs=mel_kwargs)
super(LibrosaAudioToMelDataset, self).__init__(audio_file_path, *args, **kwargs)
self._mel_transform = Compose([LibrosaAudioToMel(**mel_kwargs),
MelToImage()
])
def _build_mel(self):
def __getitem__(self, item):
return self.dataset[item]
def __len__(self):
return len(self.dataset)
def build_mel(self):
if self.reset:
self.mel_file_path.unlink(missing_ok=True)
if not self.mel_file_path.exists():
lockfile = Path(str(self.mel_file_path).replace(self.mel_file_path.suffix, '.lock'))
self.mel_file_path.parent.mkdir(parents=True, exist_ok=True)
lockfile.touch(exist_ok=False)
raw_sample, _ = librosa.core.load(self.audio_path, sr=self.sampling_rate)
mel_sample = self._mel_transform(raw_sample)
with self.mel_file_path.open('wb') as mel_file:
pickle.dump(mel_sample, mel_file, protocol=pickle.HIGHEST_PROTOCOL)
lockfile.unlink(missing_ok=False)
else:
pass

View File

@ -11,13 +11,16 @@ class TorchMelDataset(Dataset):
def __init__(self, mel_path, sub_segment_len, sub_segment_hop_len, label, audio_file_len,
sampling_rate, mel_hop_len, n_mels, transform=None, auto_pad_to_shape=True):
super(TorchMelDataset, self).__init__()
self.sampling_rate = sampling_rate
self.audio_file_len = audio_file_len
self.padding = AutoPadToShape((n_mels, sub_segment_len)) if auto_pad_to_shape and sub_segment_len else None
self.sampling_rate = int(sampling_rate)
self.audio_file_len = int(audio_file_len)
if auto_pad_to_shape and sub_segment_len:
self.padding = AutoPadToShape((int(n_mels), int(sub_segment_len)))
else:
self.padding = None
self.path = Path(mel_path)
self.sub_segment_len = sub_segment_len
self.mel_hop_len = mel_hop_len
self.sub_segment_hop_len = sub_segment_hop_len
self.sub_segment_len = int(sub_segment_len)
self.mel_hop_len = int(mel_hop_len)
self.sub_segment_hop_len = int(sub_segment_hop_len)
self.n = int((self.sampling_rate / self.mel_hop_len) * self.audio_file_len + 1)
if self.sub_segment_len and self.sub_segment_hop_len:
self.offsets = list(range(0, self.n - self.sub_segment_len, self.sub_segment_hop_len))
@ -27,8 +30,6 @@ class TorchMelDataset(Dataset):
self.transform = transform
def __getitem__(self, item):
while Path(str(self.path).replace(self.path.suffix, '.lock')).exists():
time.sleep(0.01)
with self.path.open('rb') as mel_file:
mel_spec = pickle.load(mel_file, fix_imports=True)
start = self.offsets[item]
@ -38,7 +39,7 @@ class TorchMelDataset(Dataset):
snippet = self.transform(snippet)
if self.padding:
snippet = self.padding(snippet)
return snippet, self.label
return self.path.__str__(), snippet, self.label
def __len__(self):
return len(self.offsets)

View File

@ -1,67 +0,0 @@
import torchaudio
if sys.platform =='windows':
torchaudio.set_audio_backend('soundfile')
else:
torchaudio.set_audio_backend('sox_io')
class PyTorchAudioToMelDataset(_AudioToMelDataset):
@property
def audio_file_duration(self):
info_obj = torchaudio.info(self.audio_path)
return info_obj.num_frames / info_obj.sample_rate
@property
def sampling_rate(self):
return self.mel_kwargs['sample_rate']
def __init__(self, audio_file_path, *args, **kwargs):
super(PyTorchAudioToMelDataset, self).__init__(audio_file_path, *args, **kwargs)
audio_file_path = Path(audio_file_path)
# audio_file, sampling_rate = librosa.load(self.audio_path, sr=sampling_rate)
from torchaudio.transforms import MelSpectrogram
self._mel_transform = Compose([MelSpectrogram(**self.mel_kwargs),
MelToImage()
])
def _build_mel(self):
if self.reset:
self.mel_file_path.unlink(missing_ok=True)
if not self.mel_file_path.exists():
self.mel_file_path.parent.mkdir(parents=True, exist_ok=True)
lock_file = Path(str(self.mel_file_path).replace(self.mel_file_path.suffix, '.lock'))
lock_file.touch(exist_ok=False)
try:
audio_sample, sample_rate = torchaudio.load(self.audio_path)
except RuntimeError:
import soundfile
data, samplerate = soundfile.read(self.audio_path)
# sf.available_formats()
# sf.available_subtypes()
soundfile.write(self.audio_path, data, samplerate, subtype='PCM_32')
audio_sample, sample_rate = torchaudio.load(self.audio_path)
if sample_rate != self.sampling_rate:
resample = torchaudio.transforms.Resample(orig_freq=int(sample_rate), new_freq=int(self.sampling_rate))
audio_sample = resample(audio_sample)
if audio_sample.shape[0] > 1:
# Transform Stereo to Mono
audio_sample = audio_sample.mean(dim=0, keepdim=True)
mel_sample = self._mel_transform(audio_sample)
with self.mel_file_path.open('wb') as mel_file:
pickle.dump(mel_sample, mel_file, protocol=pickle.HIGHEST_PROTOCOL)
lock_file.unlink()
else:
# print(f"Already existed.. Skipping {filename}")
# mel_file = mel_file
pass
# with mel_file.open(mode='rb') as f:
# mel_sample = pickle.load(f, fix_imports=True)
return self.mel_file_path.exists()

View File

@ -0,0 +1,47 @@
import numpy as np
from einops import reduce
import torch
from sklearn.ensemble import IsolationForest
from sklearn.metrics import recall_score, roc_auc_score, average_precision_score
from ml_lib.metrics._base_score import _BaseScores
class AttentionRollout(_BaseScores):
def __init__(self, *args):
super(AttentionRollout, self).__init__(*args)
pass
def __call__(self, outputs):
summary_dict = dict()
#######################################################################################
# Additional Score - Histogram Distances - Image Plotting
#######################################################################################
#
# INIT
attn_weights = [output['attn_weights'].cpu().numpy() for output in outputs]
attn_reduce_heads = [reduce(x, '') for x in attn_weights]
if self.model.params.use_residual:
residual_att = np.eye(att_mat.shape[1])[None, ...]
aug_att_mat = att_mat + residual_att
aug_att_mat = aug_att_mat / aug_att_mat.sum(axis=-1)[..., None]
else:
aug_att_mat = att_mat
joint_attentions = np.zeros(aug_att_mat.shape)
layers = joint_attentions.shape[0]
joint_attentions[0] = aug_att_mat[0]
for i in np.arange(1, layers):
joint_attentions[i] = aug_att_mat[i].dot(joint_attentions[i - 1])

View File

@ -113,16 +113,16 @@ class MultiClassScores(_BaseScores):
#######################################################################################
#
# Confusion matrix
fig1, ax1 = plt.subplots(dpi=96)
cm = confusion_matrix([class_names[x] for x in y_true], [class_names[x] for x in y_pred_max],
labels=[class_names[key] for key in class_names.keys()],
normalize='all')
disp = ConfusionMatrixDisplay(confusion_matrix=cm,
display_labels=[class_names[i] for i in range(self.model.n_classes)]
)
disp.plot(include_values=True)
disp.plot(include_values=True, ax=ax1)
self.model.logger.log_image('Confusion_Matrix', image=disp.figure_, step=self.model.current_epoch)
self.model.logger.log_image('Confusion_Matrix', image=fig1, step=self.model.current_epoch)
# self.model.logger.log_image('Confusion_Matrix', image=disp.figure_, step=self.model.current_epoch, ext='pdf')
plt.close('all')

View File

@ -291,19 +291,17 @@ class TransformerModule(ShapeMixin, nn.Module):
for attn, mlp in zip(self.attns, self.mlps):
# Attention
skip_connection = tensor.clone()
tensor = self.norm(tensor)
attn_tensor = self.norm(tensor)
if return_attn_weights:
tensor, attn_weight = attn(tensor, mask=mask, return_attn_weights=return_attn_weights)
attn_tensor, attn_weight = attn(attn_tensor, mask=mask, return_attn_weights=return_attn_weights)
attn_weights.append(attn_weight)
else:
tensor = attn(tensor, mask=mask)
tensor = tensor + skip_connection
attn_tensor = attn(attn_tensor, mask=mask)
tensor = attn_tensor + tensor
# MLP
skip_connection = tensor.clone()
tensor = self.norm(tensor)
tensor = mlp(tensor)
tensor = tensor + skip_connection
mlp_tensor = self.norm(tensor)
mlp_tensor = mlp(mlp_tensor)
tensor = tensor + mlp_tensor
return (tensor, attn_weights) if return_attn_weights else tensor

View File

@ -1,3 +1,6 @@
import inspect
from argparse import ArgumentParser
from functools import reduce
from abc import ABC
@ -5,13 +8,14 @@ from pathlib import Path
import torch
from operator import mul
from pytorch_lightning.utilities import argparse_utils
from torch import nn
from torch.nn import functional as F, Unfold
# Utility - Modules
###################
from ..utils.model_io import ModelParameters
from ..utils.tools import locate_and_import_class
from ..utils.tools import locate_and_import_class, add_argparse_args
try:
import pytorch_lightning as pl
@ -32,14 +36,18 @@ try:
print(e)
return -1
def __init__(self, hparams):
super(LightningBaseModule, self).__init__()
@classmethod
def from_argparse_args(cls, args, **kwargs):
return argparse_utils.from_argparse_args(cls, args, **kwargs)
# Set Parameters
################################
self.hparams = hparams
self.params = ModelParameters(hparams)
self.lr = self.params.lr or 1e-4
@classmethod
def add_argparse_args(cls, parent_parser):
return add_argparse_args(cls, parent_parser)
def __init__(self, model_parameters, weight_init='xavier_normal_'):
super(LightningBaseModule, self).__init__()
self._weight_init = weight_init
self.params = ModelParameters(model_parameters)
def size(self):
return self.shape
@ -47,15 +55,6 @@ try:
def additional_scores(self, outputs):
raise NotImplementedError
@property
def dataset_class(self):
try:
return locate_and_import_class(self.params.class_name, folder_path='datasets')
except AttributeError as e:
raise AttributeError(f'The dataset alias you provided ("{self.params.class_name}") ' +
f'was not found!\n' +
f'{e}')
def save_to_disk(self, model_path):
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
if not (model_path / 'model_class.obj').exists():
@ -86,8 +85,12 @@ try:
def test_epoch_end(self, outputs):
raise NotImplementedError
def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_):
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
def init_weights(self):
if isinstance(self._weight_init, str):
mod = __import__('torch.nn.init', fromlist=[self._weight_init])
self._weight_init = getattr(mod, self._weight_init)
assert callable(self._weight_init)
weight_initializer = WeightInit(in_place_init_function=self._weight_init)
self.apply(weight_initializer)
module_types = (LightningBaseModule, nn.Module,)

29
utils/_basedatamodule.py Normal file
View File

@ -0,0 +1,29 @@
from pytorch_lightning import LightningDataModule
# Dataset Options
from ml_lib.utils.tools import add_argparse_args
DATA_OPTION_test = 'test'
DATA_OPTION_devel = 'devel'
DATA_OPTION_train = 'train'
DATA_OPTIONS = [DATA_OPTION_train, DATA_OPTION_devel, DATA_OPTION_test]
class _BaseDataModule(LightningDataModule):
@property
def shape(self):
return self.datasets[DATA_OPTION_train].sample_shape
@classmethod
def add_argparse_args(cls, parent_parser):
return add_argparse_args(cls, parent_parser)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.datasets = dict()
def transfer_batch_to_device(self, batch, device):
return batch.to(device)

View File

@ -1,19 +1,34 @@
from abc import ABC
import inspect
from argparse import ArgumentParser
from pathlib import Path
import os
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.loggers.neptune import NeptuneLogger
from neptune.api_exceptions import ProjectNotFound
# noinspection PyUnresolvedReferences
from pytorch_lightning.loggers.csv_logs import CSVLogger
from pytorch_lightning.utilities import argparse_utils
from .config import Config
from ml_lib.utils.tools import add_argparse_args
class Logger(LightningLoggerBase, ABC):
class Logger(LightningLoggerBase):
@classmethod
def from_argparse_args(cls, args, **kwargs):
return argparse_utils.from_argparse_args(cls, args, **kwargs)
@property
def name(self) -> str:
return self._name
media_dir = 'media'
@classmethod
def add_argparse_args(cls, parent_parser):
return add_argparse_args(cls, parent_parser)
@property
def experiment(self):
if self.debug:
@ -25,27 +40,23 @@ class Logger(LightningLoggerBase, ABC):
def log_dir(self):
return Path(self.csvlogger.experiment.log_dir)
@property
def name(self):
return self.config.name
@property
def project_name(self):
return f"{self.config.project.owner}/{self.config.project.name.replace('_', '-')}"
return f"{self.owner}/{self.name.replace('_', '-')}"
@property
def version(self):
return self.config.get('main', 'seed')
return self.seed
@property
def save_dir(self):
return self.log_dir
@property
def outpath(self):
return Path(self.config.train.outpath) / self.config.model.type
return Path(self.root_out) / self.model_name
@property
def exp_path(self):
return Path(self.outpath) / self.name
def __init__(self, config: Config):
def __init__(self, owner, neptune_key, model_name, project_name='', outpath='output', seed=69, debug=False):
"""
params (dict|None): Optional. Parameters of the experiment. After experiment creation params are read-only.
Parameters are displayed in the experiments Parameters section and each key-value pair can be
@ -59,19 +70,19 @@ class Logger(LightningLoggerBase, ABC):
"""
super(Logger, self).__init__()
self.config = config
self.debug = self.config.main.debug
if self.debug:
self.config.add_section('project')
self.config.set('project', 'owner', 'testuser')
self.config.set('project', 'name', 'test')
self.config.set('project', 'neptune_key', 'XXX')
self.debug = debug
self._name = project_name or Path(os.getcwd()).name if not self.debug else 'test'
self.owner = owner if not self.debug else 'testuser'
self.neptune_key = neptune_key if not self.debug else 'XXX'
self.root_out = outpath if not self.debug else 'debug_out'
self.seed = seed
self.model_name = model_name
self._csvlogger_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
self._neptune_kwargs = dict(offline_mode=self.debug,
api_key=self.config.project.neptune_key,
api_key=self.neptune_key,
experiment_name=self.name,
project_name=self.project_name,
params=self.config.model_paramters)
project_name=self.project_name)
try:
self.neptunelogger = NeptuneLogger(**self._neptune_kwargs)
except ProjectNotFound as e:
@ -79,7 +90,6 @@ class Logger(LightningLoggerBase, ABC):
print(e)
self.csvlogger = CSVLogger(**self._csvlogger_kwargs)
self.log_config_as_ini()
def log_hyperparams(self, params):
self.neptunelogger.log_hyperparams(params)
@ -95,19 +105,15 @@ class Logger(LightningLoggerBase, ABC):
self.csvlogger.close()
self.neptunelogger.close()
def log_config_as_ini(self):
self.config.write(self.log_dir / 'config.ini')
def log_text(self, name, text, step_nb=0, **_):
def log_text(self, name, text, **_):
# TODO Implement Offline variant.
self.neptunelogger.log_text(name, text, step_nb)
self.neptunelogger.log_text(name, text)
def log_metric(self, metric_name, metric_value, **kwargs):
self.csvlogger.log_metrics(dict(metric_name=metric_value))
self.neptunelogger.log_metric(metric_name, metric_value, **kwargs)
def log_image(self, name, image, ext='png', **kwargs):
step = kwargs.get('step', None)
image_name = f'{step}_{name}' if step is not None else name
image_path = self.log_dir / self.media_dir / f'{image_name}.{ext[1:] if ext.startswith(".") else ext}'

View File

@ -13,6 +13,10 @@ from torch import nn
# Hyperparamter Object
class ModelParameters(Namespace, Mapping):
@property
def as_dict(self):
return {key: self.get(key) if key != 'activation' else self.activation_as_string for key in self.keys()}
@property
def activation_as_string(self):
return self['activation'].lower()
@ -50,13 +54,7 @@ class ModelParameters(Namespace, Mapping):
if name == 'activation':
return self._activations[self['activation'].lower()]
else:
try:
return super(ModelParameters, self).__getattribute__(name)
except AttributeError as e:
if name == 'stretch':
return False
else:
return None
return super(ModelParameters, self).__getattribute__(name)
_activations = dict(
leaky_relu=nn.LeakyReLU,
@ -88,16 +86,20 @@ class SavedLightningModels(object):
model = torch.load(models_root_path / 'model_class.obj')
assert model is not None
return cls(weights=str(checkpoint_path), model=model)
return cls(weights=checkpoint_path, model=model)
def __init__(self, **kwargs):
self.weights: str = kwargs.get('weights', '')
self.weights: Path = Path(kwargs.get('weights', ''))
self.hparams: Path = self.weights.parent / 'hparams.yaml'
self.model = kwargs.get('model', None)
assert self.model is not None
def restore(self):
pretrained_model = self.model.load_from_checkpoint(self.weights)
pretrained_model = self.model.load_from_checkpoint(self.weights.__str__())
# , hparams_file=self.hparams.__str__())
pretrained_model.eval()
pretrained_model.freeze()
return pretrained_model

View File

@ -1,6 +1,9 @@
import importlib
import inspect
import pickle
import shelve
from argparse import ArgumentParser
from ast import literal_eval
from pathlib import Path, PurePath
from typing import Union
@ -9,6 +12,13 @@ import torch
import random
def auto_cast(a):
try:
return literal_eval(a)
except:
return a
def to_one_hot(idx_array, max_classes):
one_hot = np.zeros((idx_array.size, max_classes))
one_hot[np.arange(idx_array.size), idx_array] = 1
@ -54,3 +64,20 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
continue
raise AttributeError(f'Check the Model name. Possible model files are:\n{[x.name for x in module_paths]}')
def add_argparse_args(cls, parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
full_arg_spec = inspect.getfullargspec(cls.__init__)
n_non_defaults = len(full_arg_spec.args) - (len(full_arg_spec.defaults) if full_arg_spec.defaults else 0)
for idx, argument in enumerate(full_arg_spec.args):
if argument == 'self':
continue
if idx < n_non_defaults:
parser.add_argument(f'--{argument}', type=int)
else:
argument_type = type(argument)
parser.add_argument(f'--{argument}',
type=argument_type,
default=full_arg_spec.defaults[idx - n_non_defaults]
)
return parser