masks_augments_compare-21/main_inference.py
2020-05-09 21:56:58 +02:00

74 lines
2.6 KiB
Python

from pathlib import Path
import torch
from tqdm import tqdm
import variables as V
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose
from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal, MelToImage
# Dataset and Dataloaders
# =============================================================================
# Transforms
from ml_lib.utils.logging import Logger
from ml_lib.utils.model_io import SavedLightningModels
from ml_lib.utils.transforms import ToTensor
from util.config import MConfig
# Datasets
from datasets.binar_masks import BinaryMasksDataset
def prepare_dataloader(config_obj):
mel_transforms = Compose([
# Audio to Mel Transformations
AudioToMel(sr=config_obj.data.sr, n_mels=config_obj.data.n_mels, n_fft=config_obj.data.n_fft,
hop_length=config_obj.data.hop_length), MelToImage()])
transforms = Compose([NormalizeLocal(), ToTensor()])
dataset: Dataset = BinaryMasksDataset(config_obj.data.root, setting='test',
mel_transforms=mel_transforms, transforms=transforms
)
# noinspection PyTypeChecker
return DataLoader(dataset, batch_size=None, num_workers=0, shuffle=False)
def restore_logger_and_model(config_obj):
logger = Logger(config_obj)
model = SavedLightningModels.load_checkpoint(models_root_path=logger.log_dir, n=-2)
model = model.restore()
if torch.cuda.is_available():
model.cuda()
else:
model.cpu()
return model
if __name__ == '__main__':
outpath = Path('output')
model_type = 'BandwiseConvMultiheadClassifier'
parameters = 'BCMC_9c70168a5711c269b33701f1650adfb9/'
version = 'version_1'
config_filename = 'config.ini'
inference_out = 'manual_test_out.csv'
config = MConfig()
config.read_file((outpath / model_type / parameters / version / config_filename).open('r'))
test_dataloader = prepare_dataloader(config)
loaded_model = restore_logger_and_model(config)
loaded_model.eval()
with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile:
outfile.write(f'file_name,prediction\n')
for batch in tqdm(test_dataloader, total=len(test_dataloader)):
batch_x, file_name = batch
y = loaded_model(batch_x.unsqueeze(0).to(device='cuda' if torch.cuda.is_available() else 'cpu')).main_out
prediction = (y.squeeze() >= 0.5).int().item()
prediction = 'clear' if prediction == V.CLEAR else 'mask'
outfile.write(f'{file_name},{prediction}\n')
print('Done')