import numpy as np

from collections import defaultdict

import os
from tqdm import tqdm
import glob

import torch
from torch_geometric.data import InMemoryDataset
from torch_geometric.data import Data
from torch.utils.data import Dataset


class CustomShapeNet(InMemoryDataset):

    categories = {key: val for val, key in enumerate(['Box', 'Cone', 'Cylinder', 'Sphere'])}

    def __init__(self, root, collate_per_segment=True, train=True, transform=None, pre_filter=None, pre_transform=None,
                 headers=True, **kwargs):
        self.has_headers = headers
        self.collate_per_element = collate_per_segment
        self.train = train
        super(CustomShapeNet, self).__init__(root, transform, pre_transform, pre_filter)
        path = self.processed_paths[0] if train else self.processed_paths[-1]
        self.data, self.slices = torch.load(path)
        print("Initialized")

    @property
    def raw_file_names(self):
        # Maybe add more data like validation sets
        return ['train', 'test']

    @property
    def processed_file_names(self):
        return [f'{x}.pt' for x in self.raw_file_names]

    def download(self):
        dir_count = len([name for name in os.listdir(self.raw_dir) if os.path.isdir(os.path.join(self.raw_dir, name))])
        print(f'{dir_count} folders have been found....')
        if dir_count:
            return dir_count
        raise IOError("No raw pointclouds have been found.")

    @property
    def num_classes(self):
        return len(self.categories)

    def _load_dataset(self):
        data, slices = None, None
        while True:
            try:
                filepath = os.path.join(self.root, self.processed_dir, f'{"train" if self.train else "test"}.pt')
                data, slices = torch.load(filepath)
                print('Dataset Loaded')
                break
            except FileNotFoundError:
                self.process()
                continue
        return data, slices

    def _transform_and_filter(self, data):
        # ToDo: ANy filter to apply? Then do it here.
        if self.pre_filter is not None and not self.pre_filter(data):
            data = self.pre_filter(data)
            raise NotImplementedError
        # ToDo: ANy transformation to apply? Then do it here.
        if self.pre_transform is not None:
            data = self.pre_transform(data)
            raise NotImplementedError
        return data

    def process(self, delimiter=' '):
        datasets = defaultdict(list)
        idx, data_folder = (0, self.raw_file_names[0]) if self.train else (1, self.raw_file_names[1])
        path_to_clouds = os.path.join(self.raw_dir, data_folder)

        if '.headers' in os.listdir(path_to_clouds):
            self.has_headers = True
        elif 'no.headers' in os.listdir(path_to_clouds):
            self.has_headers = False
        else:
            pass

        for pointcloud in tqdm(os.scandir(path_to_clouds)):
            if not os.path.isdir(pointcloud):
                continue
            data, paths = None, list()
            for ext in ['dat', 'xyz']:
                paths.extend(glob.glob(os.path.join(pointcloud.path, f'*.{ext}')))

            for element in paths:
                if all([x not in os.path.split(element)[-1] for x in ['pc.dat', 'pc.xyz']]):
                    # Assign training data to the data container
                    # Following the original logic;
                    # y should be the label;
                    # pos should be the six dimensional vector describing: !its pos not points!!
                    # x,y,z,x_rot,y_rot,z_rot

                    # Get the y - Label
                    y_raw = next(i for i, v in enumerate(self.categories.keys()) if v.lower() in element.lower())
                    # y_raw = os.path.splitext(element)[0].split('_')[-2]
                    with open(element,'r') as f:
                        if self.has_headers:
                            headers = f.__next__()
                            # Check if there are no useable nodes in this file, header says 0.
                            if not int(headers.rstrip().split(delimiter)[0]):
                                continue

                        # Iterate over all rows
                        src = [[float(x) if x not in ['-nan(ind)', 'nan(ind)'] else 0
                                for x in line.rstrip().split(delimiter)[None:None]] for line in f if line != '']
                    points = torch.tensor(src, dtype=None).squeeze()
                    if not len(points.shape) > 1:
                        continue
                    # pos = points[:, :3]
                    # norm = points[:, 3:]
                    y_all = [y_raw] * points.shape[0]
                    y = torch.as_tensor(y_all, dtype=torch.int)
                    # points = torch.as_tensor(points, dtype=torch.float)
                    # norm = torch.as_tensor(norm, dtype=torch.float)
                    if self.collate_per_element:
                        data = Data(y=y, pos=points[:, :3])
                    else:
                        if not data:
                            data = defaultdict(list)
                        for key, val in dict(y=y, pos= points[:, :3]).items():
                            data[key].append(val)
                    # , points=points, norm=points[:3], )
                    data = self._transform_and_filter(data)
                    if self.collate_per_element:
                        datasets[data_folder].append(data)
            if not self.collate_per_element:
                datasets[data_folder].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))

        if datasets[data_folder]:
            os.makedirs(self.processed_dir, exist_ok=True)
            torch.save(self.collate(datasets[data_folder]), self.processed_paths[idx])

    def __repr__(self):
        return f'{self.__class__.__name__}({len(self)})'


class ShapeNetPartSegDataset(Dataset):
    """
    Resample raw point cloud to fixed number of points.
    Map raw label from range [1, N] to [0, N-1].
    """
    def __init__(self, root_dir, collate_per_segment=True, train=True, transform=None, npoints=1024, headers=True):
        super(ShapeNetPartSegDataset, self).__init__()
        self.npoints = npoints
        self.dataset = CustomShapeNet(root=root_dir, collate_per_segment=collate_per_segment,
                                      train=train, transform=transform, headers=headers)

    def __getitem__(self, index):
        data = self.dataset[index]
        points, labels = data.pos, data.y

        # Resample to fixed number of points
        try:
            choice = np.random.choice(points.shape[0], self.npoints, replace=True)
        except ValueError:
            choice = []

        points, labels = points[choice, :], labels[choice]

        labels -= 1 if self.num_classes() in labels else 0   # Map label from [1, C] to [0, C-1]

        sample = {
            'points': points,  # torch.Tensor (n, 3)
            'labels': labels  # torch.Tensor (n,)
        }

        return sample

    def __len__(self):
        return len(self.dataset)

    def num_classes(self):
        return self.dataset.num_classes


class PredictionShapeNet(InMemoryDataset):
    categories = {key: val for val, key in enumerate(['Box', 'Cone', 'Cylinder', 'Sphere'])}

    def __init__(self, root, transform=None, pre_filter=None, pre_transform=None,
                 headers=True, **kwargs):
        self.has_headers = headers
        super(PredictionShapeNet, self).__init__(root, transform, pre_transform, pre_filter)
        path = self.processed_paths[0]
        self.data, self.slices = torch.load(path)
        print("Initialized")

    @property
    def raw_file_names(self):
        # Maybe add more data like validation sets
        return ['predict']

    @property
    def processed_file_names(self):
        return [f'{x}.pt' for x in self.raw_file_names]

    def download(self):
        dir_count = len([name for name in os.listdir(self.raw_dir) if os.path.isdir(os.path.join(self.raw_dir, name))])
        print(f'{dir_count} folders have been found....')
        if dir_count:
            return dir_count
        raise IOError("No raw pointclouds have been found.")

    @property
    def num_classes(self):
        return len(self.categories)

    def _load_dataset(self):
        data, slices = None, None
        while True:
            try:
                filepath = os.path.join(self.root, self.processed_dir, f'{"train" if self.train else "test"}.pt')
                data, slices = torch.load(filepath)
                print('Dataset Loaded')
                break
            except FileNotFoundError:
                self.process()
                continue
        return data, slices

    def process(self, delimiter=' '):

        datasets = defaultdict(list)
        for idx, setting in enumerate(self.raw_file_names):
            path_to_clouds = os.path.join(self.raw_dir, setting)

            if '.headers' in os.listdir(path_to_clouds):
                self.has_headers = True
            elif 'no.headers' in os.listdir(path_to_clouds):
                self.has_headers = False
            else:
                pass

            for pointcloud in tqdm(os.scandir(path_to_clouds)):
                if not os.path.isdir(pointcloud):
                    continue
                for extention in ['dat', 'xyz']:
                    file = os.path.join(pointcloud.path, f'pc.{extention}')
                    if not os.path.exists(file):
                        continue
                    with open(file, 'r') as f:
                            if self.has_headers:
                                headers = f.__next__()
                                # Check if there are no useable nodes in this file, header says 0.
                                if not int(headers.rstrip().split(delimiter)[0]):
                                    continue

                            # Iterate over all rows
                            src = [[float(x) if x not in ['-nan(ind)', 'nan(ind)'] else 0
                                    for x in line.rstrip().split(delimiter)[None:None]] for line in f if line != '']
                    points = torch.tensor(src, dtype=None).squeeze()
                    if not len(points.shape) > 1:
                        continue
                    # pos = points[:, :3]
                    # norm = points[:, 3:]
                    y_fake_all = [-1] * points.shape[0]
                    y = torch.as_tensor(y_fake_all, dtype=torch.int)
                    # points = torch.as_tensor(points, dtype=torch.float)
                    # norm = torch.as_tensor(norm, dtype=torch.float)
                    data = Data(y=y, pos=points[:, :3])
                    # , points=points, norm=points[:3], )
                    # ToDo: ANy filter to apply? Then do it here.
                    if self.pre_filter is not None and not self.pre_filter(data):
                        data = self.pre_filter(data)
                        raise NotImplementedError
                    # ToDo: ANy transformation to apply? Then do it here.
                    if self.pre_transform is not None:
                        data = self.pre_transform(data)
                        raise NotImplementedError
                    datasets[setting].append(data)

            os.makedirs(self.processed_dir, exist_ok=True)
            torch.save(self.collate(datasets[setting]), self.processed_paths[idx])

    def __repr__(self):
        return f'{self.__class__.__name__}({len(self)})'


class PredictNetPartSegDataset(Dataset):
    """
    Resample raw point cloud to fixed number of points.
    Map raw label from range [1, N] to [0, N-1].
    """
    def __init__(self, root_dir, train=False, transform=None, npoints=2048, headers=True, collate_per_segment=False):
        super(PredictNetPartSegDataset, self).__init__()
        self.npoints = npoints
        self.dataset = PredictionShapeNet(root=root_dir, train=train, transform=transform,
                                          headers=headers, collate_per_segment=collate_per_segment)

    def __getitem__(self, index):
        data = self.dataset[index]
        points, labels = data.pos, data.y

        # Resample to fixed number of points
        try:
            choice = np.random.choice(points.shape[0], self.npoints, replace=True)
        except ValueError:
            choice = []

        points, labels = points[choice, :], labels[choice]

        labels -= 1 if self.num_classes() in labels else 0   # Map label from [1, C] to [0, C-1]

        sample = {
            'points': points,  # torch.Tensor (n, 3)
            'labels': labels  # torch.Tensor (n,)
        }

        return sample

    def __len__(self):
        return len(self.dataset)

    def num_classes(self):
        return self.dataset.num_classes