148 lines
5.1 KiB
Python
148 lines
5.1 KiB
Python
from pathlib import Path
|
|
|
|
import torch
|
|
|
|
import polyscope as ps
|
|
import numpy as np
|
|
|
|
# Dataset and Dataloaders
|
|
# =============================================================================
|
|
|
|
# Transforms
|
|
from torch_geometric.transforms import Compose, NormalizeScale
|
|
|
|
from ml_lib.point_toolset.point_io import BatchToData
|
|
from ml_lib.utils.model_io import SavedLightningModels
|
|
|
|
# Datasets
|
|
from datasets.shapenet import ShapeNetPartSegDataset
|
|
from models import PointNet2
|
|
from utils.pointcloud import cluster_cubes, append_onehotencoded_type, label2color, \
|
|
write_clusters, cluster2Color, cluster_dbscan
|
|
from utils.project_settings import GlobalVar, DataClass
|
|
|
|
|
|
class DisplayMode(DataClass):
|
|
Clusters = 0,
|
|
Types = 1,
|
|
Nothing = 2
|
|
|
|
def restore_logger_and_model(log_dir):
|
|
model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, model=PointNet2, n=-1)
|
|
model = model.restore()
|
|
if torch.cuda.is_available():
|
|
model.cuda()
|
|
else:
|
|
model.cpu()
|
|
return model
|
|
|
|
|
|
def predict_prim_type(input_pc, model):
|
|
|
|
input_data = dict(
|
|
norm=torch.tensor(np.array([input_pc[:, 3:6]], np.float)).unsqueeze(0),
|
|
pos=torch.tensor(input_pc[:, 0:3]).unsqueeze(0),
|
|
)
|
|
|
|
batch_to_data = BatchToData()
|
|
|
|
data = batch_to_data(input_data)
|
|
y = loaded_model(data.to(device='cuda' if torch.cuda.is_available() else 'cpu'))
|
|
y_primary = torch.argmax(y.main_out, dim=-1).cpu().numpy()
|
|
|
|
if input_pc.shape[1] > 6:
|
|
input_pc = input_pc[:, :6]
|
|
return np.concatenate((input_pc, y_primary.reshape(-1, 1)), axis=-1)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
display_mode = DisplayMode.Types
|
|
grid_cluster_max_pts = 32 * 1024
|
|
grid_clusters = [1, 1, 1]
|
|
type_cluster_eps = 0.1
|
|
type_cluster_min_pts = 100
|
|
|
|
model_path = Path('output') / 'PN2' / 'PN_597ecab330b04d977cda8a09ae0e5f6e' / 'version_0'
|
|
|
|
loaded_model = restore_logger_and_model(model_path)
|
|
loaded_model.eval()
|
|
|
|
transforms = Compose([NormalizeScale(), ])
|
|
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False,
|
|
refresh=True, transform=transforms, cluster_type=None)
|
|
|
|
grid_clusters = cluster_cubes(test_dataset[0], grid_clusters, max_points_per_cluster=grid_cluster_max_pts)
|
|
|
|
ps.init()
|
|
|
|
# ========================== Grid Clustering ==========================
|
|
|
|
grid_cluster_pcs = []
|
|
|
|
for i, grid_cluster_pc in enumerate(grid_clusters):
|
|
|
|
print("Cluster pointcloud size: {}".format(grid_cluster_pc.shape[0]))
|
|
|
|
pc_with_prim_type = predict_prim_type(grid_cluster_pc, loaded_model)
|
|
|
|
# Re-Map Primitive type for 1-hot-encoding.
|
|
pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 0.0] = 0.0 # Sphere
|
|
pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 1.0] = 1.0 # Cylinder
|
|
pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 3.0] = 2.0 # Box
|
|
pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 4.0] = 2.0 # Polytope
|
|
pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 6.0] = 2.0 # Plane
|
|
|
|
pc_with_prim_type = append_onehotencoded_type(pc_with_prim_type)
|
|
|
|
grid_cluster_pcs.append(pc_with_prim_type)
|
|
|
|
# Merge grid cluster pcs together
|
|
final_pc = np.concatenate(grid_cluster_pcs)
|
|
|
|
# ========================== DBSCAN Clustering ==========================
|
|
|
|
print("DB Scan on point cloud " + str(final_pc.shape))
|
|
total_clusters = []
|
|
|
|
clusters = cluster_dbscan(final_pc, [0, 1, 2, 3, 4, 5], eps=type_cluster_eps,
|
|
min_samples=type_cluster_min_pts)
|
|
print("Pre-clustering done. Clusters: ", len(clusters))
|
|
|
|
for cluster in clusters:
|
|
|
|
print("2nd level clustering ..")
|
|
|
|
prim_types_in_cluster = len(np.unique(cluster[:, 6], axis=0))
|
|
if prim_types_in_cluster == 1:
|
|
print("No need for 2nd level clustering since there is only a single primitive type in the cluster.")
|
|
total_clusters.append(cluster)
|
|
else:
|
|
sub_clusters = cluster_dbscan(cluster, [0, 1, 2, 7, 8, 9], eps=type_cluster_eps,
|
|
min_samples=type_cluster_min_pts)
|
|
print("Sub clusters: ", len(sub_clusters))
|
|
total_clusters.extend(sub_clusters)
|
|
|
|
result_clusters = list(filter(lambda c: c.shape[0] > type_cluster_min_pts, total_clusters))
|
|
|
|
for cluster in result_clusters:
|
|
print("Cluster: ", cluster.shape[0])
|
|
|
|
write_clusters('clusters.txt', result_clusters, 6)
|
|
|
|
# ========================== Result visualization ==========================
|
|
if display_mode == DisplayMode.Types:
|
|
|
|
pc = ps.register_point_cloud("points_" + str(i), final_pc[:, :3], radius=0.01)
|
|
pc.add_color_quantity("prim types", label2color(final_pc[:, 6].astype(np.int64)), True)
|
|
|
|
elif display_mode == DisplayMode.Clusters:
|
|
|
|
for i, result_cluster in enumerate(result_clusters):
|
|
pc = ps.register_point_cloud("points_" + str(i), result_cluster[:, :3], radius=0.01)
|
|
pc.add_color_quantity("prim types", cluster2Color(result_cluster,i), True)
|
|
|
|
ps.show()
|
|
|
|
print('Done')
|