from argparse import Namespace

import torch
from torch import nn
from torch_geometric.data import Data
from torch_geometric.transforms import Compose, NormalizeScale, RandomFlip

from datasets.grid_clusters import GridClusters
from models._point_net_2 import _PointNetCore

from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
from utils.project_config import GlobalVar


class PointNet2GridClusters(BaseValMixin,
                            BaseTrainMixin,
                            BaseOptimizerMixin,
                            DatasetMixin,
                            BaseDataloadersMixin,
                            _PointNetCore
                            ):

    def training_step(self, batch_pos_x_n_y_c, batch_nb, *_, **__):
        data = self.batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c
        y = self(data)
        nll_main_loss = self.nll_loss(y.main_out, data.yl)
        nll_cluster_loss = self.nll_loss(y.grid_out, data.yc)
        nll_loss = nll_main_loss + nll_cluster_loss
        return dict(loss=nll_loss, log=dict(batch_nb=batch_nb),
                    nll_cluster_loss=nll_cluster_loss, nll_main_loss=nll_main_loss)

    def validation_step(self, batch_pos_x_n_y_c, batch_idx, *_, **__):
        data = self.batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c
        y = self(data)
        nll_main_loss = self.nll_loss(y.main_out, data.yl)
        nll_cluster_loss = self.nll_loss(y.grid_out, data.yc)
        nll_loss = nll_main_loss + nll_cluster_loss
        return dict(val_nll_loss=nll_loss, nll_cluster_loss=nll_cluster_loss, nll_main_loss=nll_main_loss,
                    batch_idx=batch_idx, y=y.main_out, batch_y=data.yl)

    def __init__(self, hparams):
        super(PointNet2GridClusters, self).__init__(hparams=hparams)

        # Dataset
        # =============================================================================
        self.dataset = self.build_dataset(GridClusters, setting='grid', sampling_k=self.params.sampling_k)

        # Model Paramters
        # =============================================================================
        # Additional parameters
        self.n_classes = len(GlobalVar.classes)

        # Modules
        self.point_lin = torch.nn.Linear(128, len(GlobalVar.classes))
        self.grid_lin = torch.nn.Linear(128, GlobalVar.grid_count)

        # Utility
        self.log_softmax = nn.LogSoftmax(dim=-1)

    def forward(self, data, **kwargs):
        """
        data: a batch of input torch_geometric.data.Data type
            - torch_geometric.data.Data, as torch_geometric batch input:
                data.x: (batch_size * ~num_points, C), batch nodes/points feature,
                    ~num_points means each sample can have different number of points/nodes

                data.pos: (batch_size * ~num_points, 3)

                data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud
                    idendifiers for all nodes of all graphs/pointclouds in the batch. See
                    pytorch_gemometric documentation for more information
        """

        sa0_out = (data.norm, data.pos, data.batch)
        tensor = super(PointNet2GridClusters, self).forward(sa0_out)
        point_tensor = self.point_lin(tensor)
        point_tensor = self.log_softmax(point_tensor)
        grid_tensor = self.grid_lin(tensor)
        grid_tensor = self.log_softmax(grid_tensor)
        return Namespace(main_out=point_tensor, grid_out=grid_tensor)