adjustment fot CCS, notebook folder

This commit is contained in:
Steffen Illium
2021-03-22 16:43:19 +01:00
parent 78b3139d1a
commit c12f3866c8
6 changed files with 156 additions and 29 deletions

View File

@@ -17,13 +17,14 @@ class TrainMixin:
assert isinstance(self, LightningBaseModule)
batch_files, batch_x, batch_y = batch_xy
y = self(batch_x).main_out
if self.params.loss == 'focal_loss_rob':
labels_one_hot = torch.nn.functional.one_hot(batch_y, num_classes=self.params.n_classes)
loss = self.__getattribute__(self.params.loss)(y, labels_one_hot)
if self.params.n_classes <= 2:
loss = self.bce_loss(y, batch_y.long())
else:
loss = self.__getattribute__(self.params.loss)(y, batch_y.long())
if self.params.loss == 'focal_loss_rob':
labels_one_hot = torch.nn.functional.one_hot(batch_y, num_classes=self.params.n_classes)
loss = self.__getattribute__(self.params.loss)(y, labels_one_hot)
else:
loss = self.__getattribute__(self.params.loss)(y, batch_y.long())
return dict(loss=loss)
def training_epoch_end(self, outputs):
@@ -60,14 +61,17 @@ class ValMixin:
).squeeze()
y_one_hot = torch.nn.functional.one_hot(y_max, num_classes=self.params.n_classes).float()
target_y = torch.stack(tuple(sorted_batch_y.values())).long()
if y_one_hot.ndim == 1:
y_one_hot = y_one_hot.unsqueeze(0)
if target_y.ndim == 1:
target_y = target_y.unsqueeze(0)
if self.params.n_classes <= 2:
if y_one_hot.ndim == 1:
y_one_hot = y_one_hot.unsqueeze(0)
if target_y.ndim == 1:
target_y = target_y.unsqueeze(-1)
self.metrics.update(y_one_hot, target_y)
val_loss = self.ce_loss(y, batch_y.long())
if self.params.n_classes <= 2:
val_loss = self.bce_loss(y.squeeze().float(), batch_y.float())
else:
val_loss = self.ce_loss(y, batch_y.long())
return dict(batch_files=batch_files, val_loss=val_loss,
batch_idx=batch_idx, y=y, batch_y=batch_y)
@@ -93,17 +97,26 @@ class ValMixin:
for file_name in sorted_y:
sorted_y.update({file_name: torch.stack(sorted_y[file_name])})
y_mean = torch.stack(
[torch.mean(x, dim=0, keepdim=True) if x.shape[0] > 1 else x for x in sorted_y.values()]
).squeeze()
mean_vote_loss = self.ce_loss(y_mean, sorted_batch_y)
summary_dict.update(val_mean_vote_loss=mean_vote_loss)
#y_mean = torch.stack(
# [torch.mean(x, dim=0, keepdim=True) if x.shape[0] > 1 else x for x in sorted_y.values()]
#).squeeze()
#if y_mean.ndim == 1:
# y_mean = y_mean.unsqueeze(0)
#if sorted_batch_y.ndim == 1:
# sorted_batch_y = sorted_batch_y.unsqueeze(-1)
#
#mean_vote_loss = self.ce_loss(y_mean, sorted_batch_y)
#summary_dict.update(val_mean_vote_loss=mean_vote_loss)
y_max = torch.stack(
[torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()]
).squeeze()
y_one_hot = torch.nn.functional.one_hot(y_max, num_classes=self.params.n_classes).float()
max_vote_loss = self.ce_loss(y_one_hot, sorted_batch_y)
if self.params.n_classes >= 2:
max_vote_loss = self.ce_loss(y_one_hot, sorted_batch_y)
else:
max_vote_loss = self.bce_loss(y_one_hot, sorted_batch_y)
summary_dict.update(val_max_vote_loss=max_vote_loss)
summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key]
@@ -156,6 +169,8 @@ class TestMixin:
enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
elif self.params.n_classes == 2:
class_names = {val: key for val, key in ['negative', 'positive']}
else:
raise AttributeError(f'n_classes has to be any of: [2, 5]')
df = pd.DataFrame(data=dict(filename=[Path(x).name for x in sorted_y.keys()],
prediction=[class_names[x.item()] for x in y_max.cpu()]))