bringing brances up to date

This commit is contained in:
Steffen Illium
2021-02-15 11:39:54 +01:00
parent 010176e80b
commit a966321576
11 changed files with 216 additions and 197 deletions

View File

@@ -0,0 +1,47 @@
import numpy as np
from einops import reduce
import torch
from sklearn.ensemble import IsolationForest
from sklearn.metrics import recall_score, roc_auc_score, average_precision_score
from ml_lib.metrics._base_score import _BaseScores
class AttentionRollout(_BaseScores):
def __init__(self, *args):
super(AttentionRollout, self).__init__(*args)
pass
def __call__(self, outputs):
summary_dict = dict()
#######################################################################################
# Additional Score - Histogram Distances - Image Plotting
#######################################################################################
#
# INIT
attn_weights = [output['attn_weights'].cpu().numpy() for output in outputs]
attn_reduce_heads = [reduce(x, '') for x in attn_weights]
if self.model.params.use_residual:
residual_att = np.eye(att_mat.shape[1])[None, ...]
aug_att_mat = att_mat + residual_att
aug_att_mat = aug_att_mat / aug_att_mat.sum(axis=-1)[..., None]
else:
aug_att_mat = att_mat
joint_attentions = np.zeros(aug_att_mat.shape)
layers = joint_attentions.shape[0]
joint_attentions[0] = aug_att_mat[0]
for i in np.arange(1, layers):
joint_attentions[i] = aug_att_mat[i].dot(joint_attentions[i - 1])

View File

@@ -113,17 +113,17 @@ class MultiClassScores(_BaseScores):
#######################################################################################
#
# Confusion matrix
fig1, ax1 = plt.subplots(dpi=96)
cm = confusion_matrix([class_names[x] for x in y_true], [class_names[x] for x in y_pred_max],
labels=[class_names[key] for key in class_names.keys()],
normalize='all')
disp = ConfusionMatrixDisplay(confusion_matrix=cm,
display_labels=[class_names[i] for i in range(self.model.n_classes)]
)
disp.plot(include_values=True)
disp.plot(include_values=True, ax=ax1)
self.model.logger.log_image('Confusion_Matrix', image=disp.figure_, step=self.model.current_epoch)
self.model.logger.log_image('Confusion_Matrix', image=fig1, step=self.model.current_epoch)
# self.model.logger.log_image('Confusion_Matrix', image=disp.figure_, step=self.model.current_epoch, ext='pdf')
plt.close('all')
return summary_dict
return summary_dict