Save imports
This commit is contained in:
parent
645b7905e8
commit
f290d5a8d8
@ -1,4 +1,9 @@
|
|||||||
|
try:
|
||||||
import librosa
|
import librosa
|
||||||
|
except ImportError: # pragma: no-cover
|
||||||
|
raise ImportError('You want to use `librosa` plugins which are not installed yet,' # pragma: no-cover
|
||||||
|
' install it with `pip install librosa`.')
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,14 @@
|
|||||||
|
try:
|
||||||
import librosa
|
import librosa
|
||||||
|
except ImportError: # pragma: no-cover
|
||||||
|
raise ImportError('You want to use `librosa` plugins which are not installed yet,' # pragma: no-cover
|
||||||
|
' install it with `pip install librosa`.')
|
||||||
|
try:
|
||||||
from scipy.signal import butter, lfilter
|
from scipy.signal import butter, lfilter
|
||||||
|
except ImportError: # pragma: no-cover
|
||||||
|
raise ImportError('You want to use `scikit` plugins which are not installed yet,' # pragma: no-cover
|
||||||
|
' install it with `pip install scikit-learn`.')
|
||||||
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -1,13 +1,21 @@
|
|||||||
|
try:
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from sklearn.metrics import roc_curve, auc
|
except ImportError: # pragma: no-cover
|
||||||
|
raise ImportError('You want to use `matplotlib` plugins which are not installed yet,' # pragma: no-cover
|
||||||
|
' install it with `pip install matplotlib`.')
|
||||||
|
try:
|
||||||
|
from sklearn.metrics import roc_curve, auc, recall_score
|
||||||
|
except ImportError: # pragma: no-cover
|
||||||
|
raise ImportError('You want to use `sklearn` plugins which are not installed yet,' # pragma: no-cover
|
||||||
|
' install it with `pip install scikit-learn`.')
|
||||||
|
|
||||||
|
|
||||||
class ROCEvaluation(object):
|
class ROCEvaluation(object):
|
||||||
|
|
||||||
linewidth = 2
|
linewidth = 2
|
||||||
|
|
||||||
def __init__(self, plot_roc=False):
|
def __init__(self, plot=False):
|
||||||
self.plot_roc = plot_roc
|
self.plot = plot
|
||||||
self.epoch = 0
|
self.epoch = 0
|
||||||
|
|
||||||
def __call__(self, prediction, label):
|
def __call__(self, prediction, label):
|
||||||
@ -15,7 +23,7 @@ class ROCEvaluation(object):
|
|||||||
# Compute ROC curve and ROC area
|
# Compute ROC curve and ROC area
|
||||||
fpr, tpr, _ = roc_curve(prediction, label)
|
fpr, tpr, _ = roc_curve(prediction, label)
|
||||||
roc_auc = auc(fpr, tpr)
|
roc_auc = auc(fpr, tpr)
|
||||||
if self.plot_roc:
|
if self.plot:
|
||||||
_ = plt.gcf()
|
_ = plt.gcf()
|
||||||
plt.plot(fpr, tpr, color='darkorange', lw=self.linewidth, label=f'ROC curve (area = {roc_auc})')
|
plt.plot(fpr, tpr, color='darkorange', lw=self.linewidth, label=f'ROC curve (area = {roc_auc})')
|
||||||
self._prepare_fig()
|
self._prepare_fig()
|
||||||
@ -32,3 +40,32 @@ class ROCEvaluation(object):
|
|||||||
fig.legend(loc="lower right")
|
fig.legend(loc="lower right")
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
class UAREvaluation(object):
|
||||||
|
|
||||||
|
def __init__(self, labels: list, plot=False):
|
||||||
|
self.labels = labels
|
||||||
|
self.plot_roc = plot
|
||||||
|
self.epoch = 0
|
||||||
|
|
||||||
|
def __call__(self, prediction, label):
|
||||||
|
|
||||||
|
# Compute uar score - UnweightedAverageRecal
|
||||||
|
|
||||||
|
uar_score = recall_score(label, prediction, labels=self.labels, average='macro',
|
||||||
|
sample_weight=None, zero_division='warn')
|
||||||
|
return uar_score
|
||||||
|
|
||||||
|
def _prepare_fig(self):
|
||||||
|
raise NotImplementedError # TODO Implement a nice visualization
|
||||||
|
fig = plt.gcf()
|
||||||
|
ax = plt.gca()
|
||||||
|
plt.plot([0, 1], [0, 1], color='navy', lw=self.linewidth, linestyle='--')
|
||||||
|
plt.xlim([0.0, 1.0])
|
||||||
|
plt.ylim([0.0, 1.05])
|
||||||
|
plt.xlabel('False Positive Rate')
|
||||||
|
plt.ylabel('True Positive Rate')
|
||||||
|
fig.legend(loc="lower right")
|
||||||
|
|
||||||
|
return fig
|
@ -1,5 +1,10 @@
|
|||||||
from pathlib import Path
|
try:
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
except ImportError: # pragma: no-cover
|
||||||
|
raise ImportError('You want to use `matplotlib` plugins which are not installed yet,' # pragma: no-cover
|
||||||
|
' install it with `pip install matplotlib`.')
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
class Plotter(object):
|
class Plotter(object):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user