CCS intergration dataloader

This commit is contained in:
Steffen
2021-03-19 17:17:16 +01:00
parent 6ace861016
commit d4059779c4
8 changed files with 213 additions and 35 deletions

View File

@@ -19,7 +19,7 @@ class TrainMixin:
y = self(batch_x).main_out
if self.params.loss == 'focal_loss_rob':
labels_one_hot = torch.nn.functional.one_hot(batch_y, num_classes=5)
labels_one_hot = torch.nn.functional.one_hot(batch_y, num_classes=self.params.n_classes)
loss = self.__getattribute__(self.params.loss)(y, labels_one_hot)
else:
loss = self.__getattribute__(self.params.loss)(y, batch_y.long())
@@ -58,7 +58,7 @@ class ValMixin:
y_max = torch.stack(
[torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()]
).squeeze()
y_one_hot = torch.nn.functional.one_hot(y_max, num_classes=5).float()
y_one_hot = torch.nn.functional.one_hot(y_max, num_classes=self.params.n_classes).float()
self.metrics.update(y_one_hot, torch.stack(tuple(sorted_batch_y.values())).long())
val_loss = self.ce_loss(y, batch_y.long())
@@ -96,7 +96,7 @@ class ValMixin:
y_max = torch.stack(
[torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()]
).squeeze()
y_one_hot = torch.nn.functional.one_hot(y_max, num_classes=5).float()
y_one_hot = torch.nn.functional.one_hot(y_max, num_classes=self.params.n_classes).float()
max_vote_loss = self.ce_loss(y_one_hot, sorted_batch_y)
summary_dict.update(val_max_vote_loss=max_vote_loss)
@@ -145,7 +145,11 @@ class TestMixin:
y_max = torch.stack(
[torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()]
).squeeze().cpu()
class_names = {val: key for val, key in enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
if self.params.n_classes == 5:
class_names = {val: key for val, key in
enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
elif self.params.n_classes == 2:
class_names = {val: key for val, key in ['negative', 'positive']}
df = pd.DataFrame(data=dict(filename=[Path(x).name for x in sorted_y.keys()],
prediction=y_max.cpu().numpy()))
@@ -154,7 +158,7 @@ class TestMixin:
try:
result_file.unlink()
except:
print('File allready existed')
print('File already existed')
pass
with result_file.open(mode='wb') as csv_file:
df.to_csv(index=False, path_or_buf=csv_file)