CCS intergration dataloader
This commit is contained in:
@@ -59,7 +59,13 @@ class ValMixin:
|
||||
[torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()]
|
||||
).squeeze()
|
||||
y_one_hot = torch.nn.functional.one_hot(y_max, num_classes=self.params.n_classes).float()
|
||||
self.metrics.update(y_one_hot, torch.stack(tuple(sorted_batch_y.values())).long())
|
||||
target_y = torch.stack(tuple(sorted_batch_y.values())).long()
|
||||
if y_one_hot.ndim == 1:
|
||||
y_one_hot = y_one_hot.unsqueeze(0)
|
||||
if target_y.ndim == 1:
|
||||
target_y = target_y.unsqueeze(0)
|
||||
|
||||
self.metrics.update(y_one_hot, target_y)
|
||||
|
||||
val_loss = self.ce_loss(y, batch_y.long())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user