import shelve
from pathlib import Path
from typing import Union, List

import torch
from random import choice
from torch.utils.data import ConcatDataset, Dataset
import numpy as np

from lib.objects.map import Map
import lib.variables as V
from PIL import Image


class TrajDataset(Dataset):

    @property
    def map_shape(self):
        return self.map.as_array.shape

    def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw', normalized=True,
                 length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False,
                 **kwargs):
        super(TrajDataset, self).__init__()
        assert mode.lower() in ['vectors', 'all_in_map', 'separated_arrays', 'just_route']
        self.normalized = normalized
        self.preserve_equal_samples = preserve_equal_samples
        self.mode = mode
        self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
        self.maps_root = maps_root
        self._len = length
        self.last_label = -1

        self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size)

    def __len__(self):
        return self._len

    def __getitem__(self, item):

        if self.mode.lower() == 'just_route':
            trajectory = self.map.get_random_trajectory()
            trajectory_space = trajectory.draw_in_array(self.map.shape)
            label = choice([0, 1])
            map_array = torch.as_tensor(self.map.as_array).float()
            return (map_array, trajectory_space), label

        while True:
            trajectory = self.map.get_random_trajectory()
            alternative = self.map.generate_alternative(trajectory)
            label = self.map.are_homotopic(trajectory, alternative)
            if self.preserve_equal_samples and label == self.last_label:
                continue
            else:
                break

        self.last_label = label
        if self.mode.lower() in ['all_in_map', 'separated_arrays']:
            map_array = self.map.as_array
            trajectory = trajectory.draw_in_array(self.map_shape)
            alternative = alternative.draw_in_array(self.map_shape)
            if self.mode == 'separated_arrays':
                if self.normalized:
                    map_array = map_array / V.WHITE
                    trajectory = trajectory / V.WHITE
                    alternative = alternative / V.WHITE
                return (map_array, trajectory, label), alternative
            else:
                return np.concatenate((map_array, trajectory, alternative)), label

        elif self.mode == 'vectors':
            return trajectory.vertices, alternative.vertices, label, self.mapname

        else:
            raise ValueError


class TrajData(object):
    @property
    def map_shapes(self):
        return [dataset.map_shape for dataset in self._dataset.datasets]

    @property
    def map_shapes_max(self):
        shapes = self.map_shapes
        shape_list = list(map(max, zip(*shapes)))
        if self.mode in ['separated_arrays', 'all_in_map']:
            shape_list[0] += 2
        return shape_list

    @property
    def name(self):
        return self.__class__.__name__

    def __init__(self, map_root, length=100000, mode='separated_arrays', normalized=True, **_):

        self.normalized = normalized
        self.mode = mode
        self.maps_root = Path(map_root)
        self.length = length
        self._dataset = self._load_datasets()

    def _load_datasets(self):
        map_files = list(self.maps_root.glob('*.bmp'))
        equal_split = int(self.length // len(map_files)) or 1

        # find max image size among available maps:
        max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
        return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
                                          mode=self.mode, embedding_size=max_map_size, normalized=self.normalized,
                                          preserve_equal_samples=True)
                              for map_file in map_files])

    @property
    def train_dataset(self):
        return self._dataset

    @property
    def val_dataset(self):
        return self._dataset

    @property
    def test_dataset(self):
        return self._dataset

    def get_datasets(self):
        return self._dataset, self._dataset, self._dataset