eval written
This commit is contained in:
@ -1,29 +1,24 @@
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from sklearn.metrics import roc_curve, auc
|
||||
|
||||
|
||||
class ROCEvaluation(object):
|
||||
|
||||
BINARY_PROBLEM = 2
|
||||
linewidth = 2
|
||||
|
||||
def __init__(self, save_fig=True):
|
||||
def __init__(self, prepare_figure=False):
|
||||
self.prepare_figure = prepare_figure
|
||||
self.epoch = 0
|
||||
pass
|
||||
|
||||
def __call__(self, prediction, label, prepare_fig=True):
|
||||
def __call__(self, prediction, label, plotting=False):
|
||||
|
||||
# Compute ROC curve and ROC area
|
||||
fpr, tpr, _ = roc_curve(prediction, label)
|
||||
roc_auc = auc(fpr, tpr)
|
||||
|
||||
if prepare_fig:
|
||||
fig = self._prepare_fig()
|
||||
fig.plot(fpr, tpr, color='darkorange',
|
||||
lw=2, label=f'ROC curve (area = {roc_auc})')
|
||||
self._prepare_fig()
|
||||
return roc_auc
|
||||
if plotting:
|
||||
fig = plt.gcf()
|
||||
fig.plot(fpr, tpr, color='darkorange', lw=self.linewidth, label=f'ROC curve (area = {roc_auc})')
|
||||
return roc_auc, fpr, tpr
|
||||
|
||||
def _prepare_fig(self):
|
||||
fig = plt.gcf()
|
||||
@ -32,6 +27,6 @@ class ROCEvaluation(object):
|
||||
fig.ylim([0.0, 1.05])
|
||||
fig.xlabel('False Positive Rate')
|
||||
fig.ylabel('True Positive Rate')
|
||||
|
||||
fig.legend(loc="lower right")
|
||||
|
||||
return fig
|
||||
|
Reference in New Issue
Block a user