transition
This commit is contained in:
68
metrics/generative_task_evaluation.py
Normal file
68
metrics/generative_task_evaluation.py
Normal file
@ -0,0 +1,68 @@
|
||||
from itertools import cycle
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from sklearn.metrics import roc_curve, auc, roc_auc_score, ConfusionMatrixDisplay, confusion_matrix
|
||||
from scipy.spatial.distance import cdist
|
||||
|
||||
from ml_lib.metrics._base_score import _BaseScores
|
||||
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
|
||||
class GenerativeTaskEval(_BaseScores):
|
||||
|
||||
def __init__(self, *args):
|
||||
super(GenerativeTaskEval, self).__init__(*args)
|
||||
pass
|
||||
|
||||
def __call__(self, outputs):
|
||||
summary_dict = dict()
|
||||
#######################################################################################
|
||||
# Additional Score - Histogram Distances - Image Plotting
|
||||
#######################################################################################
|
||||
#
|
||||
# INIT
|
||||
y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy()
|
||||
|
||||
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy()
|
||||
|
||||
attn_weights = torch.cat([output['attn_weights'] for output in outputs]).squeeze().cpu().numpy()
|
||||
|
||||
######################################################################################
|
||||
#
|
||||
# Histogram comparission
|
||||
|
||||
y_true_hist = np.histogram(y_true, bins=128)[0] # Todo: Find a better value
|
||||
y_pred_hist = np.histogram(y_pred, bins=128)[0] # Todo: Find a better value
|
||||
|
||||
# L2 norm == euclidean distance
|
||||
hist_euc_dist = cdist(np.expand_dims(y_true_hist, axis=0), np.expand_dims(y_pred_hist, axis=0),
|
||||
metric='euclidean')
|
||||
|
||||
# Manhattan Distance
|
||||
hist_manhattan_dist = cdist(np.expand_dims(y_true_hist, axis=0), np.expand_dims(y_pred_hist, axis=0),
|
||||
metric='cityblock')
|
||||
|
||||
summary_dict.update(hist_manhattan_dist=hist_manhattan_dist, hist_euc_dist=hist_euc_dist)
|
||||
|
||||
#######################################################################################
|
||||
#
|
||||
idx = np.random.choice(np.arange(y_true.shape[0]), 1).item()
|
||||
|
||||
ax = plt.imshow(y_true[idx].squeeze())
|
||||
# Plot using a small number of colors, with unevenly spaced boundaries.
|
||||
ax2 = plt.imshow(attn_weights[idx].sq, interpolation='nearest', aspect='auto', extent=ax.get_extent())
|
||||
self.model.logger.log_image('ROC', image=plt.gcf(), step=self.model.current_epoch)
|
||||
plt.clf()
|
||||
|
||||
|
||||
#######################################################################################
|
||||
#
|
||||
|
||||
|
||||
#######################################################################################
|
||||
#
|
||||
|
||||
plt.close('all')
|
||||
return summary_dict
|
Reference in New Issue
Block a user