bug in metric calculation
This commit is contained in:
		| @@ -88,10 +88,7 @@ class LibrosaAudioToMel(object): | ||||
|     def __init__(self, amplitude_to_db=False, power_to_db=False, **mel_kwargs): | ||||
|         assert not all([amplitude_to_db, power_to_db]), "Choose amplitude_to_db or power_to_db, not both!" | ||||
|         # Mel kwargs are: | ||||
|         #   sr | ||||
|         #   n_mels | ||||
|         #   n_fft | ||||
|         #   hop_length | ||||
|         #   sr   n_mels   n_fft   hop_length | ||||
|  | ||||
|         self.mel_kwargs = mel_kwargs | ||||
|         self.amplitude_to_db = amplitude_to_db | ||||
|   | ||||
| @@ -34,10 +34,10 @@ class LibrosaAudioToMelDataset(Dataset): | ||||
|         self.audio_path = Path(audio_file_path) | ||||
|  | ||||
|         mel_folder_suffix = self.audio_path.parent.parent.name | ||||
|         self.mel_folder = Path(str(self.audio_path) | ||||
|                                .replace(mel_folder_suffix, f'{mel_folder_suffix}_mel_folder')).parent.parent | ||||
|  | ||||
|         self.mel_file_path = Path(str(self.audio_path) | ||||
|                                   .replace(mel_folder_suffix, f'{mel_folder_suffix}_mel_folder') | ||||
|                                   .replace(self.audio_path.suffix, '.npy')) | ||||
|         self.mel_file_path = self.mel_folder / f'{self.audio_path.stem}.npy' | ||||
|  | ||||
|         self.audio_augmentations = audio_augmentations | ||||
|  | ||||
| @@ -45,7 +45,7 @@ class LibrosaAudioToMelDataset(Dataset): | ||||
|                                        self.audio_file_duration, mel_kwargs['sr'], mel_kwargs['hop_length'], | ||||
|                                        mel_kwargs['n_mels'], transform=mel_augmentations) | ||||
|  | ||||
|         self._mel_transform = Compose([LibrosaAudioToMel(**mel_kwargs), | ||||
|         self._mel_transform = Compose([LibrosaAudioToMel(power_to_db=False, **mel_kwargs), | ||||
|                                        MelToImage() | ||||
|                                        ]) | ||||
|  | ||||
|   | ||||
| @@ -64,9 +64,9 @@ class ShiftTime(_BaseTransformation): | ||||
|             # Set to silence for heading/ tailing | ||||
|             shift = int(shift) | ||||
|             if shift > 0: | ||||
|                 augmented_data[:shift, :] = 0 | ||||
|                 augmented_data[:, :shift] = 0 | ||||
|             else: | ||||
|                 augmented_data[shift:, :] = 0 | ||||
|                 augmented_data[:, shift:] = 0 | ||||
|             return augmented_data | ||||
|         else: | ||||
|             return x | ||||
|   | ||||
| @@ -5,6 +5,7 @@ from sklearn.ensemble import IsolationForest | ||||
| from sklearn.metrics import recall_score, roc_auc_score, average_precision_score | ||||
|  | ||||
| from ml_lib.metrics._base_score import _BaseScores | ||||
| from ml_lib.utils.tools import to_one_hot | ||||
|  | ||||
|  | ||||
| class BinaryScores(_BaseScores): | ||||
| @@ -17,16 +18,27 @@ class BinaryScores(_BaseScores): | ||||
|  | ||||
|         # Additional Score like the unweighted Average Recall: | ||||
|         ######################### | ||||
|         # INIT | ||||
|         if isinstance(outputs['batch_y'], torch.Tensor): | ||||
|             y_true = outputs['batch_y'].cpu().numpy() | ||||
|         else: | ||||
|             y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy() | ||||
|  | ||||
|         if isinstance(outputs['y'], torch.Tensor): | ||||
|             y_pred = outputs['y'].cpu().numpy() | ||||
|         else: | ||||
|             y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy() | ||||
|  | ||||
|         # UnweightedAverageRecall | ||||
|         y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy() | ||||
|         y_pred = torch.cat([output['element_wise_recon_error'] for output in outputs]).squeeze().cpu().numpy() | ||||
|         # y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy() | ||||
|         # y_pred = torch.cat([output['element_wise_recon_error'] for output in outputs]).squeeze().cpu().numpy() | ||||
|  | ||||
|         # How to apply a threshold manualy | ||||
|         # y_pred = (y_pred >= 0.5).astype(np.float32) | ||||
|  | ||||
|         # How to apply a threshold by IF (Isolation Forest) | ||||
|         clf = IsolationForest(random_state=self.model.seed) | ||||
|         y_score = clf.fit_predict(y_pred.reshape(-1,1)) | ||||
|         clf = IsolationForest() | ||||
|         y_score = clf.fit_predict(y_pred.reshape(-1, 1)) | ||||
|         y_score = (np.asarray(y_score) == -1).astype(np.float32) | ||||
|  | ||||
|         uar_score = recall_score(y_true, y_score, labels=[0, 1], average='macro', | ||||
|   | ||||
| @@ -44,7 +44,21 @@ try: | ||||
|  | ||||
|         def update(self, preds, target) -> None: | ||||
|             for _, metric in self: | ||||
|                 metric.update(preds, target) | ||||
|                 try: | ||||
|                     if self.n_classes <= 2: | ||||
|                         metric.update(preds.unsqueeze(-1), target.unsqueeze(-1)) | ||||
|                     else: | ||||
|                         metric.update(preds, target) | ||||
|                 except ValueError: | ||||
|                     print(f'error was: {ValueError}') | ||||
|                     print(f'Metric is: {metric}') | ||||
|                     print(f'Shape is: preds - {preds.unsqueeze(-1).shape}, target - {target.shape}') | ||||
|                     metric.update(preds.unsqueeze(-1), target) | ||||
|                 except AssertionError: | ||||
|                     print(f'error was: {AssertionError}') | ||||
|                     print(f'Metric is: {metric}') | ||||
|                     print(f'Shape is: preds - {preds.shape}, target - {target.unsqueeze(-1).shape}') | ||||
|                     metric.update(preds, target.unsqueeze(-1)) | ||||
|  | ||||
|         def reset(self) -> None: | ||||
|             for _, metric in self: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Steffen Illium
					Steffen Illium