Urban 8k Train running with newest Lightning and pytorch
This commit is contained in:
56
metrics/binary_class_classifictaion.py
Normal file
56
metrics/binary_class_classifictaion.py
Normal file
@ -0,0 +1,56 @@
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from sklearn.ensemble import IsolationForest
|
||||
from sklearn.metrics import recall_score, roc_auc_score, average_precision_score
|
||||
|
||||
from ml_lib.metrics._base_score import _BaseScores
|
||||
|
||||
|
||||
class BinaryScores(_BaseScores):
|
||||
|
||||
def __init__(self, *args):
|
||||
super(BinaryScores, self).__init__(*args)
|
||||
|
||||
def __call__(self, outputs):
|
||||
summary_dict = dict()
|
||||
|
||||
# Additional Score like the unweighted Average Recall:
|
||||
#########################
|
||||
# UnweightedAverageRecall
|
||||
y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy()
|
||||
y_pred = torch.cat([output['element_wise_recon_error'] for output in outputs]).squeeze().cpu().numpy()
|
||||
|
||||
# How to apply a threshold manualy
|
||||
# y_pred = (y_pred >= 0.5).astype(np.float32)
|
||||
|
||||
# How to apply a threshold by IF (Isolation Forest)
|
||||
clf = IsolationForest(random_state=self.model.seed)
|
||||
y_score = clf.fit_predict(y_pred.reshape(-1,1))
|
||||
y_score = (np.asarray(y_score) == -1).astype(np.float32)
|
||||
|
||||
uar_score = recall_score(y_true, y_score, labels=[0, 1], average='macro',
|
||||
sample_weight=None, zero_division='warn')
|
||||
summary_dict.update(dict(uar_score=uar_score))
|
||||
#########################
|
||||
# Precission
|
||||
precision_score = average_precision_score(y_true, y_score)
|
||||
summary_dict.update(dict(precision_score=precision_score))
|
||||
|
||||
#########################
|
||||
# AUC
|
||||
try:
|
||||
auc_score = roc_auc_score(y_true=y_true, y_score=y_score)
|
||||
summary_dict.update(dict(auc_score=auc_score))
|
||||
except ValueError:
|
||||
summary_dict.update(dict(auc_score=-1))
|
||||
|
||||
#########################
|
||||
# pAUC
|
||||
try:
|
||||
pauc = roc_auc_score(y_true=y_true, y_score=y_score, max_fpr=0.15)
|
||||
summary_dict.update(dict(pauc_score=pauc))
|
||||
except ValueError:
|
||||
summary_dict.update(dict(pauc_score=-1))
|
||||
|
||||
return summary_dict
|
Reference in New Issue
Block a user