2020-06-14 20:50:54 +02:00

93 lines
3.2 KiB
Python

from pathlib import Path
from pickle import UnpicklingError
import torch
from tqdm import tqdm
import variables as V
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, RandomApply
from ml_lib.audio_toolset.audio_augmentation import Speed
from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal, MelToImage
# Dataset and Dataloaders
# =============================================================================
# Transforms
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
from ml_lib.utils.logging import Logger
from ml_lib.utils.model_io import SavedLightningModels
from ml_lib.utils.transforms import ToTensor
from ml_lib.visualization.tools import Plotter
from util.config import MConfig
# Datasets
from datasets.binar_masks import BinaryMasksDataset
def prepare_dataloader(config_obj):
mel_transforms = Compose([
AudioToMel(sr=config_obj.data.sr, n_mels=config_obj.data.n_mels, n_fft=config_obj.data.n_fft,
hop_length=config_obj.data.hop_length),
MelToImage()])
transforms = Compose([NormalizeLocal(), ToTensor()])
"""
aug_transforms = Compose([
NoiseInjection(0.4),
LoudnessManipulator(0.4),
ShiftTime(0.3),
MaskAug(0.2),
NormalizeLocal(), ToTensor()
])
"""
dataset: Dataset = BinaryMasksDataset(config_obj.data.root, setting=V.DATA_OPTIONS.devel,
mel_transforms=mel_transforms, transforms=transforms
)
# noinspection PyTypeChecker
return DataLoader(dataset, batch_size=config_obj.train.batch_size,
num_workers=config_obj.data.worker if False else 0, shuffle=False)
def restore_logger_and_model(log_dir, ckpt):
model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, checkpoint=ckpt)
model = model.restore()
if torch.cuda.is_available():
model.cuda()
else:
model.cpu()
return model
if __name__ == '__main__':
outpath = Path('output')
config_filename = 'config.ini'
for checkpoint in outpath.rglob('*.ckpt'):
inference_out = checkpoint.parent / 'outputs' / f'{checkpoint.name[:-5]}.csv'
if inference_out.exists():
continue
inference_out.parent.mkdir(parents=True, exist_ok=True)
config = MConfig()
config.read_file((checkpoint.parent / config_filename).open('r'))
devel_dataloader = prepare_dataloader(config)
try:
loaded_model = restore_logger_and_model(checkpoint.parent, ckpt=checkpoint)
loaded_model.eval()
except UnpicklingError:
continue
with inference_out.open(mode='w') as outfile:
outfile.write(f'file_name,prediction\n')
for batch in tqdm(devel_dataloader, total=len(devel_dataloader)):
batch_x, batch_y = batch
y = loaded_model(batch_x.to(device='cuda' if torch.cuda.is_available() else 'cpu')).main_out
for prediction, label in zip(y, batch_y):
outfile.write(f'{prediction.item()},{label.item()}\n')
print('Done')