from argparse import Namespace import torch from torch import nn from torch_geometric.data import Data from torch_geometric.transforms import Compose, NormalizeScale, RandomFlip from datasets.grid_clusters import GridClusters from models._point_net_2 import _PointNetCore from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin from utils.project_config import GlobalVar class PointNet2GridClusters(BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, DatasetMixin, BaseDataloadersMixin, _PointNetCore ): def training_step(self, batch_pos_x_n_y_c, batch_nb, *_, **__): data = self.batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c y = self(data) nll_main_loss = self.nll_loss(y.main_out, data.yl) nll_cluster_loss = self.nll_loss(y.grid_out, data.yc) nll_loss = nll_main_loss + nll_cluster_loss return dict(loss=nll_loss, log=dict(batch_nb=batch_nb), nll_cluster_loss=nll_cluster_loss, nll_main_loss=nll_main_loss) def validation_step(self, batch_pos_x_n_y_c, batch_idx, *_, **__): data = self.batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c y = self(data) nll_main_loss = self.nll_loss(y.main_out, data.yl) nll_cluster_loss = self.nll_loss(y.grid_out, data.yc) nll_loss = nll_main_loss + nll_cluster_loss return dict(val_nll_loss=nll_loss, nll_cluster_loss=nll_cluster_loss, nll_main_loss=nll_main_loss, batch_idx=batch_idx, y=y.main_out, batch_y=data.yl) def __init__(self, hparams): super(PointNet2GridClusters, self).__init__(hparams=hparams) # Dataset # ============================================================================= self.dataset = self.build_dataset(GridClusters, setting='grid', sampling_k=self.params.sampling_k) # Model Paramters # ============================================================================= # Additional parameters self.n_classes = len(GlobalVar.classes) # Modules self.point_lin = torch.nn.Linear(128, len(GlobalVar.classes)) self.grid_lin = torch.nn.Linear(128, GlobalVar.grid_count) # Utility self.log_softmax = nn.LogSoftmax(dim=-1) def forward(self, data, **kwargs): """ data: a batch of input torch_geometric.data.Data type - torch_geometric.data.Data, as torch_geometric batch input: data.x: (batch_size * ~num_points, C), batch nodes/points feature, ~num_points means each sample can have different number of points/nodes data.pos: (batch_size * ~num_points, 3) data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud idendifiers for all nodes of all graphs/pointclouds in the batch. See pytorch_gemometric documentation for more information """ sa0_out = (data.norm, data.pos, data.batch) tensor = super(PointNet2GridClusters, self).forward(sa0_out) point_tensor = self.point_lin(tensor) point_tensor = self.log_softmax(point_tensor) grid_tensor = self.grid_lin(tensor) grid_tensor = self.log_softmax(grid_tensor) return Namespace(main_out=point_tensor, grid_out=grid_tensor)