ml_lib/metrics/binary_class_classifictaion.py

56 lines
2.0 KiB
Python

import numpy as np
import torch
from sklearn.ensemble import IsolationForest
from sklearn.metrics import recall_score, roc_auc_score, average_precision_score
from ml_lib.metrics._base_score import _BaseScores
class BinaryScores(_BaseScores):
def __init__(self, *args):
super(BinaryScores, self).__init__(*args)
def __call__(self, outputs):
summary_dict = dict()
# Additional Score like the unweighted Average Recall:
#########################
# UnweightedAverageRecall
y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy()
y_pred = torch.cat([output['element_wise_recon_error'] for output in outputs]).squeeze().cpu().numpy()
# How to apply a threshold manualy
# y_pred = (y_pred >= 0.5).astype(np.float32)
# How to apply a threshold by IF (Isolation Forest)
clf = IsolationForest(random_state=self.model.seed)
y_score = clf.fit_predict(y_pred.reshape(-1,1))
y_score = (np.asarray(y_score) == -1).astype(np.float32)
uar_score = recall_score(y_true, y_score, labels=[0, 1], average='macro',
sample_weight=None, zero_division='warn')
summary_dict.update(dict(uar_score=uar_score))
#########################
# Precission
precision_score = average_precision_score(y_true, y_score)
summary_dict.update(dict(precision_score=precision_score))
#########################
# AUC
try:
auc_score = roc_auc_score(y_true=y_true, y_score=y_score)
summary_dict.update(dict(auc_score=auc_score))
except ValueError:
summary_dict.update(dict(auc_score=-1))
#########################
# pAUC
try:
pauc = roc_auc_score(y_true=y_true, y_score=y_score, max_fpr=0.15)
summary_dict.update(dict(pauc_score=pauc))
except ValueError:
summary_dict.update(dict(pauc_score=-1))
return summary_dict