Merge remote-tracking branch 'origin/master'

# Conflicts:
#	multi_run.py
This commit is contained in:
Steffen Illium
2020-05-21 12:16:45 +02:00
6 changed files with 35 additions and 46 deletions

11
main.py
View File

@ -110,6 +110,7 @@ def run_lightning_loop(config_obj):
inference_out = f'{parameters}_test_out.csv'
from main_inference import prepare_dataloader
import variables as V
test_dataloader = prepare_dataloader(config_obj)
with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile:
@ -118,12 +119,12 @@ def run_lightning_loop(config_obj):
from tqdm import tqdm
for batch in tqdm(test_dataloader, total=len(test_dataloader)):
batch_x, file_name = batch
batch_x = batch_x.unsqueeze(0).to(device='cuda' if model.on_gpu else 'cpu')
batch_x = batch_x.to(device='cuda' if model.on_gpu else 'cpu')
y = model(batch_x).main_out
prediction = (y.squeeze() >= 0.5).int().item()
import variables as V
prediction = 'clear' if prediction == V.CLEAR else 'mask'
outfile.write(f'{file_name},{prediction}\n')
predictions = (y >= 0.5).int()
for prediction in predictions:
prediction_text = 'clear' if prediction == V.CLEAR else 'mask'
outfile.write(f'{file_name},{prediction_text}\n')
return model