Audio Dataset
This commit is contained in:
40
main.py
40
main.py
@ -82,47 +82,9 @@ def run_lightning_loop(config_obj):
|
||||
# Save the last state & all parameters
|
||||
trainer.save_checkpoint(str(logger.log_dir / 'weights.ckpt'))
|
||||
model.save_to_disk(logger.log_dir)
|
||||
# trainer.run_evaluation(test_mode=True)
|
||||
|
||||
# Evaluate It
|
||||
if config_obj.main.eval:
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
if torch.cuda.is_available():
|
||||
model.cuda()
|
||||
outputs = []
|
||||
from tqdm import tqdm
|
||||
for idx, batch in enumerate(tqdm(model.val_dataloader()[0])):
|
||||
batch_x, label = batch
|
||||
batch_x = batch_x.to(device='cuda' if model.on_gpu else 'cpu')
|
||||
label = label.to(device='cuda' if model.on_gpu else 'cpu')
|
||||
outputs.append(
|
||||
model.validation_step((batch_x, label), idx, 1)
|
||||
)
|
||||
model.validation_epoch_end([outputs])
|
||||
|
||||
# trainer.test()
|
||||
outpath = Path(config_obj.train.outpath)
|
||||
model_type = config_obj.model.type
|
||||
parameters = logger.name
|
||||
version = f'version_{logger.version}'
|
||||
inference_out = f'{parameters}_test_out.csv'
|
||||
|
||||
from main_inference import prepare_dataloader
|
||||
import variables as V
|
||||
test_dataloader = prepare_dataloader(config_obj)
|
||||
|
||||
with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile:
|
||||
outfile.write(f'file_name,prediction\n')
|
||||
|
||||
from tqdm import tqdm
|
||||
for batch in tqdm(test_dataloader, total=len(test_dataloader)):
|
||||
batch_x, file_names = batch
|
||||
batch_x = batch_x.to(device='cuda' if model.on_gpu else 'cpu')
|
||||
y = model(batch_x).main_out
|
||||
predictions = (y >= 0.5).int()
|
||||
for prediction, file_name in zip(predictions, file_names):
|
||||
prediction_text = 'clear' if prediction == V.CLEAR else 'mask'
|
||||
outfile.write(f'{file_name},{prediction_text}\n')
|
||||
return model
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user