point_to_primitive/main_pipeline.py
2020-06-23 14:37:34 +02:00

94 lines
2.9 KiB
Python

from pathlib import Path
import torch
import polyscope as ps
import numpy as np
from torch.utils.data import DataLoader
# Dataset and Dataloaders
# =============================================================================
# Transforms
from torch_geometric.transforms import Compose, NormalizeScale, RandomFlip
from ml_lib.point_toolset.point_io import BatchToData
from ml_lib.utils.model_io import SavedLightningModels
# Datasets
from datasets.shapenet import ShapeNetPartSegDataset
from models import PointNet2
from utils.pointcloud import cluster_cubes, append_onehotencoded_type, label2color
from utils.project_settings import GlobalVar
def restore_logger_and_model(log_dir):
model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, model=PointNet2, n=-5)
model = model.restore()
if torch.cuda.is_available():
model.cuda()
else:
model.cpu()
return model
def predict_prim_type(input_pc, model):
input_data = dict(
norm=torch.tensor(np.array([input_pc[:, 3:6]], np.float)).unsqueeze(0),
pos=torch.tensor(input_pc[:, 0:3]).unsqueeze(0),
)
batch_to_data = BatchToData()
data = batch_to_data(input_data)
y = loaded_model(data.to(device='cuda' if torch.cuda.is_available() else 'cpu'))
y_primary = torch.argmax(y.main_out, dim=-1).cpu().numpy()
if input_pc.shape[1] > 6:
input_pc = input_pc[:, :6]
return np.concatenate((input_pc, y_primary.reshape(-1, 1)), axis=-1)
if __name__ == '__main__':
input_pc_path = Path('data') / 'pc' / 'test.xyz'
model_path = Path('output') / 'PN2' / 'PN_9843bf499399786cfd58fe79fa1b3db8' / 'version_0'
# config_filename = 'config.ini'
# config = ThisConfig()
# config.read_file((Path(model_path) / config_filename).open('r'))
loaded_model = restore_logger_and_model(model_path)
loaded_model.eval()
#input_pc = read_pointcloud(input_pc_path, ' ', False)
# input_pc = normalize_pointcloud(input_pc)
# TEST DATASET
transforms = Compose([NormalizeScale(), ])
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=True,
refresh=True, transform=transforms)
grid_clusters = cluster_cubes(test_dataset[0], [3, 3, 3], max_points_per_cluster=1024)
ps.init()
for i, grid_cluster_pc in enumerate(grid_clusters):
print("Cluster pointcloud size: {}".format(grid_cluster_pc.shape[0]))
pc_with_prim_type = predict_prim_type(grid_cluster_pc, loaded_model)
# pc_with_prim_type = polytopes_to_planes(pc_with_prim_type)
pc_with_prim_type = append_onehotencoded_type(pc_with_prim_type)
pc = ps.register_point_cloud("points_" + str(i), pc_with_prim_type[:, :3], radius=0.01)
pc.add_color_quantity("prim types", label2color(pc_with_prim_type[:, 6].astype(np.int64)), True)
ps.show()
print('Done')