ml_lib/metrics/generative_task_evaluation.py
2021-02-01 10:23:22 +01:00

68 lines
2.6 KiB
Python

from itertools import cycle
import numpy as np
import torch
from sklearn.metrics import roc_curve, auc, roc_auc_score, ConfusionMatrixDisplay, confusion_matrix
from scipy.spatial.distance import cdist
from ml_lib.metrics._base_score import _BaseScores
from matplotlib import pyplot as plt
class GenerativeTaskEval(_BaseScores):
def __init__(self, *args):
super(GenerativeTaskEval, self).__init__(*args)
pass
def __call__(self, outputs):
summary_dict = dict()
#######################################################################################
# Additional Score - Histogram Distances - Image Plotting
#######################################################################################
#
# INIT
y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy()
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy()
attn_weights = torch.cat([output['attn_weights'] for output in outputs]).squeeze().cpu().numpy()
######################################################################################
#
# Histogram comparission
y_true_hist = np.histogram(y_true, bins=128)[0] # Todo: Find a better value
y_pred_hist = np.histogram(y_pred, bins=128)[0] # Todo: Find a better value
# L2 norm == euclidean distance
hist_euc_dist = cdist(np.expand_dims(y_true_hist, axis=0), np.expand_dims(y_pred_hist, axis=0),
metric='euclidean')
# Manhattan Distance
hist_manhattan_dist = cdist(np.expand_dims(y_true_hist, axis=0), np.expand_dims(y_pred_hist, axis=0),
metric='cityblock')
summary_dict.update(hist_manhattan_dist=hist_manhattan_dist, hist_euc_dist=hist_euc_dist)
#######################################################################################
#
idx = np.random.choice(np.arange(y_true.shape[0]), 1).item()
ax = plt.imshow(y_true[idx].squeeze())
# Plot using a small number of colors, with unevenly spaced boundaries.
ax2 = plt.imshow(attn_weights[idx].sq, interpolation='nearest', aspect='auto', extent=ax.get_extent())
self.model.logger.log_image('ROC', image=plt.gcf(), step=self.model.current_epoch)
plt.clf()
#######################################################################################
#
#######################################################################################
#
plt.close('all')
return summary_dict