From ce7ff0ae7ca4026d1208119adbf448ae1bb3b80a Mon Sep 17 00:00:00 2001 From: Si11ium Date: Wed, 7 Aug 2019 18:26:21 +0200 Subject: [PATCH] Added normals to prediction DataObject --- model/pointnet2_part_seg.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/model/pointnet2_part_seg.py b/model/pointnet2_part_seg.py index 64a7cc9..8c66ebf 100644 --- a/model/pointnet2_part_seg.py +++ b/model/pointnet2_part_seg.py @@ -8,6 +8,7 @@ from torch_geometric.utils.num_nodes import maybe_num_nodes from torch_geometric.data.data import Data from torch_scatter import scatter_add, scatter_max +GLOBAL_POINT_FEATURES = 3 class PointNet2SAModule(torch.nn.Module): def __init__(self, sample_radio, radius, max_num_neighbors, mlp): @@ -51,7 +52,7 @@ class PointNet2GlobalSAModule(torch.nn.Module): x1 = scatter_max(x1, batch, dim=0)[0] # (batch_size, C1) batch_size = x1.shape[0] - pos1 = x1.new_zeros((batch_size, 3)) # set the output point as origin + pos1 = x1.new_zeros((batch_size, GLOBAL_POINT_FEATURES)) # set the output point as origin batch1 = torch.arange(batch_size).to(batch.device, batch.dtype) return x1, pos1, batch1 @@ -165,36 +166,36 @@ class PointNet2PartSegmentNet(torch.nn.Module): sa1_sample_ratio = 0.5 sa1_radius = 0.2 sa1_max_num_neighbours = 64 - sa1_mlp = make_mlp(3, [64, 64, 128]) + sa1_mlp = make_mlp(GLOBAL_POINT_FEATURES, [64, 64, 128]) self.sa1_module = PointNet2SAModule(sa1_sample_ratio, sa1_radius, sa1_max_num_neighbours, sa1_mlp) # SA2 sa2_sample_ratio = 0.25 sa2_radius = 0.4 sa2_max_num_neighbours = 64 - sa2_mlp = make_mlp(128+3, [128, 128, 256]) + sa2_mlp = make_mlp(128+GLOBAL_POINT_FEATURES, [128, 128, 256]) self.sa2_module = PointNet2SAModule(sa2_sample_ratio, sa2_radius, sa2_max_num_neighbours, sa2_mlp) # SA3 - sa3_mlp = make_mlp(256+3, [256, 512, 1024]) + sa3_mlp = make_mlp(256+GLOBAL_POINT_FEATURES, [256, 512, 1024]) self.sa3_module = PointNet2GlobalSAModule(sa3_mlp) ## - knn_num = 3 + knn_num = GLOBAL_POINT_FEATURES # FP3, reverse of sa3 fp3_knn_num = 1 # After global sa module, there is only one point in point cloud - fp3_mlp = make_mlp(1024+256+3, [256, 256]) + fp3_mlp = make_mlp(1024+256+GLOBAL_POINT_FEATURES, [256, 256]) self.fp3_module = PointNet2FPModule(fp3_knn_num, fp3_mlp) # FP2, reverse of sa2 fp2_knn_num = knn_num - fp2_mlp = make_mlp(256+128+3, [256, 128]) + fp2_mlp = make_mlp(256+128+GLOBAL_POINT_FEATURES, [256, 128]) self.fp2_module = PointNet2FPModule(fp2_knn_num, fp2_mlp) # FP1, reverse of sa1 fp1_knn_num = knn_num - fp1_mlp = make_mlp(128+3, [128, 128, 128]) + fp1_mlp = make_mlp(128+GLOBAL_POINT_FEATURES, [128, 128, 128]) self.fp1_module = PointNet2FPModule(fp1_knn_num, fp1_mlp) self.fc1 = Lin(128, 128) @@ -255,7 +256,7 @@ if __name__ == '__main__': # print('Test dense input ..') - data1 = torch.rand((2, 3, 1024)) # (batch_size, 3, num_points) + data1 = torch.rand((2, GLOBAL_POINT_FEATURES, 1024)) # (batch_size, 3, num_points) print('data1: ', data1.shape) out1 = net(data1) @@ -271,7 +272,7 @@ if __name__ == '__main__': data_batch = Data() # data_batch.x = None - data_batch.pos = torch.cat([torch.rand(pos_num1, 3), torch.rand(pos_num2, 3)], dim=0) + data_batch.pos = torch.cat([torch.rand(pos_num1, GLOBAL_POINT_FEATURES), torch.rand(pos_num2, GLOBAL_POINT_FEATURES)], dim=0) data_batch.batch = torch.cat([torch.zeros(pos_num1, dtype=torch.long), torch.ones(pos_num2, dtype=torch.long)]) return data_batch