From 7c8860277623b49d2ce4e683b2831b63ef1de967 Mon Sep 17 00:00:00 2001 From: Steffen Date: Sun, 28 Mar 2021 11:49:17 +0200 Subject: [PATCH] binary test output in working state --- main.py | 10 +++++----- util/module_mixins.py | 3 ++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/main.py b/main.py index eb87f10..12a440c 100644 --- a/main.py +++ b/main.py @@ -76,11 +76,11 @@ def run_lightning_loop(h_params :Namespace, data_class, model_class, seed=69, ad trainer.fit(model, datamodule) trainer.save_checkpoint(logger.save_dir / 'last_weights.ckpt') - try: - trainer.test(model=model, datamodule=datamodule) - except: - print('Test did not Suceed!') - pass + + trainer.test(model=model, datamodule=datamodule) + #except: + # print('Test did not Suceed!') + # pass logger.log_metrics(score_callback.best_scores, step=trainer.global_step+1) diff --git a/util/module_mixins.py b/util/module_mixins.py index 851e14b..b89fcd0 100644 --- a/util/module_mixins.py +++ b/util/module_mixins.py @@ -163,7 +163,8 @@ class TestMixin: enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])} else: pred = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x for x in sorted_y.values()]).squeeze() - class_names = {val: key for val, key in ['negative', 'positive']} + pred = torch.where(pred > 0.5, 1, 0) + class_names = {val: key for val, key in enumerate(['negative', 'positive'])} df = pd.DataFrame(data=dict(filename=[Path(x).name for x in sorted_y.keys()],