Merge remote-tracking branch 'origin/master'

# Conflicts:
#	main.py
This commit is contained in:
Steffen Illium
2021-03-18 21:44:35 +01:00
4 changed files with 24 additions and 19 deletions

View File

@@ -147,9 +147,8 @@ class TestMixin:
).squeeze().cpu()
class_names = {val: key for val, key in enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
df = pd.DataFrame(data=dict(filenames=[Path(x).stem for x in sorted_y.keys()],
prediction=y_max.cpu().numpy(),
prediction_named=[class_names[x.item()] for x in y_max.cpu().numpy()]))
df = pd.DataFrame(data=dict(filename=[Path(x).name for x in sorted_y.keys()],
prediction=y_max.cpu().numpy()))
result_file = Path(self.logger.log_dir / 'predictions.csv')
if result_file.exists():
try: