from pathlib import Path from pickle import UnpicklingError import torch from tqdm import tqdm import variables as V from torch.utils.data import DataLoader, Dataset from torchvision.transforms import Compose, RandomApply from ml_lib.audio_toolset.audio_augmentation import Speed from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal, MelToImage # Dataset and Dataloaders # ============================================================================= # Transforms from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug from ml_lib.utils.logging import Logger from ml_lib.utils.model_io import SavedLightningModels from ml_lib.utils.transforms import ToTensor from ml_lib.visualization.tools import Plotter from util.config import MConfig # Datasets from datasets.binar_masks import BinaryMasksDataset def prepare_dataloader(config_obj): mel_transforms = Compose([ AudioToMel(sr=config_obj.data.sr, n_mels=config_obj.data.n_mels, n_fft=config_obj.data.n_fft, hop_length=config_obj.data.hop_length), MelToImage()]) transforms = Compose([NormalizeLocal(), ToTensor()]) """ aug_transforms = Compose([ NoiseInjection(0.4), LoudnessManipulator(0.4), ShiftTime(0.3), MaskAug(0.2), NormalizeLocal(), ToTensor() ]) """ dataset: Dataset = BinaryMasksDataset(config_obj.data.root, setting=V.DATA_OPTIONS.devel, mel_transforms=mel_transforms, transforms=transforms ) # noinspection PyTypeChecker return DataLoader(dataset, batch_size=config_obj.train.batch_size, num_workers=config_obj.data.worker if False else 0, shuffle=False) def restore_logger_and_model(log_dir, ckpt): model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, checkpoint=ckpt) model = model.restore() if torch.cuda.is_available(): model.cuda() else: model.cpu() return model if __name__ == '__main__': outpath = Path('output') config_filename = 'config.ini' for checkpoint in outpath.rglob('*.ckpt'): inference_out = checkpoint.parent / 'outputs' / f'{checkpoint.name[:-5]}.csv' if inference_out.exists(): continue inference_out.parent.mkdir(parents=True, exist_ok=True) config = MConfig() config.read_file((checkpoint.parent / config_filename).open('r')) devel_dataloader = prepare_dataloader(config) try: loaded_model = restore_logger_and_model(checkpoint.parent, ckpt=checkpoint) loaded_model.eval() except UnpicklingError: continue with inference_out.open(mode='w') as outfile: outfile.write(f'file_name,prediction\n') for batch in tqdm(devel_dataloader, total=len(devel_dataloader)): batch_x, batch_y = batch y = loaded_model(batch_x.to(device='cuda' if torch.cuda.is_available() else 'cpu')).main_out for prediction, label in zip(y, batch_y): outfile.write(f'{prediction.item()},{label.item()}\n') print('Done')