from argparse import Namespace

import torch
from torch import nn
from torch_geometric.transforms import Compose, RandomFlip, FixedPoints, RandomTranslate, NormalizeScale

from datasets.shapenet import ShapeNetPartSegDataset
from models._point_net_2 import _PointNetCore

from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin


class PointNet2(BaseValMixin,
                BaseTrainMixin,
                BaseOptimizerMixin,
                DatasetMixin,
                BaseDataloadersMixin,
                _PointNetCore
                ):

    def __init__(self, hparams):
        super(PointNet2, self).__init__(hparams=hparams)

        # Dataset
        # =============================================================================
        # rot_max_angle = 15
        trans_max_distance = 0.01
        transforms = Compose(
            [
                RandomFlip(0, p=0.8),
                FixedPoints(self.params.npoints),
                # This is not available with 6-dim cords
                # RandomRotate(rot_max_angle, 0), RandomRotate(rot_max_angle, 1), RandomRotate(rot_max_angle, 2),
                RandomTranslate(trans_max_distance),
                # NormalizeScale()
                # NormalizePositions()
            ]
        )

        # Dataset
        # =============================================================================
        self.dataset = self.build_dataset(ShapeNetPartSegDataset,
                                          transform=transforms,
                                          cluster_type=self.params.cluster_type,
                                          refresh=self.params.refresh,
                                          poly_as_plane=self.params.poly_as_plane
                                          )

        # Model Paramters
        # =============================================================================
        # Additional parameters
        self.n_classes = len(self.dataset.train_dataset.classes)

        # Modules
        self.lin3 = torch.nn.Linear(128, self.n_classes)

        # Utility
        self.log_softmax = nn.LogSoftmax(dim=-1)

    def forward(self, data, **kwargs):
        """
        data: a batch of input torch_geometric.data.Data type
            - torch_geometric.data.Data, as torch_geometric batch input:
                data.x: (batch_size * ~num_points, C), batch nodes/points feature,
                    ~num_points means each sample can have different number of points/nodes

                data.pos: (batch_size * ~num_points, 3)

                data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud
                    idendifiers for all nodes of all graphs/pointclouds in the batch. See
                    pytorch_gemometric documentation for more information
        """
        if not self.params.normals_as_cords:
            sa0_out = (data.norm, data.pos, data.batch)
        else:
            pos_cat_norm = torch.cat((data.pos, data.norm), dim=-1)
            sa0_out = (None, pos_cat_norm, data.batch)
        tensor = super(PointNet2, self).forward(sa0_out)
        tensor = self.lin3(tensor)
        tensor = self.log_softmax(tensor)
        return Namespace(main_out=tensor)