Labeling-Primitives-With-Po.../model/pointnet2_part_seg.py
2019-08-09 10:25:16 +02:00

288 lines
9.8 KiB
Python

import torch
import torch.nn.functional as F
from torch.nn import Sequential as Seq, Linear as Lin, ReLU, Dropout, BatchNorm1d
from torch_geometric.nn import PointConv, fps, radius, knn
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_geometric.data.data import Data
from torch_scatter import scatter_add, scatter_max
GLOBAL_POINT_FEATURES = 6
class PointNet2SAModule(torch.nn.Module):
def __init__(self, sample_radio, radius, max_num_neighbors, mlp):
super(PointNet2SAModule, self).__init__()
self.sample_ratio = sample_radio
self.radius = radius
self.max_num_neighbors = max_num_neighbors
self.point_conv = PointConv(mlp)
def forward(self, data):
x, pos, batch = data
# Sample
idx = fps(pos, batch, ratio=self.sample_ratio)
# Group(Build graph)
row, col = radius(pos, pos[idx], self.radius, batch, batch[idx], max_num_neighbors=self.max_num_neighbors)
edge_index = torch.stack([col, row], dim=0)
# Apply pointnet
x1 = self.point_conv(x, (pos, pos[idx]), edge_index)
pos1, batch1 = pos[idx], batch[idx]
return x1, pos1, batch1
class PointNet2GlobalSAModule(torch.nn.Module):
'''
One group with all input points, can be viewed as a simple PointNet module.
It also return the only one output point(set as origin point).
'''
def __init__(self, mlp):
super(PointNet2GlobalSAModule, self).__init__()
self.mlp = mlp
def forward(self, data):
x, pos, batch = data
if x is not None: x = torch.cat([x, pos], dim=1)
x1 = self.mlp(x)
x1 = scatter_max(x1, batch, dim=0)[0] # (batch_size, C1)
batch_size = x1.shape[0]
pos1 = x1.new_zeros((batch_size, GLOBAL_POINT_FEATURES)) # set the output point as origin
batch1 = torch.arange(batch_size).to(batch.device, batch.dtype)
return x1, pos1, batch1
class PointConvFP(MessagePassing):
'''
Core layer of Feature propagtaion module.
'''
def __init__(self, mlp=None):
super(PointConvFP, self).__init__('add', 'source_to_target')
self.mlp = mlp
self.aggr = 'add'
self.flow = 'source_to_target'
self.reset_parameters()
def reset_parameters(self):
reset(self.mlp)
def forward(self, x, pos, edge_index):
r"""
Args:
x (tuple), (tensor, tensor) or (tensor, NoneType)
pos (tuple): The node position matrix. Either given as
tensor for use in general message passing or as tuple for use
in message passing in bipartite graphs.
edge_index (LongTensor): The edge indices.
"""
# Do not pass (tensor, None) directly into propagate(), sice it will check each item's size() inside.
x_tmp = x[0] if x[1] is None else x
aggr_out = self.propagate(edge_index, x=x_tmp, pos=pos)
#
i, j = (0, 1) if self.flow == 'target_to_source' else (1, 0)
x_target, pos_target = x[i], pos[i]
add = [pos_target,] if x_target is None else [x_target, pos_target]
aggr_out = torch.cat([aggr_out, *add], dim=1)
if self.mlp is not None: aggr_out = self.mlp(aggr_out)
return aggr_out
def message(self, x_j, pos_j, pos_i, edge_index):
'''
x_j: (E, in_channels)
pos_j: (E, 3)
pos_i: (E, 3)
'''
dist = (pos_j - pos_i).pow(2).sum(dim=1).pow(0.5)
dist = torch.max(dist, torch.Tensor([1e-10]).to(dist.device, dist.dtype))
weight = 1.0 / dist # (E,)
row, col = edge_index
index = col
num_nodes = maybe_num_nodes(index, None)
wsum = scatter_add(weight, col, dim=0, dim_size=num_nodes)[index] + 1e-16 # (E,)
weight /= wsum
return weight.view(-1, 1) * x_j
def update(self, aggr_out):
return aggr_out
class PointNet2FPModule(torch.nn.Module):
def __init__(self, knn_num, mlp):
super(PointNet2FPModule, self).__init__()
self.knn_num = knn_num
self.point_conv = PointConvFP(mlp)
def forward(self, in_layer_data, skip_layer_data):
in_x, in_pos, in_batch = in_layer_data
skip_x, skip_pos, skip_batch = skip_layer_data
row, col = knn(in_pos, skip_pos, self.knn_num, in_batch, skip_batch)
edge_index = torch.stack([col, row], dim=0)
x1 = self.point_conv((in_x, skip_x), (in_pos, skip_pos), edge_index)
pos1, batch1 = skip_pos, skip_batch
return x1, pos1, batch1
def make_mlp(in_channels, mlp_channels, batch_norm=True):
assert len(mlp_channels) >= 1
layers = []
for c in mlp_channels:
layers += [Lin(in_channels, c)]
if batch_norm: layers += [BatchNorm1d(c)]
layers += [ReLU()]
in_channels = c
return Seq(*layers)
class PointNet2PartSegmentNet(torch.nn.Module):
'''
ref:
- https://github.com/charlesq34/pointnet2/blob/master/models/pointnet2_part_seg.py
- https://github.com/rusty1s/pytorch_geometric/blob/master/examples/pointnet++.py
'''
def __init__(self, num_classes):
super(PointNet2PartSegmentNet, self).__init__()
self.num_classes = num_classes
# SA1
sa1_sample_ratio = 0.5
sa1_radius = 0.2
sa1_max_num_neighbours = 64
sa1_mlp = make_mlp(GLOBAL_POINT_FEATURES, [64, 64, 128])
self.sa1_module = PointNet2SAModule(sa1_sample_ratio, sa1_radius, sa1_max_num_neighbours, sa1_mlp)
# SA2
sa2_sample_ratio = 0.25
sa2_radius = 0.4
sa2_max_num_neighbours = 64
sa2_mlp = make_mlp(128+GLOBAL_POINT_FEATURES, [128, 128, 256])
self.sa2_module = PointNet2SAModule(sa2_sample_ratio, sa2_radius, sa2_max_num_neighbours, sa2_mlp)
# SA3
sa3_mlp = make_mlp(256+GLOBAL_POINT_FEATURES, [256, 512, 1024])
self.sa3_module = PointNet2GlobalSAModule(sa3_mlp)
##
knn_num = GLOBAL_POINT_FEATURES
# FP3, reverse of sa3
fp3_knn_num = 1 # After global sa module, there is only one point in point cloud
fp3_mlp = make_mlp(1024+256+GLOBAL_POINT_FEATURES, [256, 256])
self.fp3_module = PointNet2FPModule(fp3_knn_num, fp3_mlp)
# FP2, reverse of sa2
fp2_knn_num = knn_num
fp2_mlp = make_mlp(256+128+GLOBAL_POINT_FEATURES, [256, 128])
self.fp2_module = PointNet2FPModule(fp2_knn_num, fp2_mlp)
# FP1, reverse of sa1
fp1_knn_num = knn_num
fp1_mlp = make_mlp(128+GLOBAL_POINT_FEATURES, [128, 128, 128])
self.fp1_module = PointNet2FPModule(fp1_knn_num, fp1_mlp)
self.fc1 = Lin(128, 128)
self.dropout1 = Dropout(p=0.5)
self.fc2 = Lin(128, self.num_classes)
def forward(self, data):
'''
data: a batch of input, torch.Tensor or torch_geometric.data.Data type
- torch.Tensor: (batch_size, 3, num_points), as common batch input
- torch_geometric.data.Data, as torch_geometric batch input:
data.x: (batch_size * ~num_points, C), batch nodes/points feature,
~num_points means each sample can have different number of points/nodes
data.pos: (batch_size * ~num_points, 3)
data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud
idendifiers for all nodes of all graphs/pointclouds in the batch. See
pytorch_gemometric documentation for more information
'''
dense_input = True if isinstance(data, torch.Tensor) else False
if dense_input:
# Convert to torch_geometric.data.Data type
data = data.transpose(1, 2).contiguous()
batch_size, N, _ = data.shape # (batch_size, num_points, 3)
pos = data.view(batch_size*N, -1)
batch = torch.zeros((batch_size, N), device=pos.device, dtype=torch.long)
for i in range(batch_size): batch[i] = i
batch = batch.view(-1)
data = Data()
data.pos, data.batch = pos, batch
if not hasattr(data, 'x'): data.x = None
data_in = data.x, data.pos, data.batch
sa1_out = self.sa1_module(data_in)
sa2_out = self.sa2_module(sa1_out)
sa3_out = self.sa3_module(sa2_out)
fp3_out = self.fp3_module(sa3_out, sa2_out)
fp2_out = self.fp2_module(fp3_out, sa1_out)
fp1_out = self.fp1_module(fp2_out, data_in)
fp1_out_x, fp1_out_pos, fp1_out_batch = fp1_out
x = self.fc2(self.dropout1(self.fc1(fp1_out_x)))
x = F.log_softmax(x, dim=-1)
if dense_input: return x.view(batch_size, N, self.num_classes)
else: return x, fp1_out_batch
if __name__ == '__main__':
num_classes = 10
net = PointNet2PartSegmentNet(num_classes)
#
print('Test dense input ..')
data1 = torch.rand((2, GLOBAL_POINT_FEATURES, 1024)) # (batch_size, 3, num_points)
print('data1: ', data1.shape)
out1 = net(data1)
print('out1: ', out1.shape)
#
print('Test torch_geometric.data.Data input ..')
def make_data_batch():
# batch_size = 2
pos_num1 = 1000
pos_num2 = 1024
data_batch = Data()
# data_batch.x = None
data_batch.pos = torch.cat([torch.rand(pos_num1, GLOBAL_POINT_FEATURES), torch.rand(pos_num2, GLOBAL_POINT_FEATURES)], dim=0)
data_batch.batch = torch.cat([torch.zeros(pos_num1, dtype=torch.long), torch.ones(pos_num2, dtype=torch.long)])
return data_batch
data2 = make_data_batch()
# print('data.x: ', data.x)
print('data2.pos: ', data2.pos.shape)
print('data2.batch: ', data2.batch.shape)
out2_x, out2_batch = net(data2)
print('out2_x: ', out2_x.shape)
print('out2_batch: ', out2_batch.shape)