from pathlib import Path import torch from tqdm import tqdm import variables as V from torch.utils.data import DataLoader, Dataset from torchvision.transforms import Compose, RandomApply from ml_lib.audio_toolset.audio_augmentation import Speed from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal, MelToImage # Dataset and Dataloaders # ============================================================================= # Transforms from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug from ml_lib.utils.logging import Logger from ml_lib.utils.model_io import SavedLightningModels from ml_lib.utils.transforms import ToTensor from ml_lib.visualization.tools import Plotter from util.config import MConfig # Datasets from datasets.binar_masks import BinaryMasksDataset def prepare_dataloader(config_obj): mel_transforms = Compose([ AudioToMel(sr=config_obj.data.sr, n_mels=config_obj.data.n_mels, n_fft=config_obj.data.n_fft, hop_length=config_obj.data.hop_length), MelToImage()]) transforms = Compose([NormalizeLocal(), ToTensor()]) aug_transforms = Compose([ NoiseInjection(0.4), LoudnessManipulator(0.4), ShiftTime(0.3), MaskAug(0.2), NormalizeLocal(), ToTensor() ]) dataset: Dataset = BinaryMasksDataset(config_obj.data.root, setting='test', mel_transforms=mel_transforms, transforms=transforms ) # noinspection PyTypeChecker return DataLoader(dataset, batch_size=config_obj.train.batch_size, num_workers=config_obj.data.worker, shuffle=False) def restore_logger_and_model(log_dir): model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, n=-2) model = model.restore() if torch.cuda.is_available(): model.cuda() else: model.cpu() return model if __name__ == '__main__': outpath = Path('output') model_type = 'CC' parameters = 'CC_213adb16e46592c5a405abfbd693835e/' version = 'version_41' model_path = Path('/home/steffen/projects/inter_challenge_2020/output/CC/CC_fd2020a7ead9d5c80609a7364741f24b/version_40') config_filename = 'config.ini' inference_out = 'manual_test_out.csv' config = MConfig() config.read_file((Path(model_path) / config_filename).open('r')) test_dataloader = prepare_dataloader(config) loaded_model = restore_logger_and_model(model_path) loaded_model.eval() with (model_path / inference_out).open(mode='w') as outfile: outfile.write(f'file_name,prediction\n') for batch in tqdm(test_dataloader, total=len(test_dataloader)): batch_x, file_name = batch y = loaded_model(batch_x.unsqueeze(0).to(device='cuda' if torch.cuda.is_available() else 'cpu')).main_out prediction = (y.squeeze() >= 0.5).int().item() prediction = 'clear' if prediction == V.CLEAR else 'mask' outfile.write(f'{file_name},{prediction}\n') print('Done')