CCS intergration dataloader

This commit is contained in:
Steffen
2021-03-19 18:05:17 +01:00
parent d30edbda6e
commit 78b3139d1a
2 changed files with 21 additions and 12 deletions

View File

@@ -59,7 +59,13 @@ class ValMixin:
[torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()]
).squeeze()
y_one_hot = torch.nn.functional.one_hot(y_max, num_classes=self.params.n_classes).float()
self.metrics.update(y_one_hot, torch.stack(tuple(sorted_batch_y.values())).long())
target_y = torch.stack(tuple(sorted_batch_y.values())).long()
if y_one_hot.ndim == 1:
y_one_hot = y_one_hot.unsqueeze(0)
if target_y.ndim == 1:
target_y = target_y.unsqueeze(0)
self.metrics.update(y_one_hot, target_y)
val_loss = self.ce_loss(y, batch_y.long())