Merge remote-tracking branch 'origin/master'
# Conflicts: # multi_run.py
This commit is contained in:
11
main.py
11
main.py
@ -110,6 +110,7 @@ def run_lightning_loop(config_obj):
|
||||
inference_out = f'{parameters}_test_out.csv'
|
||||
|
||||
from main_inference import prepare_dataloader
|
||||
import variables as V
|
||||
test_dataloader = prepare_dataloader(config_obj)
|
||||
|
||||
with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile:
|
||||
@ -118,12 +119,12 @@ def run_lightning_loop(config_obj):
|
||||
from tqdm import tqdm
|
||||
for batch in tqdm(test_dataloader, total=len(test_dataloader)):
|
||||
batch_x, file_name = batch
|
||||
batch_x = batch_x.unsqueeze(0).to(device='cuda' if model.on_gpu else 'cpu')
|
||||
batch_x = batch_x.to(device='cuda' if model.on_gpu else 'cpu')
|
||||
y = model(batch_x).main_out
|
||||
prediction = (y.squeeze() >= 0.5).int().item()
|
||||
import variables as V
|
||||
prediction = 'clear' if prediction == V.CLEAR else 'mask'
|
||||
outfile.write(f'{file_name},{prediction}\n')
|
||||
predictions = (y >= 0.5).int()
|
||||
for prediction in predictions:
|
||||
prediction_text = 'clear' if prediction == V.CLEAR else 'mask'
|
||||
outfile.write(f'{file_name},{prediction_text}\n')
|
||||
return model
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user