85 lines
3.0 KiB
Python
85 lines
3.0 KiB
Python
from pathlib import Path
|
|
|
|
import torch
|
|
from tqdm import tqdm
|
|
|
|
import variables as V
|
|
from torch.utils.data import DataLoader, Dataset
|
|
from torchvision.transforms import Compose, RandomApply
|
|
|
|
from ml_lib.audio_toolset.audio_augmentation import Speed
|
|
from ml_lib.audio_toolset.audio_io import LibrosaAudioToMel, NormalizeLocal, MelToImage
|
|
|
|
# Dataset and Dataloaders
|
|
# =============================================================================
|
|
|
|
# Transforms
|
|
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
|
|
from ml_lib.utils.config import Config
|
|
from ml_lib.utils.model_io import SavedLightningModels
|
|
from ml_lib.utils.transforms import ToTensor
|
|
|
|
|
|
# Datasets
|
|
from datasets.binar_masks import BinaryMasksDataset
|
|
|
|
|
|
def prepare_dataloader(config_obj):
|
|
mel_transforms = Compose([
|
|
LibrosaAudioToMel(sr=config_obj.data.sr, n_mels=config_obj.data.n_mels, n_fft=config_obj.data.n_fft,
|
|
hop_length=config_obj.data.hop_length),
|
|
MelToImage()])
|
|
transforms = Compose([NormalizeLocal(), ToTensor()])
|
|
aug_transforms = Compose([
|
|
NoiseInjection(0.4),
|
|
LoudnessManipulator(0.4),
|
|
ShiftTime(0.3),
|
|
MaskAug(0.2),
|
|
NormalizeLocal(), ToTensor()
|
|
])
|
|
|
|
dataset: Dataset = BinaryMasksDataset(config_obj.data.root, setting='test',
|
|
mel_transforms=mel_transforms, transforms=transforms
|
|
)
|
|
# noinspection PyTypeChecker
|
|
return DataLoader(dataset, batch_size=config_obj.train.batch_size,
|
|
num_workers=config_obj.data.worker, shuffle=False)
|
|
|
|
|
|
def restore_logger_and_model(log_dir):
|
|
model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, n=-2)
|
|
model = model.restore()
|
|
if torch.cuda.is_available():
|
|
model.cuda()
|
|
else:
|
|
model.cpu()
|
|
return model
|
|
|
|
|
|
if __name__ == '__main__':
|
|
outpath = Path('output')
|
|
model_type = 'CC'
|
|
parameters = 'CC_213adb16e46592c5a405abfbd693835e/'
|
|
version = 'version_41'
|
|
model_path = Path('/home/steffen/projects/inter_challenge_2020/output/CC/CC_fd2020a7ead9d5c80609a7364741f24b/version_40')
|
|
config_filename = 'config.ini'
|
|
inference_out = 'manual_test_out.csv'
|
|
|
|
config = Config()
|
|
config.read_file((Path(model_path) / config_filename).open())
|
|
test_dataloader = prepare_dataloader(config)
|
|
|
|
loaded_model = restore_logger_and_model(model_path)
|
|
loaded_model.eval()
|
|
|
|
with (model_path / inference_out).open(mode='w') as outfile:
|
|
outfile.write(f'file_name,prediction\n')
|
|
|
|
for batch in tqdm(test_dataloader, total=len(test_dataloader)):
|
|
batch_x, file_name = batch
|
|
y = loaded_model(batch_x.unsqueeze(0).to(device='cuda' if torch.cuda.is_available() else 'cpu')).main_out
|
|
prediction = (y.squeeze() >= 0.5).int().item()
|
|
prediction = 'clear' if prediction == V.CLEAR else 'mask'
|
|
outfile.write(f'{file_name},{prediction}\n')
|
|
print('Done')
|