68 lines
2.6 KiB
Python
68 lines
2.6 KiB
Python
from itertools import cycle
|
|
|
|
import numpy as np
|
|
import torch
|
|
from sklearn.metrics import roc_curve, auc, roc_auc_score, ConfusionMatrixDisplay, confusion_matrix
|
|
from scipy.spatial.distance import cdist
|
|
|
|
from ml_lib.metrics._base_score import _BaseScores
|
|
|
|
from matplotlib import pyplot as plt
|
|
|
|
|
|
class GenerativeTaskEval(_BaseScores):
|
|
|
|
def __init__(self, *args):
|
|
super(GenerativeTaskEval, self).__init__(*args)
|
|
pass
|
|
|
|
def __call__(self, outputs):
|
|
summary_dict = dict()
|
|
#######################################################################################
|
|
# Additional Score - Histogram Distances - Image Plotting
|
|
#######################################################################################
|
|
#
|
|
# INIT
|
|
y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy()
|
|
|
|
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy()
|
|
|
|
attn_weights = torch.cat([output['attn_weights'] for output in outputs]).squeeze().cpu().numpy()
|
|
|
|
######################################################################################
|
|
#
|
|
# Histogram comparission
|
|
|
|
y_true_hist = np.histogram(y_true, bins=128)[0] # Todo: Find a better value
|
|
y_pred_hist = np.histogram(y_pred, bins=128)[0] # Todo: Find a better value
|
|
|
|
# L2 norm == euclidean distance
|
|
hist_euc_dist = cdist(np.expand_dims(y_true_hist, axis=0), np.expand_dims(y_pred_hist, axis=0),
|
|
metric='euclidean')
|
|
|
|
# Manhattan Distance
|
|
hist_manhattan_dist = cdist(np.expand_dims(y_true_hist, axis=0), np.expand_dims(y_pred_hist, axis=0),
|
|
metric='cityblock')
|
|
|
|
summary_dict.update(hist_manhattan_dist=hist_manhattan_dist, hist_euc_dist=hist_euc_dist)
|
|
|
|
#######################################################################################
|
|
#
|
|
idx = np.random.choice(np.arange(y_true.shape[0]), 1).item()
|
|
|
|
ax = plt.imshow(y_true[idx].squeeze())
|
|
# Plot using a small number of colors, with unevenly spaced boundaries.
|
|
ax2 = plt.imshow(attn_weights[idx].sq, interpolation='nearest', aspect='auto', extent=ax.get_extent())
|
|
self.model.logger.log_image('ROC', image=plt.gcf(), step=self.model.current_epoch)
|
|
plt.clf()
|
|
|
|
|
|
#######################################################################################
|
|
#
|
|
|
|
|
|
#######################################################################################
|
|
#
|
|
|
|
plt.close('all')
|
|
return summary_dict |