Added normals to prediction DataObject
This commit is contained in:
parent
74de208831
commit
ce7ff0ae7c
@ -8,6 +8,7 @@ from torch_geometric.utils.num_nodes import maybe_num_nodes
|
|||||||
from torch_geometric.data.data import Data
|
from torch_geometric.data.data import Data
|
||||||
from torch_scatter import scatter_add, scatter_max
|
from torch_scatter import scatter_add, scatter_max
|
||||||
|
|
||||||
|
GLOBAL_POINT_FEATURES = 3
|
||||||
|
|
||||||
class PointNet2SAModule(torch.nn.Module):
|
class PointNet2SAModule(torch.nn.Module):
|
||||||
def __init__(self, sample_radio, radius, max_num_neighbors, mlp):
|
def __init__(self, sample_radio, radius, max_num_neighbors, mlp):
|
||||||
@ -51,7 +52,7 @@ class PointNet2GlobalSAModule(torch.nn.Module):
|
|||||||
x1 = scatter_max(x1, batch, dim=0)[0] # (batch_size, C1)
|
x1 = scatter_max(x1, batch, dim=0)[0] # (batch_size, C1)
|
||||||
|
|
||||||
batch_size = x1.shape[0]
|
batch_size = x1.shape[0]
|
||||||
pos1 = x1.new_zeros((batch_size, 3)) # set the output point as origin
|
pos1 = x1.new_zeros((batch_size, GLOBAL_POINT_FEATURES)) # set the output point as origin
|
||||||
batch1 = torch.arange(batch_size).to(batch.device, batch.dtype)
|
batch1 = torch.arange(batch_size).to(batch.device, batch.dtype)
|
||||||
|
|
||||||
return x1, pos1, batch1
|
return x1, pos1, batch1
|
||||||
@ -165,36 +166,36 @@ class PointNet2PartSegmentNet(torch.nn.Module):
|
|||||||
sa1_sample_ratio = 0.5
|
sa1_sample_ratio = 0.5
|
||||||
sa1_radius = 0.2
|
sa1_radius = 0.2
|
||||||
sa1_max_num_neighbours = 64
|
sa1_max_num_neighbours = 64
|
||||||
sa1_mlp = make_mlp(3, [64, 64, 128])
|
sa1_mlp = make_mlp(GLOBAL_POINT_FEATURES, [64, 64, 128])
|
||||||
self.sa1_module = PointNet2SAModule(sa1_sample_ratio, sa1_radius, sa1_max_num_neighbours, sa1_mlp)
|
self.sa1_module = PointNet2SAModule(sa1_sample_ratio, sa1_radius, sa1_max_num_neighbours, sa1_mlp)
|
||||||
|
|
||||||
# SA2
|
# SA2
|
||||||
sa2_sample_ratio = 0.25
|
sa2_sample_ratio = 0.25
|
||||||
sa2_radius = 0.4
|
sa2_radius = 0.4
|
||||||
sa2_max_num_neighbours = 64
|
sa2_max_num_neighbours = 64
|
||||||
sa2_mlp = make_mlp(128+3, [128, 128, 256])
|
sa2_mlp = make_mlp(128+GLOBAL_POINT_FEATURES, [128, 128, 256])
|
||||||
self.sa2_module = PointNet2SAModule(sa2_sample_ratio, sa2_radius, sa2_max_num_neighbours, sa2_mlp)
|
self.sa2_module = PointNet2SAModule(sa2_sample_ratio, sa2_radius, sa2_max_num_neighbours, sa2_mlp)
|
||||||
|
|
||||||
# SA3
|
# SA3
|
||||||
sa3_mlp = make_mlp(256+3, [256, 512, 1024])
|
sa3_mlp = make_mlp(256+GLOBAL_POINT_FEATURES, [256, 512, 1024])
|
||||||
self.sa3_module = PointNet2GlobalSAModule(sa3_mlp)
|
self.sa3_module = PointNet2GlobalSAModule(sa3_mlp)
|
||||||
|
|
||||||
##
|
##
|
||||||
knn_num = 3
|
knn_num = GLOBAL_POINT_FEATURES
|
||||||
|
|
||||||
# FP3, reverse of sa3
|
# FP3, reverse of sa3
|
||||||
fp3_knn_num = 1 # After global sa module, there is only one point in point cloud
|
fp3_knn_num = 1 # After global sa module, there is only one point in point cloud
|
||||||
fp3_mlp = make_mlp(1024+256+3, [256, 256])
|
fp3_mlp = make_mlp(1024+256+GLOBAL_POINT_FEATURES, [256, 256])
|
||||||
self.fp3_module = PointNet2FPModule(fp3_knn_num, fp3_mlp)
|
self.fp3_module = PointNet2FPModule(fp3_knn_num, fp3_mlp)
|
||||||
|
|
||||||
# FP2, reverse of sa2
|
# FP2, reverse of sa2
|
||||||
fp2_knn_num = knn_num
|
fp2_knn_num = knn_num
|
||||||
fp2_mlp = make_mlp(256+128+3, [256, 128])
|
fp2_mlp = make_mlp(256+128+GLOBAL_POINT_FEATURES, [256, 128])
|
||||||
self.fp2_module = PointNet2FPModule(fp2_knn_num, fp2_mlp)
|
self.fp2_module = PointNet2FPModule(fp2_knn_num, fp2_mlp)
|
||||||
|
|
||||||
# FP1, reverse of sa1
|
# FP1, reverse of sa1
|
||||||
fp1_knn_num = knn_num
|
fp1_knn_num = knn_num
|
||||||
fp1_mlp = make_mlp(128+3, [128, 128, 128])
|
fp1_mlp = make_mlp(128+GLOBAL_POINT_FEATURES, [128, 128, 128])
|
||||||
self.fp1_module = PointNet2FPModule(fp1_knn_num, fp1_mlp)
|
self.fp1_module = PointNet2FPModule(fp1_knn_num, fp1_mlp)
|
||||||
|
|
||||||
self.fc1 = Lin(128, 128)
|
self.fc1 = Lin(128, 128)
|
||||||
@ -255,7 +256,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
#
|
#
|
||||||
print('Test dense input ..')
|
print('Test dense input ..')
|
||||||
data1 = torch.rand((2, 3, 1024)) # (batch_size, 3, num_points)
|
data1 = torch.rand((2, GLOBAL_POINT_FEATURES, 1024)) # (batch_size, 3, num_points)
|
||||||
print('data1: ', data1.shape)
|
print('data1: ', data1.shape)
|
||||||
|
|
||||||
out1 = net(data1)
|
out1 = net(data1)
|
||||||
@ -271,7 +272,7 @@ if __name__ == '__main__':
|
|||||||
data_batch = Data()
|
data_batch = Data()
|
||||||
|
|
||||||
# data_batch.x = None
|
# data_batch.x = None
|
||||||
data_batch.pos = torch.cat([torch.rand(pos_num1, 3), torch.rand(pos_num2, 3)], dim=0)
|
data_batch.pos = torch.cat([torch.rand(pos_num1, GLOBAL_POINT_FEATURES), torch.rand(pos_num2, GLOBAL_POINT_FEATURES)], dim=0)
|
||||||
data_batch.batch = torch.cat([torch.zeros(pos_num1, dtype=torch.long), torch.ones(pos_num2, dtype=torch.long)])
|
data_batch.batch = torch.cat([torch.zeros(pos_num1, dtype=torch.long), torch.ones(pos_num2, dtype=torch.long)])
|
||||||
|
|
||||||
return data_batch
|
return data_batch
|
||||||
|
Loading…
x
Reference in New Issue
Block a user