CCS intergration training running
notebooks
This commit is contained in:
@@ -18,7 +18,7 @@ class TrainMixin:
|
||||
batch_files, batch_x, batch_y = batch_xy
|
||||
y = self(batch_x).main_out
|
||||
if self.params.n_classes <= 2:
|
||||
loss = self.bce_loss(y, batch_y.long())
|
||||
loss = self.bce_loss(y.squeeze().float(), batch_y.float())
|
||||
else:
|
||||
if self.params.loss == 'focal_loss_rob':
|
||||
labels_one_hot = torch.nn.functional.one_hot(batch_y, num_classes=self.params.n_classes)
|
||||
|
||||
Reference in New Issue
Block a user