binary test output in working state
This commit is contained in:
8
main.py
8
main.py
@ -76,11 +76,11 @@ def run_lightning_loop(h_params :Namespace, data_class, model_class, seed=69, ad
|
|||||||
trainer.fit(model, datamodule)
|
trainer.fit(model, datamodule)
|
||||||
trainer.save_checkpoint(logger.save_dir / 'last_weights.ckpt')
|
trainer.save_checkpoint(logger.save_dir / 'last_weights.ckpt')
|
||||||
|
|
||||||
try:
|
|
||||||
trainer.test(model=model, datamodule=datamodule)
|
trainer.test(model=model, datamodule=datamodule)
|
||||||
except:
|
#except:
|
||||||
print('Test did not Suceed!')
|
# print('Test did not Suceed!')
|
||||||
pass
|
# pass
|
||||||
|
|
||||||
logger.log_metrics(score_callback.best_scores, step=trainer.global_step+1)
|
logger.log_metrics(score_callback.best_scores, step=trainer.global_step+1)
|
||||||
|
|
||||||
|
@ -163,7 +163,8 @@ class TestMixin:
|
|||||||
enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
|
enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
|
||||||
else:
|
else:
|
||||||
pred = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x for x in sorted_y.values()]).squeeze()
|
pred = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x for x in sorted_y.values()]).squeeze()
|
||||||
class_names = {val: key for val, key in ['negative', 'positive']}
|
pred = torch.where(pred > 0.5, 1, 0)
|
||||||
|
class_names = {val: key for val, key in enumerate(['negative', 'positive'])}
|
||||||
|
|
||||||
|
|
||||||
df = pd.DataFrame(data=dict(filename=[Path(x).name for x in sorted_y.keys()],
|
df = pd.DataFrame(data=dict(filename=[Path(x).name for x in sorted_y.keys()],
|
||||||
|
Reference in New Issue
Block a user