56 lines
2.0 KiB
Python
56 lines
2.0 KiB
Python
import numpy as np
|
|
|
|
import torch
|
|
from sklearn.ensemble import IsolationForest
|
|
from sklearn.metrics import recall_score, roc_auc_score, average_precision_score
|
|
|
|
from ml_lib.metrics._base_score import _BaseScores
|
|
|
|
|
|
class BinaryScores(_BaseScores):
|
|
|
|
def __init__(self, *args):
|
|
super(BinaryScores, self).__init__(*args)
|
|
|
|
def __call__(self, outputs):
|
|
summary_dict = dict()
|
|
|
|
# Additional Score like the unweighted Average Recall:
|
|
#########################
|
|
# UnweightedAverageRecall
|
|
y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy()
|
|
y_pred = torch.cat([output['element_wise_recon_error'] for output in outputs]).squeeze().cpu().numpy()
|
|
|
|
# How to apply a threshold manualy
|
|
# y_pred = (y_pred >= 0.5).astype(np.float32)
|
|
|
|
# How to apply a threshold by IF (Isolation Forest)
|
|
clf = IsolationForest(random_state=self.model.seed)
|
|
y_score = clf.fit_predict(y_pred.reshape(-1,1))
|
|
y_score = (np.asarray(y_score) == -1).astype(np.float32)
|
|
|
|
uar_score = recall_score(y_true, y_score, labels=[0, 1], average='macro',
|
|
sample_weight=None, zero_division='warn')
|
|
summary_dict.update(dict(uar_score=uar_score))
|
|
#########################
|
|
# Precission
|
|
precision_score = average_precision_score(y_true, y_score)
|
|
summary_dict.update(dict(precision_score=precision_score))
|
|
|
|
#########################
|
|
# AUC
|
|
try:
|
|
auc_score = roc_auc_score(y_true=y_true, y_score=y_score)
|
|
summary_dict.update(dict(auc_score=auc_score))
|
|
except ValueError:
|
|
summary_dict.update(dict(auc_score=-1))
|
|
|
|
#########################
|
|
# pAUC
|
|
try:
|
|
pauc = roc_auc_score(y_true=y_true, y_score=y_score, max_fpr=0.15)
|
|
summary_dict.update(dict(pauc_score=pauc))
|
|
except ValueError:
|
|
summary_dict.update(dict(pauc_score=-1))
|
|
|
|
return summary_dict |