CCS intergration training running

notebooks
This commit is contained in:
Steffen
2021-03-24 08:03:12 +01:00
parent c12f3866c8
commit 82835295a1
11 changed files with 1264 additions and 445 deletions

View File

@@ -18,7 +18,7 @@ class TrainMixin:
batch_files, batch_x, batch_y = batch_xy
y = self(batch_x).main_out
if self.params.n_classes <= 2:
loss = self.bce_loss(y, batch_y.long())
loss = self.bce_loss(y.squeeze().float(), batch_y.float())
else:
if self.params.loss == 'focal_loss_rob':
labels_one_hot = torch.nn.functional.one_hot(batch_y, num_classes=self.params.n_classes)