Final Train Runs

This commit is contained in:
Steffen Illium
2021-03-18 07:45:06 +01:00
parent f89f0f8528
commit fc4617c9d8
4 changed files with 49 additions and 18 deletions

View File

@ -26,10 +26,16 @@ class MultiClassScores(_BaseScores):
#######################################################################################
#
# INIT
y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy()
if isinstance(outputs['batch_y'], torch.Tensor):
y_true = outputs['batch_y'].cpu().numpy()
else:
y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy()
y_true_one_hot = to_one_hot(y_true, self.model.params.n_classes)
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy()
if isinstance(outputs['y'], torch.Tensor):
y_pred = outputs['y'].cpu().numpy()
else:
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy()
y_pred_max = np.argmax(y_pred, axis=1)
class_names = {val: key for val, key in enumerate(class_names)}