point_to_primitive/models/point_net_2.py
2020-07-03 14:40:28 +02:00

82 lines
3.2 KiB
Python

from argparse import Namespace
import torch
from torch import nn
from torch_geometric.transforms import Compose, RandomFlip, FixedPoints, RandomTranslate, NormalizeScale
from datasets.shapenet import ShapeNetPartSegDataset
from models._point_net_2 import _PointNetCore
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
class PointNet2(BaseValMixin,
BaseTrainMixin,
BaseOptimizerMixin,
DatasetMixin,
BaseDataloadersMixin,
_PointNetCore
):
def __init__(self, hparams):
super(PointNet2, self).__init__(hparams=hparams)
# Dataset
# =============================================================================
# rot_max_angle = 15
trans_max_distance = 0.01
transforms = Compose(
[
RandomFlip(0, p=0.8),
FixedPoints(self.params.npoints),
# This is not available with 6-dim cords
# RandomRotate(rot_max_angle, 0), RandomRotate(rot_max_angle, 1), RandomRotate(rot_max_angle, 2),
RandomTranslate(trans_max_distance),
# NormalizeScale()
# NormalizePositions()
]
)
# Dataset
# =============================================================================
self.dataset = self.build_dataset(ShapeNetPartSegDataset,
transform=transforms,
cluster_type=self.params.cluster_type,
refresh=self.params.refresh,
poly_as_plane=self.params.poly_as_plane
)
# Model Paramters
# =============================================================================
# Additional parameters
self.n_classes = len(self.dataset.train_dataset.classes)
# Modules
self.lin3 = torch.nn.Linear(128, self.n_classes)
# Utility
self.log_softmax = nn.LogSoftmax(dim=-1)
def forward(self, data, **kwargs):
"""
data: a batch of input torch_geometric.data.Data type
- torch_geometric.data.Data, as torch_geometric batch input:
data.x: (batch_size * ~num_points, C), batch nodes/points feature,
~num_points means each sample can have different number of points/nodes
data.pos: (batch_size * ~num_points, 3)
data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud
idendifiers for all nodes of all graphs/pointclouds in the batch. See
pytorch_gemometric documentation for more information
"""
if not self.params.normals_as_cords:
sa0_out = (data.norm, data.pos, data.batch)
else:
pos_cat_norm = torch.cat((data.pos, data.norm), dim=-1)
sa0_out = (None, pos_cat_norm, data.batch)
tensor = super(PointNet2, self).forward(sa0_out)
tensor = self.lin3(tensor)
tensor = self.log_softmax(tensor)
return Namespace(main_out=tensor)