import shelve
from collections import defaultdict
from pathlib import Path
from typing import Union

from torchvision.transforms import Normalize

import multiprocessing as mp

import torch
from random import choice
from torch.utils.data import ConcatDataset, Dataset
import numpy as np
from tqdm import tqdm

from lib.objects.map import Map
import lib.variables as V
from PIL import Image

from lib.utils.tools import write_to_shelve


class TrajDataShelve(Dataset):

    @property
    def map_shape(self):
        return self[0][0].shape

    def __init__(self, file_path, **kwargs):
        assert Path(file_path).exists()
        super(TrajDataShelve, self).__init__()
        self._mutex = mp.Lock()
        self.file_path = str(file_path)

    def __len__(self):
        self._mutex.acquire()
        with shelve.open(self.file_path) as d:
            length = len(d)
            d.close()
        self._mutex.release()
        return length

    def seed(self):
        pass

    def __getitem__(self, item):
        self._mutex.acquire()
        with shelve.open(self.file_path) as d:
            sample = d[str(item)]
            d.close()
        self._mutex.release()
        return sample


class TrajDataset(Dataset):

    @property
    def _last_label_init(self):
        d = defaultdict(lambda: -1)
        d['generator_hom_all_in_map'] = V.ALTERNATIVE
        d['generator_alt_all_in_map'] = V.HOMOTOPIC
        return d[self.mode]

    @property
    def map_shape(self):
        return self.map.as_array.shape

    def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw', normalized=True,
                 length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False,
                 **kwargs):
        super(TrajDataset, self).__init__()
        assert mode.lower() in ['generator_all_in_map', 'generator_hom_all_in_map', 'generator_alt_all_in_map',
                                'ae_no_label_in_map',
                                'generator_alt_no_label_in_map', 'classifier_all_in_map', 'vae_no_label_in_map']
        self.normalize = Normalize(0.5, 0.5) if normalized else lambda x: x
        self.preserve_equal_samples = preserve_equal_samples
        self.mode = mode
        self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
        self.maps_root = maps_root
        self._len = length
        self.last_label = self._last_label_init

        self.map = Map.from_image(self.maps_root / self.mapname, embedding_size=embedding_size)

    def __len__(self):
        return self._len

    def __getitem__(self, item):

        if self.mode.lower() == 'just_route':
            raise NotImplementedError
            trajectory = self.map.get_random_trajectory()
            trajectory_space = trajectory.draw_in_array(self.map.shape)
            label = choice([0, 1])
            map_array = torch.as_tensor(self.map.as_array).float()
            return (map_array, trajectory_space), label

        # Produce an alternative.
        while True:
            trajectory = self.map.get_random_trajectory()
            alternative = self.map.generate_alternative(trajectory)
            label = self.map.are_homotopic(trajectory, alternative)
            if self.preserve_equal_samples and label == self.last_label:
                continue
            else:
                break

        self.last_label = label if self._last_label_init == V.ANY else self._last_label_init[self.mode]
        if 'in_map' in self.mode.lower():
            map_array = self.map.as_array
            trajectory = trajectory.draw_in_array(self.map_shape)
            alternative = alternative.draw_in_array(self.map_shape)
            label_as_array = np.full_like(map_array, label)

            if self.mode == 'generator_all_in_map':
                return np.concatenate((map_array, trajectory, label_as_array)), alternative
            elif self.mode in ['vae_no_label_in_map', 'ae_no_label_in_map']:
                return np.sum((map_array, trajectory, alternative), axis=0), 0
            elif self.mode in ['generator_alt_no_label_in_map', 'generator_hom_no_label_in_map']:
                return np.concatenate((map_array, trajectory)), alternative
            elif self.mode == 'classifier_all_in_map':
                return np.concatenate((map_array, trajectory, alternative)), label

        elif self.mode == '_vectors':
            raise NotImplementedError
            return trajectory.vertices, alternative.vertices, label, self.mapname

        raise ValueError(f'Mode was: {self.mode}')

    def seed(self, seed):
        self.map.seed(seed)


class TrajData(object):
    @property
    def map_shapes(self):
        return [dataset.map_shape for dataset in self.train_dataset.datasets]

    @property
    def map_shapes_max(self):
        shapes = self.map_shapes
        shape_list = list(map(max, zip(*shapes)))
        if '_all_in_map' in self.mode and not self.preprocessed:
            shape_list[0] += 2
        return shape_list

    @property
    def name(self):
        return self.__class__.__name__

    def __init__(self, map_root, length=100000, mode='separated_arrays', normalized=True, preprocessed=False, **_):
        self.preprocessed = preprocessed
        self.normalized = normalized
        self.mode = mode
        self.maps_root = Path(map_root)
        self.length = length
        self.test_dataset = self._load_datasets('test')
        self.val_dataset = self._load_datasets('val')
        self.train_dataset = self._load_datasets('train')

    def _load_datasets(self, dataset_type=''):

        map_files = list(self.maps_root.glob('*.bmp'))

        # find max image size among available maps:
        max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))

        if self.preprocessed:
            preprocessed_map_files = list(self.maps_root.glob('*.pik'))
            preprocessed_map_names = [p.name for p in preprocessed_map_files]
            datasets = []
            for map_file in map_files:
                equal_split = int(self.length // len(map_files)) or 5
                new_pik_name = f'{self.mode}_{map_file.name[:-4]}_{dataset_type}.pik'
                if dataset_type != 'train':
                    equal_split = max(int(equal_split * 0.01), 10)
                if not new_pik_name in preprocessed_map_names:
                    traj_dataset = TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
                                               mode=self.mode, embedding_size=max_map_size, normalized=self.normalized,
                                               preserve_equal_samples=True)
                    self.dump_n(map_file.parent / new_pik_name, traj_dataset, n=equal_split)

                dataset = TrajDataShelve(map_file.parent / new_pik_name)
                datasets.append(dataset)
            return ConcatDataset(datasets)

        # Set the equal split so that all maps are visited with the same frequency
        equal_split = int(self.length // len(map_files)) or 5
        return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
                                          mode=self.mode, embedding_size=max_map_size, normalized=self.normalized,
                                          preserve_equal_samples=True)
                              for map_file in map_files])

    def kill_em_all(self):
        for pik_file in self.maps_root.glob('*.pik'):
            pik_file.unlink()
            print(pik_file.name, ' was deleted.')
        print('Done.')

    def seed(self, seed):
        for dataset in [x.datasets for x in [self._train_dataset, self._test_dataset, self.val_dataset]]:
            dataset.seed(seed)

    def dump_n(self, file_path, traj_dataset: TrajDataset, n=100000):
        assert str(file_path).endswith('.pik')
        mutex = mp.Lock()
        for i in tqdm(range(n), total=n, desc=f'Generating {n} Samples'):
            sample = traj_dataset[i]
            mutex.acquire()
            write_to_shelve(file_path, sample)
            mutex.release()

        print(f'{n} samples successfully dumped to "{file_path}"!')

    def get_datasets(self):
        return self._train_dataset, self._val_dataset, self._test_dataset