Added normals to prediction DataObject

This commit is contained in:
Si11ium 2019-08-07 18:26:21 +02:00
parent 74de208831
commit ce7ff0ae7c

View File

@ -8,6 +8,7 @@ from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_geometric.data.data import Data
from torch_scatter import scatter_add, scatter_max
GLOBAL_POINT_FEATURES = 3
class PointNet2SAModule(torch.nn.Module):
def __init__(self, sample_radio, radius, max_num_neighbors, mlp):
@ -51,7 +52,7 @@ class PointNet2GlobalSAModule(torch.nn.Module):
x1 = scatter_max(x1, batch, dim=0)[0] # (batch_size, C1)
batch_size = x1.shape[0]
pos1 = x1.new_zeros((batch_size, 3)) # set the output point as origin
pos1 = x1.new_zeros((batch_size, GLOBAL_POINT_FEATURES)) # set the output point as origin
batch1 = torch.arange(batch_size).to(batch.device, batch.dtype)
return x1, pos1, batch1
@ -165,36 +166,36 @@ class PointNet2PartSegmentNet(torch.nn.Module):
sa1_sample_ratio = 0.5
sa1_radius = 0.2
sa1_max_num_neighbours = 64
sa1_mlp = make_mlp(3, [64, 64, 128])
sa1_mlp = make_mlp(GLOBAL_POINT_FEATURES, [64, 64, 128])
self.sa1_module = PointNet2SAModule(sa1_sample_ratio, sa1_radius, sa1_max_num_neighbours, sa1_mlp)
# SA2
sa2_sample_ratio = 0.25
sa2_radius = 0.4
sa2_max_num_neighbours = 64
sa2_mlp = make_mlp(128+3, [128, 128, 256])
sa2_mlp = make_mlp(128+GLOBAL_POINT_FEATURES, [128, 128, 256])
self.sa2_module = PointNet2SAModule(sa2_sample_ratio, sa2_radius, sa2_max_num_neighbours, sa2_mlp)
# SA3
sa3_mlp = make_mlp(256+3, [256, 512, 1024])
sa3_mlp = make_mlp(256+GLOBAL_POINT_FEATURES, [256, 512, 1024])
self.sa3_module = PointNet2GlobalSAModule(sa3_mlp)
##
knn_num = 3
knn_num = GLOBAL_POINT_FEATURES
# FP3, reverse of sa3
fp3_knn_num = 1 # After global sa module, there is only one point in point cloud
fp3_mlp = make_mlp(1024+256+3, [256, 256])
fp3_mlp = make_mlp(1024+256+GLOBAL_POINT_FEATURES, [256, 256])
self.fp3_module = PointNet2FPModule(fp3_knn_num, fp3_mlp)
# FP2, reverse of sa2
fp2_knn_num = knn_num
fp2_mlp = make_mlp(256+128+3, [256, 128])
fp2_mlp = make_mlp(256+128+GLOBAL_POINT_FEATURES, [256, 128])
self.fp2_module = PointNet2FPModule(fp2_knn_num, fp2_mlp)
# FP1, reverse of sa1
fp1_knn_num = knn_num
fp1_mlp = make_mlp(128+3, [128, 128, 128])
fp1_mlp = make_mlp(128+GLOBAL_POINT_FEATURES, [128, 128, 128])
self.fp1_module = PointNet2FPModule(fp1_knn_num, fp1_mlp)
self.fc1 = Lin(128, 128)
@ -255,7 +256,7 @@ if __name__ == '__main__':
#
print('Test dense input ..')
data1 = torch.rand((2, 3, 1024)) # (batch_size, 3, num_points)
data1 = torch.rand((2, GLOBAL_POINT_FEATURES, 1024)) # (batch_size, 3, num_points)
print('data1: ', data1.shape)
out1 = net(data1)
@ -271,7 +272,7 @@ if __name__ == '__main__':
data_batch = Data()
# data_batch.x = None
data_batch.pos = torch.cat([torch.rand(pos_num1, 3), torch.rand(pos_num2, 3)], dim=0)
data_batch.pos = torch.cat([torch.rand(pos_num1, GLOBAL_POINT_FEATURES), torch.rand(pos_num2, GLOBAL_POINT_FEATURES)], dim=0)
data_batch.batch = torch.cat([torch.zeros(pos_num1, dtype=torch.long), torch.ones(pos_num2, dtype=torch.long)])
return data_batch