84 lines
3.4 KiB
Python
84 lines
3.4 KiB
Python
from argparse import Namespace
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch_geometric.transforms import Compose, RandomFlip, FixedPoints, RandomTranslate, NormalizeScale
|
|
|
|
from datasets.shapenet import ShapeNetPartSegDataset
|
|
from models._point_net_2 import _PointNetCore
|
|
|
|
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
|
|
from utils.project_settings import GlobalVar
|
|
|
|
|
|
class PointNet2(BaseValMixin,
|
|
BaseTrainMixin,
|
|
BaseOptimizerMixin,
|
|
DatasetMixin,
|
|
BaseDataloadersMixin,
|
|
_PointNetCore
|
|
):
|
|
|
|
def __init__(self, hparams):
|
|
super(PointNet2, self).__init__(hparams=hparams)
|
|
|
|
# Dataset
|
|
# =============================================================================
|
|
# rot_max_angle = 15
|
|
trans_max_distance = 0.01
|
|
transforms = Compose(
|
|
[
|
|
RandomFlip(0, p=0.8),
|
|
FixedPoints(self.params.npoints),
|
|
# This is not available with 6-dim cords
|
|
# RandomRotate(rot_max_angle, 0), RandomRotate(rot_max_angle, 1), RandomRotate(rot_max_angle, 2),
|
|
RandomTranslate(trans_max_distance),
|
|
NormalizeScale()
|
|
# NormalizePositions()
|
|
]
|
|
)
|
|
|
|
# Dataset
|
|
# =============================================================================
|
|
self.dataset = self.build_dataset(ShapeNetPartSegDataset,
|
|
collate_per_segment=True,
|
|
transform=transforms,
|
|
cluster_type=self.params.cluster_type,
|
|
refresh=self.params.refresh,
|
|
poly_as_plane=self.params.poly_as_plane
|
|
)
|
|
|
|
# Model Paramters
|
|
# =============================================================================
|
|
# Additional parameters
|
|
self.n_classes = len(GlobalVar.classes) if not self.params.poly_as_plane else (len(GlobalVar.classes) - 2)
|
|
|
|
# Modules
|
|
self.lin3 = torch.nn.Linear(128, self.n_classes)
|
|
|
|
# Utility
|
|
self.log_softmax = nn.LogSoftmax(dim=-1)
|
|
|
|
def forward(self, data, **kwargs):
|
|
"""
|
|
data: a batch of input torch_geometric.data.Data type
|
|
- torch_geometric.data.Data, as torch_geometric batch input:
|
|
data.x: (batch_size * ~num_points, C), batch nodes/points feature,
|
|
~num_points means each sample can have different number of points/nodes
|
|
|
|
data.pos: (batch_size * ~num_points, 3)
|
|
|
|
data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud
|
|
idendifiers for all nodes of all graphs/pointclouds in the batch. See
|
|
pytorch_gemometric documentation for more information
|
|
"""
|
|
if not self.params.normals_as_cords:
|
|
sa0_out = (data.norm, data.pos, data.batch)
|
|
else:
|
|
pos_cat_norm = torch.cat((data.pos, data.norm), dim=-1)
|
|
sa0_out = (None, pos_cat_norm, data.batch)
|
|
tensor = super(PointNet2, self).forward(sa0_out)
|
|
tensor = self.lin3(tensor)
|
|
tensor = self.log_softmax(tensor)
|
|
return Namespace(main_out=tensor)
|