bug in metric calculation
This commit is contained in:
parent
6816e423ff
commit
1d1b154460
@ -88,10 +88,7 @@ class LibrosaAudioToMel(object):
|
|||||||
def __init__(self, amplitude_to_db=False, power_to_db=False, **mel_kwargs):
|
def __init__(self, amplitude_to_db=False, power_to_db=False, **mel_kwargs):
|
||||||
assert not all([amplitude_to_db, power_to_db]), "Choose amplitude_to_db or power_to_db, not both!"
|
assert not all([amplitude_to_db, power_to_db]), "Choose amplitude_to_db or power_to_db, not both!"
|
||||||
# Mel kwargs are:
|
# Mel kwargs are:
|
||||||
# sr
|
# sr n_mels n_fft hop_length
|
||||||
# n_mels
|
|
||||||
# n_fft
|
|
||||||
# hop_length
|
|
||||||
|
|
||||||
self.mel_kwargs = mel_kwargs
|
self.mel_kwargs = mel_kwargs
|
||||||
self.amplitude_to_db = amplitude_to_db
|
self.amplitude_to_db = amplitude_to_db
|
||||||
|
@ -34,10 +34,10 @@ class LibrosaAudioToMelDataset(Dataset):
|
|||||||
self.audio_path = Path(audio_file_path)
|
self.audio_path = Path(audio_file_path)
|
||||||
|
|
||||||
mel_folder_suffix = self.audio_path.parent.parent.name
|
mel_folder_suffix = self.audio_path.parent.parent.name
|
||||||
|
self.mel_folder = Path(str(self.audio_path)
|
||||||
|
.replace(mel_folder_suffix, f'{mel_folder_suffix}_mel_folder')).parent.parent
|
||||||
|
|
||||||
self.mel_file_path = Path(str(self.audio_path)
|
self.mel_file_path = self.mel_folder / f'{self.audio_path.stem}.npy'
|
||||||
.replace(mel_folder_suffix, f'{mel_folder_suffix}_mel_folder')
|
|
||||||
.replace(self.audio_path.suffix, '.npy'))
|
|
||||||
|
|
||||||
self.audio_augmentations = audio_augmentations
|
self.audio_augmentations = audio_augmentations
|
||||||
|
|
||||||
@ -45,7 +45,7 @@ class LibrosaAudioToMelDataset(Dataset):
|
|||||||
self.audio_file_duration, mel_kwargs['sr'], mel_kwargs['hop_length'],
|
self.audio_file_duration, mel_kwargs['sr'], mel_kwargs['hop_length'],
|
||||||
mel_kwargs['n_mels'], transform=mel_augmentations)
|
mel_kwargs['n_mels'], transform=mel_augmentations)
|
||||||
|
|
||||||
self._mel_transform = Compose([LibrosaAudioToMel(**mel_kwargs),
|
self._mel_transform = Compose([LibrosaAudioToMel(power_to_db=False, **mel_kwargs),
|
||||||
MelToImage()
|
MelToImage()
|
||||||
])
|
])
|
||||||
|
|
||||||
|
@ -64,9 +64,9 @@ class ShiftTime(_BaseTransformation):
|
|||||||
# Set to silence for heading/ tailing
|
# Set to silence for heading/ tailing
|
||||||
shift = int(shift)
|
shift = int(shift)
|
||||||
if shift > 0:
|
if shift > 0:
|
||||||
augmented_data[:shift, :] = 0
|
augmented_data[:, :shift] = 0
|
||||||
else:
|
else:
|
||||||
augmented_data[shift:, :] = 0
|
augmented_data[:, shift:] = 0
|
||||||
return augmented_data
|
return augmented_data
|
||||||
else:
|
else:
|
||||||
return x
|
return x
|
||||||
|
@ -5,6 +5,7 @@ from sklearn.ensemble import IsolationForest
|
|||||||
from sklearn.metrics import recall_score, roc_auc_score, average_precision_score
|
from sklearn.metrics import recall_score, roc_auc_score, average_precision_score
|
||||||
|
|
||||||
from ml_lib.metrics._base_score import _BaseScores
|
from ml_lib.metrics._base_score import _BaseScores
|
||||||
|
from ml_lib.utils.tools import to_one_hot
|
||||||
|
|
||||||
|
|
||||||
class BinaryScores(_BaseScores):
|
class BinaryScores(_BaseScores):
|
||||||
@ -17,16 +18,27 @@ class BinaryScores(_BaseScores):
|
|||||||
|
|
||||||
# Additional Score like the unweighted Average Recall:
|
# Additional Score like the unweighted Average Recall:
|
||||||
#########################
|
#########################
|
||||||
|
# INIT
|
||||||
|
if isinstance(outputs['batch_y'], torch.Tensor):
|
||||||
|
y_true = outputs['batch_y'].cpu().numpy()
|
||||||
|
else:
|
||||||
|
y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy()
|
||||||
|
|
||||||
|
if isinstance(outputs['y'], torch.Tensor):
|
||||||
|
y_pred = outputs['y'].cpu().numpy()
|
||||||
|
else:
|
||||||
|
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy()
|
||||||
|
|
||||||
# UnweightedAverageRecall
|
# UnweightedAverageRecall
|
||||||
y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy()
|
# y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy()
|
||||||
y_pred = torch.cat([output['element_wise_recon_error'] for output in outputs]).squeeze().cpu().numpy()
|
# y_pred = torch.cat([output['element_wise_recon_error'] for output in outputs]).squeeze().cpu().numpy()
|
||||||
|
|
||||||
# How to apply a threshold manualy
|
# How to apply a threshold manualy
|
||||||
# y_pred = (y_pred >= 0.5).astype(np.float32)
|
# y_pred = (y_pred >= 0.5).astype(np.float32)
|
||||||
|
|
||||||
# How to apply a threshold by IF (Isolation Forest)
|
# How to apply a threshold by IF (Isolation Forest)
|
||||||
clf = IsolationForest(random_state=self.model.seed)
|
clf = IsolationForest()
|
||||||
y_score = clf.fit_predict(y_pred.reshape(-1,1))
|
y_score = clf.fit_predict(y_pred.reshape(-1, 1))
|
||||||
y_score = (np.asarray(y_score) == -1).astype(np.float32)
|
y_score = (np.asarray(y_score) == -1).astype(np.float32)
|
||||||
|
|
||||||
uar_score = recall_score(y_true, y_score, labels=[0, 1], average='macro',
|
uar_score = recall_score(y_true, y_score, labels=[0, 1], average='macro',
|
||||||
|
@ -44,7 +44,21 @@ try:
|
|||||||
|
|
||||||
def update(self, preds, target) -> None:
|
def update(self, preds, target) -> None:
|
||||||
for _, metric in self:
|
for _, metric in self:
|
||||||
metric.update(preds, target)
|
try:
|
||||||
|
if self.n_classes <= 2:
|
||||||
|
metric.update(preds.unsqueeze(-1), target.unsqueeze(-1))
|
||||||
|
else:
|
||||||
|
metric.update(preds, target)
|
||||||
|
except ValueError:
|
||||||
|
print(f'error was: {ValueError}')
|
||||||
|
print(f'Metric is: {metric}')
|
||||||
|
print(f'Shape is: preds - {preds.unsqueeze(-1).shape}, target - {target.shape}')
|
||||||
|
metric.update(preds.unsqueeze(-1), target)
|
||||||
|
except AssertionError:
|
||||||
|
print(f'error was: {AssertionError}')
|
||||||
|
print(f'Metric is: {metric}')
|
||||||
|
print(f'Shape is: preds - {preds.shape}, target - {target.unsqueeze(-1).shape}')
|
||||||
|
metric.update(preds, target.unsqueeze(-1))
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
for _, metric in self:
|
for _, metric in self:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user