pipeline for single cluster
This commit is contained in:
@ -72,7 +72,7 @@ if __name__ == '__main__':
|
|||||||
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False,
|
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False,
|
||||||
refresh=True, transform=transforms)
|
refresh=True, transform=transforms)
|
||||||
|
|
||||||
grid_clusters = cluster_cubes(test_dataset[1], [3, 3, 3], max_points_per_cluster=1024)
|
grid_clusters = cluster_cubes(test_dataset[1], [1, 1, 1], max_points_per_cluster=32768)
|
||||||
|
|
||||||
ps.init()
|
ps.init()
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ class PointNet2(BaseValMixin,
|
|||||||
# Dataset
|
# Dataset
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# rot_max_angle = 15
|
# rot_max_angle = 15
|
||||||
trans_max_distance = 0.01
|
trans_max_distance = 0.02
|
||||||
transforms = Compose(
|
transforms = Compose(
|
||||||
[
|
[
|
||||||
RandomFlip(0, p=0.8),
|
RandomFlip(0, p=0.8),
|
||||||
|
@ -103,10 +103,6 @@ def cluster_per_column(pc, column):
|
|||||||
|
|
||||||
|
|
||||||
def cluster_cubes(data, cluster_dims, max_points_per_cluster=-1, min_points_per_cluster=-1):
|
def cluster_cubes(data, cluster_dims, max_points_per_cluster=-1, min_points_per_cluster=-1):
|
||||||
if cluster_dims[0] == 1 and cluster_dims[1] == 1 and cluster_dims[2] == 1:
|
|
||||||
print("no need to cluster.")
|
|
||||||
return [farthest_point_sampling(data, max_points_per_cluster)]
|
|
||||||
|
|
||||||
if isinstance(data, Data):
|
if isinstance(data, Data):
|
||||||
import torch
|
import torch
|
||||||
candidate_list = list()
|
candidate_list = list()
|
||||||
@ -119,6 +115,10 @@ def cluster_cubes(data, cluster_dims, max_points_per_cluster=-1, min_points_per_
|
|||||||
|
|
||||||
data = torch.cat(candidate_list, dim=-1).numpy()
|
data = torch.cat(candidate_list, dim=-1).numpy()
|
||||||
|
|
||||||
|
if cluster_dims[0] == 1 and cluster_dims[1] == 1 and cluster_dims[2] == 1:
|
||||||
|
print("no need to cluster.")
|
||||||
|
return [farthest_point_sampling(data, max_points_per_cluster)]
|
||||||
|
|
||||||
max = data[:, :3].max(axis=0)
|
max = data[:, :3].max(axis=0)
|
||||||
max += max * 0.01
|
max += max * 0.01
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user