From 95b1503f78476af498042c6fd54ab6b083517578 Mon Sep 17 00:00:00 2001 From: Si11ium Date: Fri, 3 Jul 2020 14:45:08 +0200 Subject: [PATCH] main_pipeline fixed --- main_inference.py | 1 + main_pipeline.py | 15 ++++++++------- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/main_inference.py b/main_inference.py index 45bab96..a3c1317 100644 --- a/main_inference.py +++ b/main_inference.py @@ -18,6 +18,7 @@ from ml_lib.utils.model_io import SavedLightningModels from datasets.shapenet import ShapeNetPartSegDataset from utils.project_config import ThisConfig +raise BrokenPipeError('There are Imports that need to be fixed first!!!!') def prepare_dataloader(config_obj): dataset = ShapeNetPartSegDataset(config_obj.data.root, mode=GlobalVar.data_split.test, diff --git a/main_pipeline.py b/main_pipeline.py index 716e3f5..8fdc47f 100644 --- a/main_pipeline.py +++ b/main_pipeline.py @@ -19,7 +19,7 @@ from datasets.shapenet import ShapeNetPartSegDataset from models import PointNet2 from utils.pointcloud import cluster_cubes, append_onehotencoded_type, label2color, \ write_clusters, cluster2Color, cluster_dbscan -from utils.project_settings import GlobalVar, DataClass +from utils.project_settings import dataSplit, DataClass class DisplayMode(DataClass): @@ -27,6 +27,7 @@ class DisplayMode(DataClass): Types = 1, Nothing = 2 + def restore_logger_and_model(log_dir): model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, model=PointNet2, n=-1) model = model.restore() @@ -47,7 +48,7 @@ def predict_prim_type(input_pc, model): batch_to_data = BatchToData() data = batch_to_data(input_data) - y = loaded_model(data.to(device='cuda' if torch.cuda.is_available() else 'cpu')) + y = model(data.to(device='cuda' if torch.cuda.is_available() else 'cpu')) y_primary = torch.argmax(y.main_out, dim=-1).cpu().numpy() if input_pc.shape[1] > 6: @@ -69,7 +70,7 @@ if __name__ == '__main__': loaded_model.eval() transforms = Compose([NormalizeScale(), ]) - test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False, + test_dataset = ShapeNetPartSegDataset('data', mode=dataSplit.predict, collate_per_segment=False, refresh=True, transform=transforms, cluster_type=None) grid_clusters = cluster_cubes(test_dataset[0], grid_clusters, max_points_per_cluster=grid_cluster_max_pts) @@ -106,7 +107,7 @@ if __name__ == '__main__': total_clusters = [] clusters = cluster_dbscan(final_pc, [0, 1, 2, 3, 4, 5], eps=type_cluster_eps, - min_samples=type_cluster_min_pts) + min_samples=type_cluster_min_pts) print("Pre-clustering done. Clusters: ", len(clusters)) for cluster in clusters: @@ -119,7 +120,7 @@ if __name__ == '__main__': total_clusters.append(cluster) else: sub_clusters = cluster_dbscan(cluster, [0, 1, 2, 7, 8, 9], eps=type_cluster_eps, - min_samples=type_cluster_min_pts) + min_samples=type_cluster_min_pts) print("Sub clusters: ", len(sub_clusters)) total_clusters.extend(sub_clusters) @@ -133,14 +134,14 @@ if __name__ == '__main__': # ========================== Result visualization ========================== if display_mode == DisplayMode.Types: - pc = ps.register_point_cloud("points_" + str(i), final_pc[:, :3], radius=0.01) + pc = ps.register_point_cloud("points_" + str(0), final_pc[:, :3], radius=0.01) pc.add_color_quantity("prim types", label2color(final_pc[:, 6].astype(np.int64)), True) elif display_mode == DisplayMode.Clusters: for i, result_cluster in enumerate(result_clusters): pc = ps.register_point_cloud("points_" + str(i), result_cluster[:, :3], radius=0.01) - pc.add_color_quantity("prim types", cluster2Color(result_cluster,i), True) + pc.add_color_quantity("prim types", cluster2Color(result_cluster, i), True) ps.show()