main_pipeline fixed
This commit is contained in:
parent
5353220890
commit
95b1503f78
@ -18,6 +18,7 @@ from ml_lib.utils.model_io import SavedLightningModels
|
|||||||
from datasets.shapenet import ShapeNetPartSegDataset
|
from datasets.shapenet import ShapeNetPartSegDataset
|
||||||
from utils.project_config import ThisConfig
|
from utils.project_config import ThisConfig
|
||||||
|
|
||||||
|
raise BrokenPipeError('There are Imports that need to be fixed first!!!!')
|
||||||
|
|
||||||
def prepare_dataloader(config_obj):
|
def prepare_dataloader(config_obj):
|
||||||
dataset = ShapeNetPartSegDataset(config_obj.data.root, mode=GlobalVar.data_split.test,
|
dataset = ShapeNetPartSegDataset(config_obj.data.root, mode=GlobalVar.data_split.test,
|
||||||
|
@ -19,7 +19,7 @@ from datasets.shapenet import ShapeNetPartSegDataset
|
|||||||
from models import PointNet2
|
from models import PointNet2
|
||||||
from utils.pointcloud import cluster_cubes, append_onehotencoded_type, label2color, \
|
from utils.pointcloud import cluster_cubes, append_onehotencoded_type, label2color, \
|
||||||
write_clusters, cluster2Color, cluster_dbscan
|
write_clusters, cluster2Color, cluster_dbscan
|
||||||
from utils.project_settings import GlobalVar, DataClass
|
from utils.project_settings import dataSplit, DataClass
|
||||||
|
|
||||||
|
|
||||||
class DisplayMode(DataClass):
|
class DisplayMode(DataClass):
|
||||||
@ -27,6 +27,7 @@ class DisplayMode(DataClass):
|
|||||||
Types = 1,
|
Types = 1,
|
||||||
Nothing = 2
|
Nothing = 2
|
||||||
|
|
||||||
|
|
||||||
def restore_logger_and_model(log_dir):
|
def restore_logger_and_model(log_dir):
|
||||||
model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, model=PointNet2, n=-1)
|
model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, model=PointNet2, n=-1)
|
||||||
model = model.restore()
|
model = model.restore()
|
||||||
@ -47,7 +48,7 @@ def predict_prim_type(input_pc, model):
|
|||||||
batch_to_data = BatchToData()
|
batch_to_data = BatchToData()
|
||||||
|
|
||||||
data = batch_to_data(input_data)
|
data = batch_to_data(input_data)
|
||||||
y = loaded_model(data.to(device='cuda' if torch.cuda.is_available() else 'cpu'))
|
y = model(data.to(device='cuda' if torch.cuda.is_available() else 'cpu'))
|
||||||
y_primary = torch.argmax(y.main_out, dim=-1).cpu().numpy()
|
y_primary = torch.argmax(y.main_out, dim=-1).cpu().numpy()
|
||||||
|
|
||||||
if input_pc.shape[1] > 6:
|
if input_pc.shape[1] > 6:
|
||||||
@ -69,7 +70,7 @@ if __name__ == '__main__':
|
|||||||
loaded_model.eval()
|
loaded_model.eval()
|
||||||
|
|
||||||
transforms = Compose([NormalizeScale(), ])
|
transforms = Compose([NormalizeScale(), ])
|
||||||
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False,
|
test_dataset = ShapeNetPartSegDataset('data', mode=dataSplit.predict, collate_per_segment=False,
|
||||||
refresh=True, transform=transforms, cluster_type=None)
|
refresh=True, transform=transforms, cluster_type=None)
|
||||||
|
|
||||||
grid_clusters = cluster_cubes(test_dataset[0], grid_clusters, max_points_per_cluster=grid_cluster_max_pts)
|
grid_clusters = cluster_cubes(test_dataset[0], grid_clusters, max_points_per_cluster=grid_cluster_max_pts)
|
||||||
@ -106,7 +107,7 @@ if __name__ == '__main__':
|
|||||||
total_clusters = []
|
total_clusters = []
|
||||||
|
|
||||||
clusters = cluster_dbscan(final_pc, [0, 1, 2, 3, 4, 5], eps=type_cluster_eps,
|
clusters = cluster_dbscan(final_pc, [0, 1, 2, 3, 4, 5], eps=type_cluster_eps,
|
||||||
min_samples=type_cluster_min_pts)
|
min_samples=type_cluster_min_pts)
|
||||||
print("Pre-clustering done. Clusters: ", len(clusters))
|
print("Pre-clustering done. Clusters: ", len(clusters))
|
||||||
|
|
||||||
for cluster in clusters:
|
for cluster in clusters:
|
||||||
@ -119,7 +120,7 @@ if __name__ == '__main__':
|
|||||||
total_clusters.append(cluster)
|
total_clusters.append(cluster)
|
||||||
else:
|
else:
|
||||||
sub_clusters = cluster_dbscan(cluster, [0, 1, 2, 7, 8, 9], eps=type_cluster_eps,
|
sub_clusters = cluster_dbscan(cluster, [0, 1, 2, 7, 8, 9], eps=type_cluster_eps,
|
||||||
min_samples=type_cluster_min_pts)
|
min_samples=type_cluster_min_pts)
|
||||||
print("Sub clusters: ", len(sub_clusters))
|
print("Sub clusters: ", len(sub_clusters))
|
||||||
total_clusters.extend(sub_clusters)
|
total_clusters.extend(sub_clusters)
|
||||||
|
|
||||||
@ -133,14 +134,14 @@ if __name__ == '__main__':
|
|||||||
# ========================== Result visualization ==========================
|
# ========================== Result visualization ==========================
|
||||||
if display_mode == DisplayMode.Types:
|
if display_mode == DisplayMode.Types:
|
||||||
|
|
||||||
pc = ps.register_point_cloud("points_" + str(i), final_pc[:, :3], radius=0.01)
|
pc = ps.register_point_cloud("points_" + str(0), final_pc[:, :3], radius=0.01)
|
||||||
pc.add_color_quantity("prim types", label2color(final_pc[:, 6].astype(np.int64)), True)
|
pc.add_color_quantity("prim types", label2color(final_pc[:, 6].astype(np.int64)), True)
|
||||||
|
|
||||||
elif display_mode == DisplayMode.Clusters:
|
elif display_mode == DisplayMode.Clusters:
|
||||||
|
|
||||||
for i, result_cluster in enumerate(result_clusters):
|
for i, result_cluster in enumerate(result_clusters):
|
||||||
pc = ps.register_point_cloud("points_" + str(i), result_cluster[:, :3], radius=0.01)
|
pc = ps.register_point_cloud("points_" + str(i), result_cluster[:, :3], radius=0.01)
|
||||||
pc.add_color_quantity("prim types", cluster2Color(result_cluster,i), True)
|
pc.add_color_quantity("prim types", cluster2Color(result_cluster, i), True)
|
||||||
|
|
||||||
ps.show()
|
ps.show()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user