from argparse import Namespace import torch from torch import nn from torch_geometric.transforms import Compose, RandomFlip, FixedPoints, RandomTranslate, NormalizeScale from datasets.shapenet import ShapeNetPartSegDataset from models._point_net_2 import _PointNetCore from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin from utils.project_settings import GlobalVar class PointNet2(BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, DatasetMixin, BaseDataloadersMixin, _PointNetCore ): def __init__(self, hparams): super(PointNet2, self).__init__(hparams=hparams) # Dataset # ============================================================================= # rot_max_angle = 15 trans_max_distance = 0.01 transforms = Compose( [ RandomFlip(0, p=0.8), FixedPoints(self.params.npoints), # This is not available with 6-dim cords # RandomRotate(rot_max_angle, 0), RandomRotate(rot_max_angle, 1), RandomRotate(rot_max_angle, 2), RandomTranslate(trans_max_distance), NormalizeScale() # NormalizePositions() ] ) # Dataset # ============================================================================= self.dataset = self.build_dataset(ShapeNetPartSegDataset, collate_per_segment=True, transform=transforms, cluster_type=self.params.cluster_type, refresh=self.params.refresh, poly_as_plane=self.params.poly_as_plane ) # Model Paramters # ============================================================================= # Additional parameters self.n_classes = len(GlobalVar.classes) if not self.params.poly_as_plane else (len(GlobalVar.classes) - 2) # Modules self.lin3 = torch.nn.Linear(128, self.n_classes) # Utility self.log_softmax = nn.LogSoftmax(dim=-1) def forward(self, data, **kwargs): """ data: a batch of input torch_geometric.data.Data type - torch_geometric.data.Data, as torch_geometric batch input: data.x: (batch_size * ~num_points, C), batch nodes/points feature, ~num_points means each sample can have different number of points/nodes data.pos: (batch_size * ~num_points, 3) data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud idendifiers for all nodes of all graphs/pointclouds in the batch. See pytorch_gemometric documentation for more information """ if not self.params.normals_as_cords: sa0_out = (data.norm, data.pos, data.batch) else: pos_cat_norm = torch.cat((data.pos, data.norm), dim=-1) sa0_out = (None, pos_cat_norm, data.batch) tensor = super(PointNet2, self).forward(sa0_out) tensor = self.lin3(tensor) tensor = self.log_softmax(tensor) return Namespace(main_out=tensor)