bugs fixed, binary datasets working

This commit is contained in:
Steffen 2021-03-27 18:23:51 +01:00
parent 37e36df0a8
commit 76eb567eed
4 changed files with 33 additions and 44 deletions

View File

@ -37,6 +37,7 @@ patch_size = 8
attn_depth = 12 attn_depth = 12
heads = 4 heads = 4
embedding_size = 128 embedding_size = 128
mlp_dim = 32
[CNNBaseline] [CNNBaseline]

View File

@ -144,7 +144,9 @@ class CompareBase(_BaseDataModule):
print(f'{data_option} skipped...') print(f'{data_option} skipped...')
continue continue
if lab_file is not None: if lab_file is None:
lab_file = f'{data_option}.csv'
elif lab_file is not None:
if any([x in lab_file for x in data_options]): if any([x in lab_file for x in data_options]):
lab_file = f'{data_option}.csv' lab_file = f'{data_option}.csv'

View File

@ -8,6 +8,7 @@ from torch import nn
from einops import rearrange, repeat from einops import rearrange, repeat
from ml_lib.metrics.binary_class_classifictaion import BinaryScores
from ml_lib.metrics.multi_class_classification import MultiClassScores from ml_lib.metrics.multi_class_classification import MultiClassScores
from ml_lib.modules.blocks import TransformerModule from ml_lib.modules.blocks import TransformerModule
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x) from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x)
@ -64,11 +65,6 @@ class Tester(CombinedModelMixins,
self.autopad = AutoPadToShape((self.image_size, self.image_size)) self.autopad = AutoPadToShape((self.image_size, self.image_size))
# Modules with Parameters # Modules with Parameters
self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.mlp_dim,
heads=self.params.heads, depth=self.params.attn_depth,
dropout=self.params.dropout, use_norm=self.params.use_norm,
activation=self.params.activation, use_residual=self.params.use_residual
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim)) self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \ self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
@ -78,13 +74,17 @@ class Tester(CombinedModelMixins,
self.to_cls_token = nn.Identity() self.to_cls_token = nn.Identity()
logits = self.params.n_classes if self.params.n_classes > 2 else 1
outbound_activation = nn.Softmax if logits > 1 else nn.Sigmoid
self.mlp_head = nn.Sequential( self.mlp_head = nn.Sequential(
nn.LayerNorm(self.embed_dim), nn.LayerNorm(self.embed_dim),
nn.Linear(self.embed_dim, self.params.lat_dim), nn.Linear(self.embed_dim, self.params.lat_dim),
nn.GELU(), nn.GELU(),
nn.Dropout(self.params.dropout), nn.Dropout(self.params.dropout),
nn.Linear(self.params.lat_dim, n_classes), nn.Linear(self.params.lat_dim, logits),
nn.Softmax() outbound_activation()
) )
def forward(self, x, mask=None, return_attn_weights=False): def forward(self, x, mask=None, return_attn_weights=False):
@ -106,15 +106,12 @@ class Tester(CombinedModelMixins,
tensor += self.pos_embedding[:, :(n + 1)] tensor += self.pos_embedding[:, :(n + 1)]
tensor = self.dropout(tensor) tensor = self.dropout(tensor)
if return_attn_weights:
tensor, attn_weights = self.transformer(tensor, mask, return_attn_weights)
else:
attn_weights = None
tensor = self.transformer(tensor, mask)
tensor = self.to_cls_token(tensor[:, 0]) tensor = self.to_cls_token(tensor[:, 0])
tensor = self.mlp_head(tensor) tensor = self.mlp_head(tensor)
return Namespace(main_out=tensor, attn_weights=attn_weights) return Namespace(main_out=tensor, attn_weights=None)
def additional_scores(self, outputs): def additional_scores(self, outputs):
return MultiClassScores(self)(outputs) if self.params.n_classes > 2:
return MultiClassScores(self)(outputs)
else:
return BinaryScores(self)(outputs)

View File

@ -96,39 +96,27 @@ class ValMixin:
for file_name in sorted_y: for file_name in sorted_y:
sorted_y.update({file_name: torch.stack(sorted_y[file_name])}) sorted_y.update({file_name: torch.stack(sorted_y[file_name])})
#y_mean = torch.stack(
# [torch.mean(x, dim=0, keepdim=True) if x.shape[0] > 1 else x for x in sorted_y.values()]
#).squeeze()
#if y_mean.ndim == 1:
# y_mean = y_mean.unsqueeze(0)
#if sorted_batch_y.ndim == 1:
# sorted_batch_y = sorted_batch_y.unsqueeze(-1)
#
#mean_vote_loss = self.ce_loss(y_mean, sorted_batch_y)
#summary_dict.update(val_mean_vote_loss=mean_vote_loss)
if self.params.n_classes <= 2: if self.params.n_classes <= 2:
mean_sorted_y = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x for x in sorted_y.values()]) mean_sorted_y = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x for x in sorted_y.values()]).squeeze()
max_vote_loss = self.bce_loss(mean_sorted_y.float(), sorted_batch_y.float()) max_vote_loss = self.bce_loss(mean_sorted_y.float(), sorted_batch_y.float())
# Sklearn Scores
additional_scores = self.additional_scores(dict(y=mean_sorted_y, batch_y=sorted_batch_y))
else: else:
y_max = torch.stack( y_max = torch.stack(
[torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()] [torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()]
).squeeze() ).squeeze()
y_one_hot = torch.nn.functional.one_hot(y_max, num_classes=self.params.n_classes).float() y_one_hot = torch.nn.functional.one_hot(y_max, num_classes=self.params.n_classes).float()
max_vote_loss = self.ce_loss(y_one_hot, sorted_batch_y) max_vote_loss = self.ce_loss(y_one_hot, sorted_batch_y)
summary_dict.update(val_max_vote_loss=max_vote_loss) # Sklearn Scores
additional_scores = self.additional_scores(dict(y=y_one_hot, batch_y=sorted_batch_y))
summary_dict.update(val_max_vote_loss=max_vote_loss, **additional_scores)
summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key] summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs])) for output in outputs]))
for key in keys if 'loss' in key} for key in keys if 'loss' in key}
) )
# Sklearn Scores
if self.params.n_classes <= 2:
additional_scores = self.additional_scores(dict(y=y_max, batch_y=sorted_batch_y))
else:
additional_scores = self.additional_scores(dict(y=y_one_hot, batch_y=sorted_batch_y))
summary_dict.update(**additional_scores)
pl_metrics, pl_images = self.metrics.compute_and_prepare() pl_metrics, pl_images = self.metrics.compute_and_prepare()
self.metrics.reset() self.metrics.reset()
@ -166,19 +154,20 @@ class TestMixin:
for file_name in sorted_y: for file_name in sorted_y:
sorted_y.update({file_name: torch.stack(sorted_y[file_name])}) sorted_y.update({file_name: torch.stack(sorted_y[file_name])})
y_max = torch.stack(
[torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()] if self.params.n_classes > 2:
).squeeze().cpu() pred = torch.stack(
if self.params.n_classes == 5: [torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()]
).squeeze().cpu()
class_names = {val: key for val, key in class_names = {val: key for val, key in
enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])} enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
elif self.params.n_classes == 2:
class_names = {val: key for val, key in ['negative', 'positive']}
else: else:
raise AttributeError(f'n_classes has to be any of: [2, 5]') pred = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x for x in sorted_y.values()]).squeeze()
class_names = {val: key for val, key in ['negative', 'positive']}
df = pd.DataFrame(data=dict(filename=[Path(x).name for x in sorted_y.keys()], df = pd.DataFrame(data=dict(filename=[Path(x).name for x in sorted_y.keys()],
prediction=[class_names[x.item()] for x in y_max.cpu()])) prediction=[class_names[x.item()] for x in pred.cpu()]))
result_file = Path(self.logger.log_dir / 'predictions.csv') result_file = Path(self.logger.log_dir / 'predictions.csv')
if result_file.exists(): if result_file.exists():
try: try: