Audio Dataset

This commit is contained in:
Si11ium
2020-12-01 16:37:16 +01:00
parent 95561acc35
commit 95dcf22f3d
15 changed files with 468 additions and 145 deletions

40
main.py
View File

@ -82,47 +82,9 @@ def run_lightning_loop(config_obj):
# Save the last state & all parameters
trainer.save_checkpoint(str(logger.log_dir / 'weights.ckpt'))
model.save_to_disk(logger.log_dir)
# trainer.run_evaluation(test_mode=True)
# Evaluate It
if config_obj.main.eval:
with torch.no_grad():
model.eval()
if torch.cuda.is_available():
model.cuda()
outputs = []
from tqdm import tqdm
for idx, batch in enumerate(tqdm(model.val_dataloader()[0])):
batch_x, label = batch
batch_x = batch_x.to(device='cuda' if model.on_gpu else 'cpu')
label = label.to(device='cuda' if model.on_gpu else 'cpu')
outputs.append(
model.validation_step((batch_x, label), idx, 1)
)
model.validation_epoch_end([outputs])
# trainer.test()
outpath = Path(config_obj.train.outpath)
model_type = config_obj.model.type
parameters = logger.name
version = f'version_{logger.version}'
inference_out = f'{parameters}_test_out.csv'
from main_inference import prepare_dataloader
import variables as V
test_dataloader = prepare_dataloader(config_obj)
with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile:
outfile.write(f'file_name,prediction\n')
from tqdm import tqdm
for batch in tqdm(test_dataloader, total=len(test_dataloader)):
batch_x, file_names = batch
batch_x = batch_x.to(device='cuda' if model.on_gpu else 'cpu')
y = model(batch_x).main_out
predictions = (y >= 0.5).int()
for prediction, file_name in zip(predictions, file_names):
prediction_text = 'clear' if prediction == V.CLEAR else 'mask'
outfile.write(f'{file_name},{prediction_text}\n')
return model