Save imports

This commit is contained in:
Si11ium 2020-05-19 10:03:35 +02:00
parent 645b7905e8
commit f290d5a8d8
4 changed files with 65 additions and 9 deletions

View File

@ -1,4 +1,9 @@
import librosa try:
import librosa
except ImportError: # pragma: no-cover
raise ImportError('You want to use `librosa` plugins which are not installed yet,' # pragma: no-cover
' install it with `pip install librosa`.')
import numpy as np import numpy as np

View File

@ -1,5 +1,14 @@
import librosa try:
from scipy.signal import butter, lfilter import librosa
except ImportError: # pragma: no-cover
raise ImportError('You want to use `librosa` plugins which are not installed yet,' # pragma: no-cover
' install it with `pip install librosa`.')
try:
from scipy.signal import butter, lfilter
except ImportError: # pragma: no-cover
raise ImportError('You want to use `scikit` plugins which are not installed yet,' # pragma: no-cover
' install it with `pip install scikit-learn`.')
import numpy as np import numpy as np

View File

@ -1,13 +1,21 @@
import matplotlib.pyplot as plt try:
from sklearn.metrics import roc_curve, auc import matplotlib.pyplot as plt
except ImportError: # pragma: no-cover
raise ImportError('You want to use `matplotlib` plugins which are not installed yet,' # pragma: no-cover
' install it with `pip install matplotlib`.')
try:
from sklearn.metrics import roc_curve, auc, recall_score
except ImportError: # pragma: no-cover
raise ImportError('You want to use `sklearn` plugins which are not installed yet,' # pragma: no-cover
' install it with `pip install scikit-learn`.')
class ROCEvaluation(object): class ROCEvaluation(object):
linewidth = 2 linewidth = 2
def __init__(self, plot_roc=False): def __init__(self, plot=False):
self.plot_roc = plot_roc self.plot = plot
self.epoch = 0 self.epoch = 0
def __call__(self, prediction, label): def __call__(self, prediction, label):
@ -15,7 +23,7 @@ class ROCEvaluation(object):
# Compute ROC curve and ROC area # Compute ROC curve and ROC area
fpr, tpr, _ = roc_curve(prediction, label) fpr, tpr, _ = roc_curve(prediction, label)
roc_auc = auc(fpr, tpr) roc_auc = auc(fpr, tpr)
if self.plot_roc: if self.plot:
_ = plt.gcf() _ = plt.gcf()
plt.plot(fpr, tpr, color='darkorange', lw=self.linewidth, label=f'ROC curve (area = {roc_auc})') plt.plot(fpr, tpr, color='darkorange', lw=self.linewidth, label=f'ROC curve (area = {roc_auc})')
self._prepare_fig() self._prepare_fig()
@ -32,3 +40,32 @@ class ROCEvaluation(object):
fig.legend(loc="lower right") fig.legend(loc="lower right")
return fig return fig
class UAREvaluation(object):
def __init__(self, labels: list, plot=False):
self.labels = labels
self.plot_roc = plot
self.epoch = 0
def __call__(self, prediction, label):
# Compute uar score - UnweightedAverageRecal
uar_score = recall_score(label, prediction, labels=self.labels, average='macro',
sample_weight=None, zero_division='warn')
return uar_score
def _prepare_fig(self):
raise NotImplementedError # TODO Implement a nice visualization
fig = plt.gcf()
ax = plt.gca()
plt.plot([0, 1], [0, 1], color='navy', lw=self.linewidth, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
fig.legend(loc="lower right")
return fig

View File

@ -1,5 +1,10 @@
try:
import matplotlib.pyplot as plt
except ImportError: # pragma: no-cover
raise ImportError('You want to use `matplotlib` plugins which are not installed yet,' # pragma: no-cover
' install it with `pip install matplotlib`.')
from pathlib import Path from pathlib import Path
import matplotlib.pyplot as plt
class Plotter(object): class Plotter(object):