from pathlib import Path import torch import polyscope as ps import numpy as np # Dataset and Dataloaders # ============================================================================= # Transforms from torch_geometric.transforms import Compose, NormalizeScale from ml_lib.point_toolset.point_io import BatchToData from ml_lib.utils.model_io import SavedLightningModels # Datasets from datasets.shapenet import ShapeNetPartSegDataset from models import PointNet2 from utils.pointcloud import cluster_cubes, append_onehotencoded_type, label2color, \ write_clusters, cluster2Color, cluster_dbscan from utils.project_settings import GlobalVar, DataClass class DisplayMode(DataClass): Clusters = 0, Types = 1, Nothing = 2 def restore_logger_and_model(log_dir): model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, model=PointNet2, n=-1) model = model.restore() if torch.cuda.is_available(): model.cuda() else: model.cpu() return model def predict_prim_type(input_pc, model): input_data = dict( norm=torch.tensor(np.array([input_pc[:, 3:6]], np.float)).unsqueeze(0), pos=torch.tensor(input_pc[:, 0:3]).unsqueeze(0), ) batch_to_data = BatchToData() data = batch_to_data(input_data) y = loaded_model(data.to(device='cuda' if torch.cuda.is_available() else 'cpu')) y_primary = torch.argmax(y.main_out, dim=-1).cpu().numpy() if input_pc.shape[1] > 6: input_pc = input_pc[:, :6] return np.concatenate((input_pc, y_primary.reshape(-1, 1)), axis=-1) if __name__ == '__main__': display_mode = DisplayMode.Types grid_cluster_max_pts = 32 * 1024 grid_clusters = [1, 1, 1] type_cluster_eps = 0.1 type_cluster_min_pts = 100 model_path = Path('output') / 'PN2' / 'PN_597ecab330b04d977cda8a09ae0e5f6e' / 'version_0' loaded_model = restore_logger_and_model(model_path) loaded_model.eval() transforms = Compose([NormalizeScale(), ]) test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False, refresh=True, transform=transforms, cluster_type=None) grid_clusters = cluster_cubes(test_dataset[0], grid_clusters, max_points_per_cluster=grid_cluster_max_pts) ps.init() # ========================== Grid Clustering ========================== grid_cluster_pcs = [] for i, grid_cluster_pc in enumerate(grid_clusters): print("Cluster pointcloud size: {}".format(grid_cluster_pc.shape[0])) pc_with_prim_type = predict_prim_type(grid_cluster_pc, loaded_model) # Re-Map Primitive type for 1-hot-encoding. pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 0.0] = 0.0 # Sphere pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 1.0] = 1.0 # Cylinder pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 3.0] = 2.0 # Box pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 4.0] = 2.0 # Polytope pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 6.0] = 2.0 # Plane pc_with_prim_type = append_onehotencoded_type(pc_with_prim_type) grid_cluster_pcs.append(pc_with_prim_type) # Merge grid cluster pcs together final_pc = np.concatenate(grid_cluster_pcs) # ========================== DBSCAN Clustering ========================== print("DB Scan on point cloud " + str(final_pc.shape)) total_clusters = [] clusters = cluster_dbscan(final_pc, [0, 1, 2, 3, 4, 5], eps=type_cluster_eps, min_samples=type_cluster_min_pts) print("Pre-clustering done. Clusters: ", len(clusters)) for cluster in clusters: print("2nd level clustering ..") prim_types_in_cluster = len(np.unique(cluster[:, 6], axis=0)) if prim_types_in_cluster == 1: print("No need for 2nd level clustering since there is only a single primitive type in the cluster.") total_clusters.append(cluster) else: sub_clusters = cluster_dbscan(cluster, [0, 1, 2, 7, 8, 9], eps=type_cluster_eps, min_samples=type_cluster_min_pts) print("Sub clusters: ", len(sub_clusters)) total_clusters.extend(sub_clusters) result_clusters = list(filter(lambda c: c.shape[0] > type_cluster_min_pts, total_clusters)) for cluster in result_clusters: print("Cluster: ", cluster.shape[0]) write_clusters('clusters.txt', result_clusters, 6) # ========================== Result visualization ========================== if display_mode == DisplayMode.Types: pc = ps.register_point_cloud("points_" + str(i), final_pc[:, :3], radius=0.01) pc.add_color_quantity("prim types", label2color(final_pc[:, 6].astype(np.int64)), True) elif display_mode == DisplayMode.Clusters: for i, result_cluster in enumerate(result_clusters): pc = ps.register_point_cloud("points_" + str(i), result_cluster[:, :3], radius=0.01) pc.add_color_quantity("prim types", cluster2Color(result_cluster,i), True) ps.show() print('Done')