From 76eb567eed628a9a36cf1665b9e154b4bb80fefa Mon Sep 17 00:00:00 2001 From: Steffen Date: Sat, 27 Mar 2021 18:23:51 +0100 Subject: [PATCH] bugs fixed, binary datasets working --- _parameters.ini | 1 + datasets/compare_base.py | 4 +++- models/testing.py | 27 +++++++++++------------- util/module_mixins.py | 45 +++++++++++++++------------------------- 4 files changed, 33 insertions(+), 44 deletions(-) diff --git a/_parameters.ini b/_parameters.ini index 3bb70ee..dceb99d 100644 --- a/_parameters.ini +++ b/_parameters.ini @@ -37,6 +37,7 @@ patch_size = 8 attn_depth = 12 heads = 4 embedding_size = 128 +mlp_dim = 32 [CNNBaseline] diff --git a/datasets/compare_base.py b/datasets/compare_base.py index ce5e075..0594420 100644 --- a/datasets/compare_base.py +++ b/datasets/compare_base.py @@ -144,7 +144,9 @@ class CompareBase(_BaseDataModule): print(f'{data_option} skipped...') continue - if lab_file is not None: + if lab_file is None: + lab_file = f'{data_option}.csv' + elif lab_file is not None: if any([x in lab_file for x in data_options]): lab_file = f'{data_option}.csv' diff --git a/models/testing.py b/models/testing.py index ad5928a..8da7e92 100644 --- a/models/testing.py +++ b/models/testing.py @@ -8,6 +8,7 @@ from torch import nn from einops import rearrange, repeat +from ml_lib.metrics.binary_class_classifictaion import BinaryScores from ml_lib.metrics.multi_class_classification import MultiClassScores from ml_lib.modules.blocks import TransformerModule from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x) @@ -64,11 +65,6 @@ class Tester(CombinedModelMixins, self.autopad = AutoPadToShape((self.image_size, self.image_size)) # Modules with Parameters - self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.mlp_dim, - heads=self.params.heads, depth=self.params.attn_depth, - dropout=self.params.dropout, use_norm=self.params.use_norm, - activation=self.params.activation, use_residual=self.params.use_residual - ) self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim)) self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \ @@ -78,13 +74,17 @@ class Tester(CombinedModelMixins, self.to_cls_token = nn.Identity() + logits = self.params.n_classes if self.params.n_classes > 2 else 1 + + outbound_activation = nn.Softmax if logits > 1 else nn.Sigmoid + self.mlp_head = nn.Sequential( nn.LayerNorm(self.embed_dim), nn.Linear(self.embed_dim, self.params.lat_dim), nn.GELU(), nn.Dropout(self.params.dropout), - nn.Linear(self.params.lat_dim, n_classes), - nn.Softmax() + nn.Linear(self.params.lat_dim, logits), + outbound_activation() ) def forward(self, x, mask=None, return_attn_weights=False): @@ -106,15 +106,12 @@ class Tester(CombinedModelMixins, tensor += self.pos_embedding[:, :(n + 1)] tensor = self.dropout(tensor) - if return_attn_weights: - tensor, attn_weights = self.transformer(tensor, mask, return_attn_weights) - else: - attn_weights = None - tensor = self.transformer(tensor, mask) - tensor = self.to_cls_token(tensor[:, 0]) tensor = self.mlp_head(tensor) - return Namespace(main_out=tensor, attn_weights=attn_weights) + return Namespace(main_out=tensor, attn_weights=None) def additional_scores(self, outputs): - return MultiClassScores(self)(outputs) + if self.params.n_classes > 2: + return MultiClassScores(self)(outputs) + else: + return BinaryScores(self)(outputs) diff --git a/util/module_mixins.py b/util/module_mixins.py index ba04994..851e14b 100644 --- a/util/module_mixins.py +++ b/util/module_mixins.py @@ -96,39 +96,27 @@ class ValMixin: for file_name in sorted_y: sorted_y.update({file_name: torch.stack(sorted_y[file_name])}) - #y_mean = torch.stack( - # [torch.mean(x, dim=0, keepdim=True) if x.shape[0] > 1 else x for x in sorted_y.values()] - #).squeeze() - - #if y_mean.ndim == 1: - # y_mean = y_mean.unsqueeze(0) - #if sorted_batch_y.ndim == 1: - # sorted_batch_y = sorted_batch_y.unsqueeze(-1) - # - #mean_vote_loss = self.ce_loss(y_mean, sorted_batch_y) - #summary_dict.update(val_mean_vote_loss=mean_vote_loss) - if self.params.n_classes <= 2: - mean_sorted_y = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x for x in sorted_y.values()]) + mean_sorted_y = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x for x in sorted_y.values()]).squeeze() max_vote_loss = self.bce_loss(mean_sorted_y.float(), sorted_batch_y.float()) + # Sklearn Scores + additional_scores = self.additional_scores(dict(y=mean_sorted_y, batch_y=sorted_batch_y)) + else: y_max = torch.stack( [torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()] ).squeeze() y_one_hot = torch.nn.functional.one_hot(y_max, num_classes=self.params.n_classes).float() max_vote_loss = self.ce_loss(y_one_hot, sorted_batch_y) - summary_dict.update(val_max_vote_loss=max_vote_loss) + # Sklearn Scores + additional_scores = self.additional_scores(dict(y=y_one_hot, batch_y=sorted_batch_y)) + + summary_dict.update(val_max_vote_loss=max_vote_loss, **additional_scores) summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key] for output in outputs])) for key in keys if 'loss' in key} ) - # Sklearn Scores - if self.params.n_classes <= 2: - additional_scores = self.additional_scores(dict(y=y_max, batch_y=sorted_batch_y)) - else: - additional_scores = self.additional_scores(dict(y=y_one_hot, batch_y=sorted_batch_y)) - summary_dict.update(**additional_scores) pl_metrics, pl_images = self.metrics.compute_and_prepare() self.metrics.reset() @@ -166,19 +154,20 @@ class TestMixin: for file_name in sorted_y: sorted_y.update({file_name: torch.stack(sorted_y[file_name])}) - y_max = torch.stack( - [torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()] - ).squeeze().cpu() - if self.params.n_classes == 5: + + if self.params.n_classes > 2: + pred = torch.stack( + [torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()] + ).squeeze().cpu() class_names = {val: key for val, key in enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])} - elif self.params.n_classes == 2: - class_names = {val: key for val, key in ['negative', 'positive']} else: - raise AttributeError(f'n_classes has to be any of: [2, 5]') + pred = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x for x in sorted_y.values()]).squeeze() + class_names = {val: key for val, key in ['negative', 'positive']} + df = pd.DataFrame(data=dict(filename=[Path(x).name for x in sorted_y.keys()], - prediction=[class_names[x.item()] for x in y_max.cpu()])) + prediction=[class_names[x.item()] for x in pred.cpu()])) result_file = Path(self.logger.log_dir / 'predictions.csv') if result_file.exists(): try: