bugs fixed, binary datasets working
This commit is contained in:
parent
37e36df0a8
commit
76eb567eed
@ -37,6 +37,7 @@ patch_size = 8
|
|||||||
attn_depth = 12
|
attn_depth = 12
|
||||||
heads = 4
|
heads = 4
|
||||||
embedding_size = 128
|
embedding_size = 128
|
||||||
|
mlp_dim = 32
|
||||||
|
|
||||||
|
|
||||||
[CNNBaseline]
|
[CNNBaseline]
|
||||||
|
@ -144,7 +144,9 @@ class CompareBase(_BaseDataModule):
|
|||||||
print(f'{data_option} skipped...')
|
print(f'{data_option} skipped...')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if lab_file is not None:
|
if lab_file is None:
|
||||||
|
lab_file = f'{data_option}.csv'
|
||||||
|
elif lab_file is not None:
|
||||||
if any([x in lab_file for x in data_options]):
|
if any([x in lab_file for x in data_options]):
|
||||||
lab_file = f'{data_option}.csv'
|
lab_file = f'{data_option}.csv'
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ from torch import nn
|
|||||||
|
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
from ml_lib.metrics.binary_class_classifictaion import BinaryScores
|
||||||
from ml_lib.metrics.multi_class_classification import MultiClassScores
|
from ml_lib.metrics.multi_class_classification import MultiClassScores
|
||||||
from ml_lib.modules.blocks import TransformerModule
|
from ml_lib.modules.blocks import TransformerModule
|
||||||
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x)
|
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x)
|
||||||
@ -64,11 +65,6 @@ class Tester(CombinedModelMixins,
|
|||||||
self.autopad = AutoPadToShape((self.image_size, self.image_size))
|
self.autopad = AutoPadToShape((self.image_size, self.image_size))
|
||||||
|
|
||||||
# Modules with Parameters
|
# Modules with Parameters
|
||||||
self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.mlp_dim,
|
|
||||||
heads=self.params.heads, depth=self.params.attn_depth,
|
|
||||||
dropout=self.params.dropout, use_norm=self.params.use_norm,
|
|
||||||
activation=self.params.activation, use_residual=self.params.use_residual
|
|
||||||
)
|
|
||||||
|
|
||||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
|
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
|
||||||
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
|
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
|
||||||
@ -78,13 +74,17 @@ class Tester(CombinedModelMixins,
|
|||||||
|
|
||||||
self.to_cls_token = nn.Identity()
|
self.to_cls_token = nn.Identity()
|
||||||
|
|
||||||
|
logits = self.params.n_classes if self.params.n_classes > 2 else 1
|
||||||
|
|
||||||
|
outbound_activation = nn.Softmax if logits > 1 else nn.Sigmoid
|
||||||
|
|
||||||
self.mlp_head = nn.Sequential(
|
self.mlp_head = nn.Sequential(
|
||||||
nn.LayerNorm(self.embed_dim),
|
nn.LayerNorm(self.embed_dim),
|
||||||
nn.Linear(self.embed_dim, self.params.lat_dim),
|
nn.Linear(self.embed_dim, self.params.lat_dim),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Dropout(self.params.dropout),
|
nn.Dropout(self.params.dropout),
|
||||||
nn.Linear(self.params.lat_dim, n_classes),
|
nn.Linear(self.params.lat_dim, logits),
|
||||||
nn.Softmax()
|
outbound_activation()
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, mask=None, return_attn_weights=False):
|
def forward(self, x, mask=None, return_attn_weights=False):
|
||||||
@ -106,15 +106,12 @@ class Tester(CombinedModelMixins,
|
|||||||
tensor += self.pos_embedding[:, :(n + 1)]
|
tensor += self.pos_embedding[:, :(n + 1)]
|
||||||
tensor = self.dropout(tensor)
|
tensor = self.dropout(tensor)
|
||||||
|
|
||||||
if return_attn_weights:
|
|
||||||
tensor, attn_weights = self.transformer(tensor, mask, return_attn_weights)
|
|
||||||
else:
|
|
||||||
attn_weights = None
|
|
||||||
tensor = self.transformer(tensor, mask)
|
|
||||||
|
|
||||||
tensor = self.to_cls_token(tensor[:, 0])
|
tensor = self.to_cls_token(tensor[:, 0])
|
||||||
tensor = self.mlp_head(tensor)
|
tensor = self.mlp_head(tensor)
|
||||||
return Namespace(main_out=tensor, attn_weights=attn_weights)
|
return Namespace(main_out=tensor, attn_weights=None)
|
||||||
|
|
||||||
def additional_scores(self, outputs):
|
def additional_scores(self, outputs):
|
||||||
return MultiClassScores(self)(outputs)
|
if self.params.n_classes > 2:
|
||||||
|
return MultiClassScores(self)(outputs)
|
||||||
|
else:
|
||||||
|
return BinaryScores(self)(outputs)
|
||||||
|
@ -96,39 +96,27 @@ class ValMixin:
|
|||||||
for file_name in sorted_y:
|
for file_name in sorted_y:
|
||||||
sorted_y.update({file_name: torch.stack(sorted_y[file_name])})
|
sorted_y.update({file_name: torch.stack(sorted_y[file_name])})
|
||||||
|
|
||||||
#y_mean = torch.stack(
|
|
||||||
# [torch.mean(x, dim=0, keepdim=True) if x.shape[0] > 1 else x for x in sorted_y.values()]
|
|
||||||
#).squeeze()
|
|
||||||
|
|
||||||
#if y_mean.ndim == 1:
|
|
||||||
# y_mean = y_mean.unsqueeze(0)
|
|
||||||
#if sorted_batch_y.ndim == 1:
|
|
||||||
# sorted_batch_y = sorted_batch_y.unsqueeze(-1)
|
|
||||||
#
|
|
||||||
#mean_vote_loss = self.ce_loss(y_mean, sorted_batch_y)
|
|
||||||
#summary_dict.update(val_mean_vote_loss=mean_vote_loss)
|
|
||||||
|
|
||||||
if self.params.n_classes <= 2:
|
if self.params.n_classes <= 2:
|
||||||
mean_sorted_y = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x for x in sorted_y.values()])
|
mean_sorted_y = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x for x in sorted_y.values()]).squeeze()
|
||||||
max_vote_loss = self.bce_loss(mean_sorted_y.float(), sorted_batch_y.float())
|
max_vote_loss = self.bce_loss(mean_sorted_y.float(), sorted_batch_y.float())
|
||||||
|
# Sklearn Scores
|
||||||
|
additional_scores = self.additional_scores(dict(y=mean_sorted_y, batch_y=sorted_batch_y))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
y_max = torch.stack(
|
y_max = torch.stack(
|
||||||
[torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()]
|
[torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()]
|
||||||
).squeeze()
|
).squeeze()
|
||||||
y_one_hot = torch.nn.functional.one_hot(y_max, num_classes=self.params.n_classes).float()
|
y_one_hot = torch.nn.functional.one_hot(y_max, num_classes=self.params.n_classes).float()
|
||||||
max_vote_loss = self.ce_loss(y_one_hot, sorted_batch_y)
|
max_vote_loss = self.ce_loss(y_one_hot, sorted_batch_y)
|
||||||
summary_dict.update(val_max_vote_loss=max_vote_loss)
|
# Sklearn Scores
|
||||||
|
additional_scores = self.additional_scores(dict(y=y_one_hot, batch_y=sorted_batch_y))
|
||||||
|
|
||||||
|
summary_dict.update(val_max_vote_loss=max_vote_loss, **additional_scores)
|
||||||
|
|
||||||
summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key]
|
summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key]
|
||||||
for output in outputs]))
|
for output in outputs]))
|
||||||
for key in keys if 'loss' in key}
|
for key in keys if 'loss' in key}
|
||||||
)
|
)
|
||||||
# Sklearn Scores
|
|
||||||
if self.params.n_classes <= 2:
|
|
||||||
additional_scores = self.additional_scores(dict(y=y_max, batch_y=sorted_batch_y))
|
|
||||||
else:
|
|
||||||
additional_scores = self.additional_scores(dict(y=y_one_hot, batch_y=sorted_batch_y))
|
|
||||||
summary_dict.update(**additional_scores)
|
|
||||||
|
|
||||||
pl_metrics, pl_images = self.metrics.compute_and_prepare()
|
pl_metrics, pl_images = self.metrics.compute_and_prepare()
|
||||||
self.metrics.reset()
|
self.metrics.reset()
|
||||||
@ -166,19 +154,20 @@ class TestMixin:
|
|||||||
for file_name in sorted_y:
|
for file_name in sorted_y:
|
||||||
sorted_y.update({file_name: torch.stack(sorted_y[file_name])})
|
sorted_y.update({file_name: torch.stack(sorted_y[file_name])})
|
||||||
|
|
||||||
y_max = torch.stack(
|
|
||||||
[torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()]
|
if self.params.n_classes > 2:
|
||||||
).squeeze().cpu()
|
pred = torch.stack(
|
||||||
if self.params.n_classes == 5:
|
[torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()]
|
||||||
|
).squeeze().cpu()
|
||||||
class_names = {val: key for val, key in
|
class_names = {val: key for val, key in
|
||||||
enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
|
enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
|
||||||
elif self.params.n_classes == 2:
|
|
||||||
class_names = {val: key for val, key in ['negative', 'positive']}
|
|
||||||
else:
|
else:
|
||||||
raise AttributeError(f'n_classes has to be any of: [2, 5]')
|
pred = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x for x in sorted_y.values()]).squeeze()
|
||||||
|
class_names = {val: key for val, key in ['negative', 'positive']}
|
||||||
|
|
||||||
|
|
||||||
df = pd.DataFrame(data=dict(filename=[Path(x).name for x in sorted_y.keys()],
|
df = pd.DataFrame(data=dict(filename=[Path(x).name for x in sorted_y.keys()],
|
||||||
prediction=[class_names[x.item()] for x in y_max.cpu()]))
|
prediction=[class_names[x.item()] for x in pred.cpu()]))
|
||||||
result_file = Path(self.logger.log_dir / 'predictions.csv')
|
result_file = Path(self.logger.log_dir / 'predictions.csv')
|
||||||
if result_file.exists():
|
if result_file.exists():
|
||||||
try:
|
try:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user