import matplotlib.pyplot as plt from sklearn.metrics import roc_curve, auc class ROCEvaluation(object): BINARY_PROBLEM = 2 linewidth = 2 def __init__(self, save_fig=True): self.epoch = 0 pass def __call__(self, prediction, label, prepare_fig=True): # Compute ROC curve and ROC area fpr, tpr, _ = roc_curve(prediction, label) roc_auc = auc(fpr, tpr) if prepare_fig: fig = self._prepare_fig() fig.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc})') self._prepare_fig() return roc_auc def _prepare_fig(self): fig = plt.gcf() fig.plot([0, 1], [0, 1], color='navy', lw=self.linewidth, linestyle='--') fig.xlim([0.0, 1.0]) fig.ylim([0.0, 1.05]) fig.xlabel('False Positive Rate') fig.ylabel('True Positive Rate') fig.legend(loc="lower right") return fig