38 lines
1019 B
Python
38 lines
1019 B
Python
import matplotlib.pyplot as plt
|
|
|
|
from sklearn.metrics import roc_curve, auc
|
|
|
|
|
|
class ROCEvaluation(object):
|
|
|
|
BINARY_PROBLEM = 2
|
|
linewidth = 2
|
|
|
|
def __init__(self, save_fig=True):
|
|
self.epoch = 0
|
|
pass
|
|
|
|
def __call__(self, prediction, label, prepare_fig=True):
|
|
|
|
# Compute ROC curve and ROC area
|
|
fpr, tpr, _ = roc_curve(prediction, label)
|
|
roc_auc = auc(fpr, tpr)
|
|
|
|
if prepare_fig:
|
|
fig = self._prepare_fig()
|
|
fig.plot(fpr, tpr, color='darkorange',
|
|
lw=2, label=f'ROC curve (area = {roc_auc})')
|
|
self._prepare_fig()
|
|
return roc_auc
|
|
|
|
def _prepare_fig(self):
|
|
fig = plt.gcf()
|
|
fig.plot([0, 1], [0, 1], color='navy', lw=self.linewidth, linestyle='--')
|
|
fig.xlim([0.0, 1.0])
|
|
fig.ylim([0.0, 1.05])
|
|
fig.xlabel('False Positive Rate')
|
|
fig.ylabel('True Positive Rate')
|
|
|
|
fig.legend(loc="lower right")
|
|
return fig
|