246 lines
		
	
	
		
			9.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			246 lines
		
	
	
		
			9.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import sys
 | |
| import os
 | |
| import shutil
 | |
| import math
 | |
| from dataset.shapenet import ShapeNetPartSegDataset
 | |
| from model.pointnet2_part_seg import PointNet2PartSegmentNet
 | |
| import torch_geometric.transforms as GT
 | |
| import torch
 | |
| import argparse
 | |
| from distutils.util import strtobool
 | |
| import numpy as np
 | |
| import pointcloud as pc
 | |
| import operator
 | |
| 
 | |
| 
 | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../')  # add project root directory
 | |
| 
 | |
| 
 | |
| def eval_sample(net, sample):
 | |
|     '''
 | |
|     sample: { 'points': tensor(n, 3), 'labels': tensor(n,) }
 | |
|     return: (pred_label, gt_label) with labels shape (n,)
 | |
|     '''
 | |
|     net.eval()
 | |
|     with torch.no_grad():
 | |
|         # points: (n, 3)
 | |
|         points, gt_label = sample['points'], sample['labels']
 | |
|         n = points.shape[0]
 | |
|         f = points.shape[1]
 | |
| 
 | |
|         points = points.view(1, n, f)  # make a batch
 | |
|         points = points.transpose(1, 2).contiguous()
 | |
|         points = points.to(device, dtype)
 | |
| 
 | |
|         pred = net(points)  # (batch_size, n, num_classes)
 | |
|         pred_label = pred.max(2)[1]
 | |
|         pred_label = pred_label.view(-1).cpu()  # (n,)
 | |
| 
 | |
|         assert pred_label.shape == gt_label.shape
 | |
|         return (pred_label, gt_label)
 | |
| 
 | |
| 
 | |
| def append_normal_angles(data):
 | |
| 
 | |
|     def func(x):
 | |
|         theta = math.acos(x[2]) / math.pi
 | |
|         phi = (math.atan2(x[1], x[0]) + math.pi) / (2.0 * math.pi)
 | |
|         return (theta, phi)
 | |
| 
 | |
|     res = np.array([func(xi) for xi in data[:, 3:6]])
 | |
| 
 | |
|     print(res)
 | |
| 
 | |
|     return np.column_stack((data, res))
 | |
| 
 | |
| 
 | |
| def recreate_folder(folder):
 | |
|     if os.path.exists(folder) and os.path.isdir(folder):
 | |
|         shutil.rmtree(folder)
 | |
|     os.mkdir(folder)
 | |
| 
 | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../')  # add project root directory
 | |
| 
 | |
| parser = argparse.ArgumentParser()
 | |
| parser.add_argument('--npoints', type=int, default=2048, help='resample points number')
 | |
| parser.add_argument('--model', type=str, default='./checkpoint/seg_model_custom_131.pth', help='model path')
 | |
| parser.add_argument('--sample_idx', type=int, default=0, help='select a sample to segment and view result')
 | |
| parser.add_argument('--headers', type=strtobool, default=True, help='if raw files come with headers')
 | |
| parser.add_argument('--with_normals', type=strtobool, default=True, help='if training will include normals')
 | |
| parser.add_argument('--collate_per_segment', type=strtobool, default=True, help='whether to look at pointclouds or sub')
 | |
| parser.add_argument('--has_variations', type=strtobool, default=False,
 | |
|                     help='whether a single pointcloud has variations '
 | |
|                          'named int(id)_pc.(xyz|dat) look at pointclouds or sub')
 | |
| opt = parser.parse_args()
 | |
| print(opt)
 | |
| 
 | |
| if __name__ == '__main__':
 | |
| 
 | |
|     # ------------------------------------------------------------------------------------------------------------------
 | |
|     # Load point cloud, cluster it and store clusters as point cloud cluster files again for later prediction
 | |
|     # ------------------------------------------------------------------------------------------------------------------
 | |
| 
 | |
|     # Create dataset
 | |
|     print('Create data set ..')
 | |
| 
 | |
|     dataset_folder = './data/raw/predict/'
 | |
|     pointcloud_file = './pointclouds/m3.xyz'
 | |
| 
 | |
|     # Load and pre-process point cloud
 | |
|     pcloud = pc.read_pointcloud(pointcloud_file)
 | |
|     pcloud = pc.normalize_pointcloud(pcloud, 1)
 | |
| 
 | |
|     #pc_clusters = pc.hierarchical_clustering(pcloud, selected_indices_0=[0, 1, 2, 3, 4, 5],
 | |
|     #                                         selected_indices_1=[0, 1, 2, 3, 4, 5], eps=0.7, min_samples=5)
 | |
| 
 | |
|     pc_clusters = pc.cluster_cubes(pcloud, [4, 4, 4])
 | |
|     print("Pre-Processing: Clustering")
 | |
|     pc.draw_clusters(pc_clusters)
 | |
| 
 | |
|     recreate_folder(dataset_folder)
 | |
|     for idx, pcc in enumerate(pc_clusters):
 | |
|          pcc = pc.farthest_point_sampling(pcc, opt.npoints)
 | |
|          recreate_folder(dataset_folder + str(idx) + '/')
 | |
|          pc.write_pointcloud(dataset_folder + str(idx) + '/pc.xyz', pcc)
 | |
|          #draw_sample_data(pcc, True)
 | |
| 
 | |
|     # ------------------------------------------------------------------------------------------------------------------
 | |
|     # Load point cloud clusters and model.
 | |
|     # ------------------------------------------------------------------------------------------------------------------
 | |
| 
 | |
|     # Load dataset
 | |
|     print('load dataset ..')
 | |
|     test_transform = GT.Compose([GT.NormalizeScale(), ])
 | |
| 
 | |
|     test_dataset = ShapeNetPartSegDataset(
 | |
|          mode='predict',
 | |
|          root_dir='data',
 | |
|          with_normals=opt.with_normals,
 | |
|          npoints=opt.npoints,
 | |
|          refresh=True,
 | |
|          collate_per_segment=opt.collate_per_segment,
 | |
|          has_variations=opt.has_variations,
 | |
|          headers=opt.headers
 | |
|     )
 | |
|     num_classes = test_dataset.num_classes()
 | |
| 
 | |
|     # Load model
 | |
|     print('Construct model ..')
 | |
|     device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
 | |
|     dtype = torch.float
 | |
| 
 | |
|     net = PointNet2PartSegmentNet(num_classes, with_normals=opt.with_normals)
 | |
|     net.load_state_dict(torch.load(opt.model, map_location=device.type))
 | |
|     net = net.to(device, dtype)
 | |
|     net.eval()
 | |
| 
 | |
|     # ------------------------------------------------------------------------------------------------------------------
 | |
|     # Predict per cluster.
 | |
|     # ------------------------------------------------------------------------------------------------------------------
 | |
| 
 | |
|     labeled_dataset = None
 | |
|     result_clusters = []
 | |
| 
 | |
|     # Iterate over all the samples and predict
 | |
|     for idx, sample in enumerate(test_dataset):
 | |
| 
 | |
|          predicted_label, _ = eval_sample(net, sample)
 | |
|          if opt.with_normals:
 | |
|              sample_data = np.column_stack((sample["points"].numpy(), predicted_label.numpy()))
 | |
|          else:
 | |
|              sample_data = np.column_stack((sample["points"].numpy(), sample["normals"], predicted_label.numpy()))
 | |
| 
 | |
|          result_clusters.append(sample_data)
 | |
| 
 | |
|          if labeled_dataset is None:
 | |
|              labeled_dataset = sample_data
 | |
|          else:
 | |
|              labeled_dataset = np.vstack((labeled_dataset, sample_data))
 | |
| 
 | |
|          print("Prediction done for cluster " + str(idx+1) + "/" + str(len(test_dataset)))
 | |
| 
 | |
|     # ------------------------------------------------------------------------------------------------------------------
 | |
|     # Remove cluster rows if the amount of points for a particular primitive type is below a threshold.
 | |
|     # ------------------------------------------------------------------------------------------------------------------
 | |
| 
 | |
|     min_cluster_size = 10
 | |
|     contamination = 0.01
 | |
| 
 | |
|     filtered_clusters = filter(lambda c : c.shape[0] > min_cluster_size, result_clusters)
 | |
|     type_filtered_clusters = []
 | |
|     for c in filtered_clusters:
 | |
| 
 | |
|         prim_types = np.unique(c[:, 6])
 | |
|         pt_count = {}
 | |
|         for pt in prim_types:
 | |
|             pt_count[pt] = len(c[c[:, 6] == pt])
 | |
| 
 | |
|         max_pt = max(pt_count.items(), key=operator.itemgetter(1))[0]
 | |
|         min_size = pt_count[max_pt] * contamination
 | |
| 
 | |
|         valid_types = []
 | |
|         for pt in prim_types:
 | |
|             if pt_count[pt] > min_size:
 | |
|                 valid_types.append(pt)
 | |
| 
 | |
|         filtered_c = c[np.isin(c[:, 6], valid_types)]
 | |
|         type_filtered_clusters.append(filtered_c)
 | |
| 
 | |
|     result_clusters = type_filtered_clusters
 | |
| 
 | |
|     labeled_dataset = np.vstack(result_clusters)
 | |
| 
 | |
|     np.savetxt('labeled_dataset.txt', labeled_dataset)
 | |
| 
 | |
|     # ------------------------------------------------------------------------------------------------------------------
 | |
|     # Clustering that results in per-primitive type clusters
 | |
|     # ------------------------------------------------------------------------------------------------------------------
 | |
| 
 | |
|     # labeled_dataset = np.loadtxt('labeled_dataset.txt')
 | |
|     pc.draw_sample_data(labeled_dataset)
 | |
| 
 | |
|     # Try to get rid of outliers.
 | |
|     labeled_dataset,outliers = pc.split_outliers(labeled_dataset, columns=[0,1,2,3,4,5])
 | |
|     pc.draw_sample_data(outliers, False)
 | |
| 
 | |
|     print("Final clustering..")
 | |
| 
 | |
|     labeled_dataset = pc.append_onehotencoded_type(labeled_dataset)
 | |
| 
 | |
|     print("Test row: ", labeled_dataset[:1, :])
 | |
| 
 | |
|     total_clusters = []
 | |
| 
 | |
|     clusters = pc.cluster_dbscan(labeled_dataset, [0,1,2,3,4,5], eps=0.1, min_samples=100)
 | |
|     print("Pre-clustering done. Clusters: ", len(clusters))
 | |
|     pc.draw_clusters(clusters)
 | |
| 
 | |
|     for cluster in clusters:
 | |
|         #cluster = pc.normalize_pointcloud(cluster)
 | |
| 
 | |
|         print("2nd level clustering ..")
 | |
| 
 | |
|         prim_types_in_cluster = len(np.unique(cluster[:, 6], axis=0))
 | |
|         if prim_types_in_cluster == 1:
 | |
|             print("No need for 2nd level clustering since there is only a single primitive type in the cluster.")
 | |
|             total_clusters.append(cluster)
 | |
|         else:
 | |
|             sub_clusters = pc.cluster_dbscan(cluster, [0,1,2,7,8,9,10], eps=0.1, min_samples=100)
 | |
|             print("Sub clusters: ", len(sub_clusters))
 | |
|             total_clusters.extend(sub_clusters)
 | |
| 
 | |
|     result_clusters = list(filter(lambda c: c.shape[0] > 100, total_clusters))
 | |
| 
 | |
|     for cluster in result_clusters:
 | |
|         print("Cluster: ", cluster.shape[0])
 | |
| 
 | |
|         # pc.draw_sample_data(cluster, False)
 | |
| 
 | |
|     print("Number of clusters: ", len(result_clusters))
 | |
| 
 | |
|     pc.draw_clusters(result_clusters)
 | |
| 
 | |
|     # ------------------------------------------------------------------------------------------------------------------
 | |
|     # Write clusters to file.
 | |
|     # ------------------------------------------------------------------------------------------------------------------
 | |
|     pc.write_clusters("clusters.txt", result_clusters) | 
