bringing brances up to date
This commit is contained in:
@ -113,17 +113,17 @@ class MultiClassScores(_BaseScores):
|
||||
#######################################################################################
|
||||
#
|
||||
# Confusion matrix
|
||||
|
||||
fig1, ax1 = plt.subplots(dpi=96)
|
||||
cm = confusion_matrix([class_names[x] for x in y_true], [class_names[x] for x in y_pred_max],
|
||||
labels=[class_names[key] for key in class_names.keys()],
|
||||
normalize='all')
|
||||
disp = ConfusionMatrixDisplay(confusion_matrix=cm,
|
||||
display_labels=[class_names[i] for i in range(self.model.n_classes)]
|
||||
)
|
||||
disp.plot(include_values=True)
|
||||
disp.plot(include_values=True, ax=ax1)
|
||||
|
||||
self.model.logger.log_image('Confusion_Matrix', image=disp.figure_, step=self.model.current_epoch)
|
||||
self.model.logger.log_image('Confusion_Matrix', image=fig1, step=self.model.current_epoch)
|
||||
# self.model.logger.log_image('Confusion_Matrix', image=disp.figure_, step=self.model.current_epoch, ext='pdf')
|
||||
|
||||
plt.close('all')
|
||||
return summary_dict
|
||||
return summary_dict
|
||||
|
Reference in New Issue
Block a user