Final Train Runs
This commit is contained in:
@ -26,10 +26,16 @@ class MultiClassScores(_BaseScores):
|
||||
#######################################################################################
|
||||
#
|
||||
# INIT
|
||||
y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy()
|
||||
if isinstance(outputs['batch_y'], torch.Tensor):
|
||||
y_true = outputs['batch_y'].cpu().numpy()
|
||||
else:
|
||||
y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy()
|
||||
y_true_one_hot = to_one_hot(y_true, self.model.params.n_classes)
|
||||
|
||||
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy()
|
||||
if isinstance(outputs['y'], torch.Tensor):
|
||||
y_pred = outputs['y'].cpu().numpy()
|
||||
else:
|
||||
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy()
|
||||
y_pred_max = np.argmax(y_pred, axis=1)
|
||||
|
||||
class_names = {val: key for val, key in enumerate(class_names)}
|
||||
|
Reference in New Issue
Block a user