From 1d1b154460b305f1dd60d9f463c29f8d204c4360 Mon Sep 17 00:00:00 2001 From: Steffen Illium Date: Sat, 27 Mar 2021 16:39:07 +0100 Subject: [PATCH] bug in metric calculation --- audio_toolset/audio_io.py | 5 +---- audio_toolset/audio_to_mel_dataset.py | 8 ++++---- audio_toolset/mel_augmentation.py | 4 ++-- metrics/binary_class_classifictaion.py | 20 ++++++++++++++++---- modules/util.py | 16 +++++++++++++++- 5 files changed, 38 insertions(+), 15 deletions(-) diff --git a/audio_toolset/audio_io.py b/audio_toolset/audio_io.py index 530daf0..c17deb7 100644 --- a/audio_toolset/audio_io.py +++ b/audio_toolset/audio_io.py @@ -88,10 +88,7 @@ class LibrosaAudioToMel(object): def __init__(self, amplitude_to_db=False, power_to_db=False, **mel_kwargs): assert not all([amplitude_to_db, power_to_db]), "Choose amplitude_to_db or power_to_db, not both!" # Mel kwargs are: - # sr - # n_mels - # n_fft - # hop_length + # sr n_mels n_fft hop_length self.mel_kwargs = mel_kwargs self.amplitude_to_db = amplitude_to_db diff --git a/audio_toolset/audio_to_mel_dataset.py b/audio_toolset/audio_to_mel_dataset.py index a98ce90..492ed34 100644 --- a/audio_toolset/audio_to_mel_dataset.py +++ b/audio_toolset/audio_to_mel_dataset.py @@ -34,10 +34,10 @@ class LibrosaAudioToMelDataset(Dataset): self.audio_path = Path(audio_file_path) mel_folder_suffix = self.audio_path.parent.parent.name + self.mel_folder = Path(str(self.audio_path) + .replace(mel_folder_suffix, f'{mel_folder_suffix}_mel_folder')).parent.parent - self.mel_file_path = Path(str(self.audio_path) - .replace(mel_folder_suffix, f'{mel_folder_suffix}_mel_folder') - .replace(self.audio_path.suffix, '.npy')) + self.mel_file_path = self.mel_folder / f'{self.audio_path.stem}.npy' self.audio_augmentations = audio_augmentations @@ -45,7 +45,7 @@ class LibrosaAudioToMelDataset(Dataset): self.audio_file_duration, mel_kwargs['sr'], mel_kwargs['hop_length'], mel_kwargs['n_mels'], transform=mel_augmentations) - self._mel_transform = Compose([LibrosaAudioToMel(**mel_kwargs), + self._mel_transform = Compose([LibrosaAudioToMel(power_to_db=False, **mel_kwargs), MelToImage() ]) diff --git a/audio_toolset/mel_augmentation.py b/audio_toolset/mel_augmentation.py index 9f7516c..f708ecf 100644 --- a/audio_toolset/mel_augmentation.py +++ b/audio_toolset/mel_augmentation.py @@ -64,9 +64,9 @@ class ShiftTime(_BaseTransformation): # Set to silence for heading/ tailing shift = int(shift) if shift > 0: - augmented_data[:shift, :] = 0 + augmented_data[:, :shift] = 0 else: - augmented_data[shift:, :] = 0 + augmented_data[:, shift:] = 0 return augmented_data else: return x diff --git a/metrics/binary_class_classifictaion.py b/metrics/binary_class_classifictaion.py index 568f522..2a6e3bf 100644 --- a/metrics/binary_class_classifictaion.py +++ b/metrics/binary_class_classifictaion.py @@ -5,6 +5,7 @@ from sklearn.ensemble import IsolationForest from sklearn.metrics import recall_score, roc_auc_score, average_precision_score from ml_lib.metrics._base_score import _BaseScores +from ml_lib.utils.tools import to_one_hot class BinaryScores(_BaseScores): @@ -17,16 +18,27 @@ class BinaryScores(_BaseScores): # Additional Score like the unweighted Average Recall: ######################### + # INIT + if isinstance(outputs['batch_y'], torch.Tensor): + y_true = outputs['batch_y'].cpu().numpy() + else: + y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy() + + if isinstance(outputs['y'], torch.Tensor): + y_pred = outputs['y'].cpu().numpy() + else: + y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy() + # UnweightedAverageRecall - y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy() - y_pred = torch.cat([output['element_wise_recon_error'] for output in outputs]).squeeze().cpu().numpy() + # y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy() + # y_pred = torch.cat([output['element_wise_recon_error'] for output in outputs]).squeeze().cpu().numpy() # How to apply a threshold manualy # y_pred = (y_pred >= 0.5).astype(np.float32) # How to apply a threshold by IF (Isolation Forest) - clf = IsolationForest(random_state=self.model.seed) - y_score = clf.fit_predict(y_pred.reshape(-1,1)) + clf = IsolationForest() + y_score = clf.fit_predict(y_pred.reshape(-1, 1)) y_score = (np.asarray(y_score) == -1).astype(np.float32) uar_score = recall_score(y_true, y_score, labels=[0, 1], average='macro', diff --git a/modules/util.py b/modules/util.py index edc6874..0dd9847 100644 --- a/modules/util.py +++ b/modules/util.py @@ -44,7 +44,21 @@ try: def update(self, preds, target) -> None: for _, metric in self: - metric.update(preds, target) + try: + if self.n_classes <= 2: + metric.update(preds.unsqueeze(-1), target.unsqueeze(-1)) + else: + metric.update(preds, target) + except ValueError: + print(f'error was: {ValueError}') + print(f'Metric is: {metric}') + print(f'Shape is: preds - {preds.unsqueeze(-1).shape}, target - {target.shape}') + metric.update(preds.unsqueeze(-1), target) + except AssertionError: + print(f'error was: {AssertionError}') + print(f'Metric is: {metric}') + print(f'Shape is: preds - {preds.shape}, target - {target.unsqueeze(-1).shape}') + metric.update(preds, target.unsqueeze(-1)) def reset(self) -> None: for _, metric in self: