Added normals to prediction DataObject
This commit is contained in:
parent
74de208831
commit
ce7ff0ae7c
@ -8,6 +8,7 @@ from torch_geometric.utils.num_nodes import maybe_num_nodes
|
||||
from torch_geometric.data.data import Data
|
||||
from torch_scatter import scatter_add, scatter_max
|
||||
|
||||
GLOBAL_POINT_FEATURES = 3
|
||||
|
||||
class PointNet2SAModule(torch.nn.Module):
|
||||
def __init__(self, sample_radio, radius, max_num_neighbors, mlp):
|
||||
@ -51,7 +52,7 @@ class PointNet2GlobalSAModule(torch.nn.Module):
|
||||
x1 = scatter_max(x1, batch, dim=0)[0] # (batch_size, C1)
|
||||
|
||||
batch_size = x1.shape[0]
|
||||
pos1 = x1.new_zeros((batch_size, 3)) # set the output point as origin
|
||||
pos1 = x1.new_zeros((batch_size, GLOBAL_POINT_FEATURES)) # set the output point as origin
|
||||
batch1 = torch.arange(batch_size).to(batch.device, batch.dtype)
|
||||
|
||||
return x1, pos1, batch1
|
||||
@ -165,36 +166,36 @@ class PointNet2PartSegmentNet(torch.nn.Module):
|
||||
sa1_sample_ratio = 0.5
|
||||
sa1_radius = 0.2
|
||||
sa1_max_num_neighbours = 64
|
||||
sa1_mlp = make_mlp(3, [64, 64, 128])
|
||||
sa1_mlp = make_mlp(GLOBAL_POINT_FEATURES, [64, 64, 128])
|
||||
self.sa1_module = PointNet2SAModule(sa1_sample_ratio, sa1_radius, sa1_max_num_neighbours, sa1_mlp)
|
||||
|
||||
# SA2
|
||||
sa2_sample_ratio = 0.25
|
||||
sa2_radius = 0.4
|
||||
sa2_max_num_neighbours = 64
|
||||
sa2_mlp = make_mlp(128+3, [128, 128, 256])
|
||||
sa2_mlp = make_mlp(128+GLOBAL_POINT_FEATURES, [128, 128, 256])
|
||||
self.sa2_module = PointNet2SAModule(sa2_sample_ratio, sa2_radius, sa2_max_num_neighbours, sa2_mlp)
|
||||
|
||||
# SA3
|
||||
sa3_mlp = make_mlp(256+3, [256, 512, 1024])
|
||||
sa3_mlp = make_mlp(256+GLOBAL_POINT_FEATURES, [256, 512, 1024])
|
||||
self.sa3_module = PointNet2GlobalSAModule(sa3_mlp)
|
||||
|
||||
##
|
||||
knn_num = 3
|
||||
knn_num = GLOBAL_POINT_FEATURES
|
||||
|
||||
# FP3, reverse of sa3
|
||||
fp3_knn_num = 1 # After global sa module, there is only one point in point cloud
|
||||
fp3_mlp = make_mlp(1024+256+3, [256, 256])
|
||||
fp3_mlp = make_mlp(1024+256+GLOBAL_POINT_FEATURES, [256, 256])
|
||||
self.fp3_module = PointNet2FPModule(fp3_knn_num, fp3_mlp)
|
||||
|
||||
# FP2, reverse of sa2
|
||||
fp2_knn_num = knn_num
|
||||
fp2_mlp = make_mlp(256+128+3, [256, 128])
|
||||
fp2_mlp = make_mlp(256+128+GLOBAL_POINT_FEATURES, [256, 128])
|
||||
self.fp2_module = PointNet2FPModule(fp2_knn_num, fp2_mlp)
|
||||
|
||||
# FP1, reverse of sa1
|
||||
fp1_knn_num = knn_num
|
||||
fp1_mlp = make_mlp(128+3, [128, 128, 128])
|
||||
fp1_mlp = make_mlp(128+GLOBAL_POINT_FEATURES, [128, 128, 128])
|
||||
self.fp1_module = PointNet2FPModule(fp1_knn_num, fp1_mlp)
|
||||
|
||||
self.fc1 = Lin(128, 128)
|
||||
@ -255,7 +256,7 @@ if __name__ == '__main__':
|
||||
|
||||
#
|
||||
print('Test dense input ..')
|
||||
data1 = torch.rand((2, 3, 1024)) # (batch_size, 3, num_points)
|
||||
data1 = torch.rand((2, GLOBAL_POINT_FEATURES, 1024)) # (batch_size, 3, num_points)
|
||||
print('data1: ', data1.shape)
|
||||
|
||||
out1 = net(data1)
|
||||
@ -271,7 +272,7 @@ if __name__ == '__main__':
|
||||
data_batch = Data()
|
||||
|
||||
# data_batch.x = None
|
||||
data_batch.pos = torch.cat([torch.rand(pos_num1, 3), torch.rand(pos_num2, 3)], dim=0)
|
||||
data_batch.pos = torch.cat([torch.rand(pos_num1, GLOBAL_POINT_FEATURES), torch.rand(pos_num2, GLOBAL_POINT_FEATURES)], dim=0)
|
||||
data_batch.batch = torch.cat([torch.zeros(pos_num1, dtype=torch.long), torch.ones(pos_num2, dtype=torch.long)])
|
||||
|
||||
return data_batch
|
||||
|
Loading…
x
Reference in New Issue
Block a user