import pickle
from pathlib import Path
from typing import Union
from warnings import warn

from collections import defaultdict

import os
from torch.utils.data import Dataset
from tqdm import tqdm

import torch
from torch_geometric.data import InMemoryDataset
from torch_geometric.data import Data

from utils.project_settings import Classes, DataSplit, ClusterTypes


def save_names(name_list, path):
    with open(path, 'wb') as f:
        f.writelines(name_list)


class CustomShapeNet(InMemoryDataset):

    name = 'CustomShapeNet'

    def download(self):
        pass

    @property
    def categories(self):
        return {key: val for val, key in self.classes.items()}

    @property
    def modes(self):
        return {key: val for val, key in DataSplit().items()}

    @property
    def cluster_types(self):
        return {key: val for val, key in ClusterTypes().items()}

    @property
    def raw_dir(self):
        return self.root / 'raw'

    @property
    def raw_file_names(self):
        return [self.mode]

    @property
    def processed_dir(self):
        return self.root / 'processed'

    def __init__(self, root_dir, collate_per_segment=True, mode='train', transform=None, pre_filter=None,
                 pre_transform=None, refresh=False, cluster_type: Union[str, None] = '',
                 poly_as_plane=False):
        assert mode in self.modes.keys(), \
            f'"mode" must be one of {self.modes.keys()}'
        assert cluster_type in self.cluster_types.keys() or cluster_type is None, \
            f'"cluster_type" must be one of {self.cluster_types.keys()} or None, but was: {cluster_type}'

        # Set the Dataset Parameters
        self.cluster_type = cluster_type if cluster_type else 'pc'
        self.classes = Classes()
        self.poly_as_plane = poly_as_plane
        self.collate_per_segment = collate_per_segment
        self.mode = mode
        self.refresh = refresh

        root_dir = Path(root_dir)
        super(CustomShapeNet, self).__init__(root_dir, transform, pre_transform, pre_filter)
        self.data, self.slices = self._load_dataset()
        print("Initialized")

    @property
    def processed_file_names(self):
        return [f'{self.mode}.pt']

    def check_and_resolve_cloud_count(self):
        if (self.raw_dir / self.mode).exists():
            file_count = len([cloud for cloud in (self.raw_dir / self.mode).iterdir() if cloud.is_file()])

            if file_count:
                print(f'{file_count} files have been found....')
                return file_count
            else:
                warn(ResourceWarning("No raw pointclouds have been found. Was this intentional?"))
                return file_count
        warn(ResourceWarning("The raw data folder does not exist. Was this intentional?"))
        return -1

    @property
    def num_classes(self):
        return len(self.categories) if self.poly_as_plane else (len(self.categories) - 2)

    @property
    def class_map_all(self):
        return {0: 0,
                1: 1,
                2: None,
                3: 2,
                4: 3,
                5: None,
                6: 4,
                7: None
                }

    @property
    def class_map_poly_as_plane(self):
        return {0: 0,
                1: 1,
                2: None,
                3: 2,
                4: 2,
                5: None,
                6: 2,
                7: None
                }

    def _load_dataset(self):
        data, slices = None, None
        filepath = self.processed_paths[0]
        config_path = Path(filepath).parent / f'{self.mode}_params.ini'
        if config_path.exists() and not self.refresh and not self.mode == DataSplit().predict:
            with config_path.open('rb') as f:
                config = pickle.load(f)
            if config == self._build_config():
                pass
            else:
                print('The given data parameters seem to differ from the one used to process the dataset:')
                self.refresh = True
        if self.refresh:
            try:
                os.remove(filepath)
                try:
                    config_path.unlink()
                except FileNotFoundError:
                    pass
                print('Processed Location "Refreshed" (We deleted the Files)')
            except FileNotFoundError:
                print('The allready processed dataset was meant to be refreshed, but there was none...')
                print('continue processing')
                pass

        while True:
            try:
                data, slices = torch.load(filepath)
                print(f'{self.mode.title()}-Dataset Loaded')
                break
            except FileNotFoundError:
                status = self.check_and_resolve_cloud_count()
                if status in [0, -1]:
                    print(f'No dataset was loaded, status: {status}')
                    break
                self.process()
                continue
        if not self.mode == DataSplit().predict:
            config = self._build_config()
            with config_path.open('wb') as f:
                pickle.dump(config, f, pickle.HIGHEST_PROTOCOL)
        return data, slices

    def _build_config(self):
        conf_dict = {key: str(val) for key, val in self.__dict__.items() if '__' not in key and key not in [
            'classes', 'refresh', 'transform', 'data', 'slices'
        ]}
        return conf_dict

    def _pre_transform_and_filter(self, data):
        if self.pre_filter is not None and not self.pre_filter(data):
            data = self.pre_filter(data)
        if self.pre_transform is not None:
            data = self.pre_transform(data)
        return data

    def process(self, delimiter=' '):
        datasets = defaultdict(list)
        path_to_clouds = self.raw_dir / self.mode
        found_clouds = list(path_to_clouds.glob('*.xyz'))
        class_map = self.class_map_all if not self.poly_as_plane else self.class_map_poly_as_plane
        if len(found_clouds):
            for pointcloud in tqdm(found_clouds):
                if self.cluster_type not in pointcloud.name:
                    continue
                data = None

                with pointcloud.open('r') as f:
                    src = defaultdict(list)
                    # Iterate over all rows
                    for row in f:
                        if row != '':
                            vals = row.rstrip().split(delimiter)[None:None]
                            vals = [float(x) if x not in ['-nan(ind)', 'nan(ind)'] else 0 for x in vals]
                            if len(vals) < 6:
                                raise ValueError('Check the Input!!!!!!')
                            # Expand the values from the csv by fake labels if non are provided.
                            vals = vals + [0] * (8 - len(vals))
                            vals[-2] = float(class_map[int(vals[-2])])
                            src[vals[-1]].append(vals)

                # Switch from un-pickable Defaultdict to Standard Dict
                src = dict(src)

                # Transform the Dict[List] to Dict[torch.Tensor]
                for key, values in src.items():
                    src[key] = torch.tensor(values, dtype=torch.double).squeeze()

                # Screw the Sorting and make it a FullCloud rather than a seperated
                if not self.collate_per_segment:
                    src = dict(
                        all=torch.cat(tuple(src.values()))
                    )

                # Transform Box and Polytope to Plane if poly_as_plane is set
                for key, tensor in src.items():
                    if tensor.ndim == 1:
                        if all([x == 0 for x in tensor]):
                            continue
                        tensor = tensor.unsqueeze(0)
                    src[key] = tensor

                for key, values in src.items():
                    try:
                        points = values[:, :-2]
                    except IndexError:
                        continue
                    y = torch.as_tensor(values[:, -2], dtype=torch.long)
                    y_c = torch.as_tensor(values[:, -1], dtype=torch.long)
                    ####################################
                    # This is where you define the keys
                    attr_dict = dict(
                        y=y,
                        y_c=y_c,
                        pos=points[:, :3],
                        norm=points[:, 3:6]
                    )

                    ####################################
                    if self.collate_per_segment:
                        data = Data(**attr_dict)
                    else:
                        if data is None:
                            data = defaultdict(list)
                        for attr_key, val in attr_dict.items():
                            data[attr_key].append(val)

                    # data = self._pre_transform_and_filter(data)
                    if self.collate_per_segment:
                        datasets[self.mode].append(data)
                if not self.collate_per_segment:
                    datasets[self.mode].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))

        if datasets[self.mode]:
            os.makedirs(self.processed_dir, exist_ok=True)
            collated_dataset = self.collate(datasets[self.mode])
            torch.save(collated_dataset, self.processed_paths[0])

    def __repr__(self):
        return f'{self.__class__.__name__}({len(self)})'


class ShapeNetPartSegDataset(Dataset):
    """
    Resample raw point cloud to fixed number of points.
    Map raw label from range [1, N] to [0, N-1].
    """

    name = 'ShapeNetPartSegDataset'

    def __init__(self, root_dir, mode='train', **kwargs):
        super(ShapeNetPartSegDataset, self).__init__()
        self.mode = mode
        kwargs.update(dict(root_dir=root_dir, mode=self.mode))
        # self.npoints = npoints
        self.dataset = CustomShapeNet(**kwargs)

    def __getitem__(self, index):
        data = self.dataset[index]

        # Resample to fixed number of points
        '''
        try:
            npoints = self.npoints if self.mode != DataSplit.predict else data.pos.shape[0]
            choice = np.random.choice(data.pos.shape[0], npoints,
                                      replace=False if self.mode == DataSplit.predict else True
                                      )
        except ValueError:
            choice = []

        pos, norm, y = data.pos[choice, :], data.norm[choice], data.y[choice]   

        # y -= 1 if self.num_classes() in y else 0   # Map label from [1, C] to [0, C-1]

        data = Data(**dict(pos=pos,          # torch.Tensor (n, 3/6)
                           y=y,              # torch.Tensor (n,)
                           norm=norm         # torch.Tensor (n, 3/0)
                           )
                    )
        '''
        return data

    def __len__(self):
        return len(self.dataset)

    def num_classes(self):
        return self.dataset.num_classes