import multiprocessing as mp
from collections import defaultdict
from pathlib import Path

from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler
from torchvision.transforms import Compose, RandomApply

from ml_lib.audio_toolset.audio_io import NormalizeLocal
from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
from ml_lib.utils._basedatamodule import _BaseDataModule, DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel
from ml_lib.utils.equal_sampler import EqualSampler
from ml_lib.utils.transforms import ToTensor

data_options = [DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel]


class PrimatesLibrosaDatamodule(_BaseDataModule):

    @property
    def class_names(self):
        return {key: val for val, key in enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}

    @property
    def n_classes(self):
        return len(self.class_names)

    @property
    def shape(self):

        return self.datasets[DATA_OPTION_train].datasets[0][0][1].shape

    @property
    def mel_folder(self):
        return self.root / 'mel_folder'

    @property
    def wav_folder(self):
        return self.root / 'wav'

    def __init__(self, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length, sampler=None,
                 target_mel_length_in_seconds=0.7, random_apply_chance=0.5,
                 loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3):
        super(PrimatesLibrosaDatamodule, self).__init__()
        self.sampler = sampler
        self.samplers = None

        self.num_worker = num_worker or 1
        self.batch_size = batch_size
        self.root = Path(data_root) / 'primates'
        self.target_mel_length_in_seconds = target_mel_length_in_seconds

        # Mel Transforms - will be pushed with all other paramters by self.__dict__ to subdataset-class
        self.mel_kwargs = dict(sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)

        # Utility
        self.utility_transforms = Compose([NormalizeLocal(), ToTensor()])

        # Data Augmentations
        self.random_apply_chance = random_apply_chance
        self.mel_augmentations = Compose([
            RandomApply([NoiseInjection(noise_ratio)], p=random_apply_chance),
            RandomApply([LoudnessManipulator(loudness_ratio)], p=random_apply_chance),
            RandomApply([ShiftTime(shift_ratio)], p=random_apply_chance),
            RandomApply([MaskAug(mask_ratio)], p=random_apply_chance),
            self.utility_transforms])

    def train_dataloader(self):
        return DataLoader(dataset=self.datasets[DATA_OPTION_train], num_workers=self.num_worker, pin_memory=True,
                          sampler=self.samplers[DATA_OPTION_train], batch_size=self.batch_size)

    # Validation Dataloader
    def val_dataloader(self):
        return DataLoader(dataset=self.datasets[DATA_OPTION_devel], shuffle=False,
                          batch_size=self.batch_size, pin_memory=False,
                          num_workers=self.num_worker)

    # Test Dataloader
    def test_dataloader(self):
        return DataLoader(dataset=self.datasets[DATA_OPTION_test], shuffle=False,
                          batch_size=self.batch_size, pin_memory=False,
                          num_workers=self.num_worker)

    def _build_subdataset(self, row, build=False):
        slice_file_name, class_name = row.strip().split(',')
        class_id = self.class_names.get(class_name, -1)
        audio_file_path = self.wav_folder / slice_file_name

        # DATA OPTION DIFFERENTIATION !!!!!!!!!!! - Begin
        kwargs = self.__dict__
        if any([x in slice_file_name for x in [DATA_OPTION_devel, DATA_OPTION_test]]):
            kwargs.update(mel_augmentations=self.utility_transforms)
        # DATA OPTION DIFFERENTIATION !!!!!!!!!!! - End

        target_frames = self.target_mel_length_in_seconds * self.mel_kwargs['sr']
        sample_segment_length = target_frames // self.mel_kwargs['hop_length'] + 1
        kwargs.update(sample_segment_len=sample_segment_length, sample_hop_len=sample_segment_length//2)
        mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, **kwargs)
        if build:
            assert mel_dataset.build_mel()
        return mel_dataset, class_id, slice_file_name

    def prepare_data(self, *args, **kwargs):
        datasets = dict()
        for data_option in data_options:
            with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f:
                # Exclude the header
                _ = next(f)
                all_rows = list(f)
            chunksize = len(all_rows) // max(self.num_worker, 1)
            dataset = list()
            with mp.Pool(processes=self.num_worker) as pool:

                from itertools import repeat
                results = pool.starmap_async(self._build_subdataset, zip(all_rows, repeat(True, len(all_rows))),
                                             chunksize=chunksize)
                for sub_dataset in results.get():
                    dataset.append(sub_dataset[0])
            datasets[data_option] = ConcatDataset(dataset)
        self.datasets = datasets
        return datasets

    def setup(self, stag=None):
        datasets = dict()
        samplers = dict()
        weights = dict()

        for data_option in data_options:
            with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f:
                # Exclude the header
                _ = next(f)
                all_rows = list(f)
            dataset = list()
            for row in all_rows:
                mel_dataset, class_id, _ = self._build_subdataset(row)
                dataset.append(mel_dataset)
            datasets[data_option] = ConcatDataset(dataset)

            # Build Weighted Sampler for train and val
            if data_option in [DATA_OPTION_train]:
                if self.sampler == EqualSampler.__name__:
                    class_idxs = [[idx for idx, (_, __, label) in enumerate(datasets[data_option]) if label == class_idx]
                                  for class_idx in range(len(self.class_names))
                                  ]
                    samplers[data_option] = EqualSampler(class_idxs)
                elif self.sampler == WeightedRandomSampler.__name__:
                    class_counts = defaultdict(lambda: 0)
                    for _, __, label in datasets[data_option]:
                        class_counts[label] += 1
                    len_largest_class = max(class_counts.values())

                    weights[data_option] = [1 / class_counts[x] for x in range(len(class_counts))]
                    ##############################################################################
                    weights[data_option] = [weights[data_option][datasets[data_option][i][-1]]
                                            for i in range(len(datasets[data_option]))]
                    samplers[data_option] = WeightedRandomSampler(weights[data_option],
                                                                  len_largest_class * len(self.class_names))
                else:
                    samplers[data_option] = None
        self.datasets = datasets
        self.samplers = samplers
        return datasets

    def purge(self):
        import shutil

        shutil.rmtree(self.mel_folder, ignore_errors=True)
        print('Mel Folder has been recursively deleted')
        print(f'Folder still exists: {self.mel_folder.exists()}')
        return not self.mel_folder.exists()