from pathlib import Path

import torch
from tqdm import tqdm

import variables as V
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, RandomApply

from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal, MelToImage

# Dataset and Dataloaders
# =============================================================================

# Transforms
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
from ml_lib.utils.logging import Logger
from ml_lib.utils.model_io import SavedLightningModels
from ml_lib.utils.transforms import ToTensor
from util.config import MConfig

# Datasets
from datasets.binar_masks import BinaryMasksDataset


def prepare_dataloader(config_obj):
    mel_transforms = Compose([
        # Audio to Mel Transformations
        AudioToMel(sr=config_obj.data.sr, n_mels=config_obj.data.n_mels, n_fft=config_obj.data.n_fft,
                   hop_length=config_obj.data.hop_length), MelToImage()])
    transforms = Compose([NormalizeLocal(), ToTensor()])
    aug_transforms = Compose([
        RandomApply([
            NoiseInjection(config_obj.data.noise_ratio),
            LoudnessManipulator(config_obj.data.loudness_ratio),
            ShiftTime(config_obj.data.shift_ratio),
            MaskAug(config_obj.data.mask_ratio),
        ], p=0.6),
        # Utility
        NormalizeLocal(), ToTensor()
    ])

    dataset: Dataset = BinaryMasksDataset(config_obj.data.root, setting='train',
                                          mel_transforms=mel_transforms, transforms=transforms
                                          )
    # noinspection PyTypeChecker
    return DataLoader(dataset, batch_size=None, num_workers=0, shuffle=False)


def restore_logger_and_model(config_obj):
    logger = Logger(config_obj)
    model = SavedLightningModels.load_checkpoint(models_root_path=logger.log_dir, n=-2)
    model = model.restore()
    if torch.cuda.is_available():
        model.cuda()
    else:
        model.cpu()
    return model


if __name__ == '__main__':
    outpath = Path('output')
    model_type = 'CC'
    parameters = 'CC_213adb16e46592c5a405abfbd693835e/'
    version = 'version_41'
    config_filename = 'config.ini'
    inference_out = 'manual_test_out.csv'

    config = MConfig()
    config.read_file((outpath / model_type / parameters / version / config_filename).open('r'))
    test_dataloader = prepare_dataloader(config)
    loaded_model = restore_logger_and_model(config)
    loaded_model.eval()

    with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile:
        outfile.write(f'file_name,prediction\n')

        for batch in tqdm(test_dataloader, total=len(test_dataloader)):
            batch_x, file_name = batch
            y = loaded_model(batch_x.unsqueeze(0).to(device='cuda' if torch.cuda.is_available() else 'cpu')).main_out
            prediction = (y.squeeze() >= 0.5).int().item()
            prediction = 'clear' if prediction == V.CLEAR else 'mask'
            outfile.write(f'{file_name},{prediction}\n')
    print('Done')