from argparse import Namespace import torch from torch import nn from datasets.shapenet import ShapeNetPartSegDataset from models._point_net_2 import _PointNetCore from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin from utils.project_config import GlobalVar class PointNet2(BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, DatasetMixin, BaseDataloadersMixin, _PointNetCore ): def __init__(self, hparams): super(PointNet2, self).__init__(hparams=hparams) # Dataset # ============================================================================= self.dataset = self.build_dataset(ShapeNetPartSegDataset, collate_per_segment=True, npoints=self.params.npoints) # Model Paramters # ============================================================================= # Additional parameters self.n_classes = len(GlobalVar.classes) # Modules self.point_net_core = () self.lin3 = torch.nn.Linear(128, len(GlobalVar.classes)) # Utility self.log_softmax = nn.LogSoftmax(dim=-1) def forward(self, data, **kwargs): """ data: a batch of input torch_geometric.data.Data type - torch_geometric.data.Data, as torch_geometric batch input: data.x: (batch_size * ~num_points, C), batch nodes/points feature, ~num_points means each sample can have different number of points/nodes data.pos: (batch_size * ~num_points, 3) data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud idendifiers for all nodes of all graphs/pointclouds in the batch. See pytorch_gemometric documentation for more information """ sa0_out = (data.norm, data.pos, data.batch) tensor = super(PointNet2, self).forward(sa0_out) tensor = self.lin3(tensor) tensor = self.log_softmax(tensor) return Namespace(main_out=tensor)