from argparse import Namespace import torch from torch import nn from datasets.full_pointclouds import FullCloudsDataset from models._point_net_2 import _PointNetCore from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin from utils.project_config import GlobalVar class PointNet2PrimClusters(BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, DatasetMixin, BaseDataloadersMixin, _PointNetCore ): def __init__(self, hparams): super(PointNet2PrimClusters, self).__init__(hparams=hparams) # Dataset # ============================================================================= self.dataset = self.build_dataset(FullCloudsDataset, setting='prim') # Model Paramters # ============================================================================= # Additional parameters # Modules self.point_lin = torch.nn.Linear(128, len(GlobalVar.classes)) self.prim_lin = torch.nn.Linear(128, len(GlobalVar.prims)) # Utility self.log_softmax = nn.LogSoftmax(dim=-1) def forward(self, data, **kwargs): """ data: a batch of input torch_geometric.data.Data type - torch_geometric.data.Data, as torch_geometric batch input: data.x: (batch_size * ~num_points, C), batch nodes/points feature, ~num_points means each sample can have different number of points/nodes data.pos: (batch_size * ~num_points, 3) data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud idendifiers for all nodes of all graphs/pointclouds in the batch. See pytorch_gemometric documentation for more information """ sa0_out = (data.x, data.pos, data.batch) tensor = super(PointNet2PrimClusters, self).forward(sa0_out) point_tensor = self.point_lin(tensor) point_tensor = self.log_softmax(point_tensor) prim_tensor = self.prim_lin(tensor) prim_tensor = self.log_softmax(prim_tensor) return Namespace(main_out=point_tensor, prim_out=prim_tensor)