72 lines
2.8 KiB
Python
72 lines
2.8 KiB
Python
from abc import ABC
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch_geometric.transforms import Compose, NormalizeScale, RandomFlip
|
|
|
|
from ml_lib.modules.geometric_blocks import SAModule, GlobalSAModule, MLP, FPModule
|
|
from ml_lib.modules.util import LightningBaseModule, F_x
|
|
from ml_lib.point_toolset.point_io import BatchToData
|
|
|
|
|
|
class _PointNetCore(LightningBaseModule, ABC):
|
|
|
|
def __init__(self, hparams):
|
|
super(_PointNetCore, self).__init__(hparams=hparams)
|
|
|
|
# Transforms
|
|
# =============================================================================
|
|
self.batch_to_data = BatchToData(transforms=None)
|
|
|
|
# Model Paramters
|
|
# =============================================================================
|
|
# Additional parameters
|
|
self.cord_dims = 6 if self.params.normals_as_cords else 3
|
|
|
|
# Modules
|
|
self.sa1_module = SAModule(0.2, 0.2, MLP([3 + 3, 64, 64, 128]))
|
|
self.sa2_module = SAModule(0.25, 0.4, MLP([128 + self.cord_dims, 128, 128, 256]))
|
|
self.sa3_module = GlobalSAModule(MLP([256 + self.cord_dims, 256, 512, 1024]), channels=self.cord_dims)
|
|
|
|
self.fp3_module = FPModule(1, MLP([1024 + 256, 256, 256]))
|
|
self.fp2_module = FPModule(3, MLP([256 + 128, 256, 128]))
|
|
self.fp1_module = FPModule(3, MLP([128 + (3 if not self.params.normals_as_cords else 0), 128, 128, 128]))
|
|
|
|
self.lin1 = torch.nn.Linear(128, 128)
|
|
self.lin2 = torch.nn.Linear(128, 128)
|
|
|
|
# Utility
|
|
self.dropout = nn.Dropout(self.params.dropout) if self.params.dropout else F_x(None)
|
|
self.activation = self.params.activation()
|
|
|
|
def forward(self, sa0_out, **kwargs):
|
|
"""
|
|
sa0_out: a batch of input torch_geometric.data.Data type
|
|
- torch_geometric.data.Data, as torch_geometric batch input:
|
|
data.x: (batch_size * ~num_points, C), batch nodes/points feature,
|
|
~num_points means each sample can have different number of points/nodes
|
|
|
|
data.pos: (batch_size * ~num_points, 3)
|
|
|
|
data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud
|
|
idendifiers for all nodes of all graphs/pointclouds in the batch. See
|
|
pytorch_gemometric documentation for more information
|
|
"""
|
|
|
|
sa1_out = self.sa1_module(*sa0_out)
|
|
sa2_out = self.sa2_module(*sa1_out)
|
|
sa3_out = self.sa3_module(*sa2_out)
|
|
|
|
fp3_out = self.fp3_module(*sa3_out, *sa2_out)
|
|
fp2_out = self.fp2_module(*fp3_out, *sa1_out)
|
|
tensor, _, _ = self.fp1_module(*fp2_out, *sa0_out)
|
|
|
|
tensor = tensor.float()
|
|
|
|
tensor = self.activation(tensor)
|
|
tensor = self.lin1(tensor)
|
|
tensor = self.dropout(tensor)
|
|
tensor = self.lin2(tensor)
|
|
tensor = self.dropout(tensor)
|
|
return tensor
|