diff --git a/datasets/shapenet.py b/datasets/shapenet.py index eb7ed55..2b959ae 100644 --- a/datasets/shapenet.py +++ b/datasets/shapenet.py @@ -56,14 +56,18 @@ class CustomShapeNet(InMemoryDataset): def processed_file_names(self): return [f'{self.mode}.pt'] - def __download(self): - dir_count = len([name for name in os.listdir(self.raw_dir) if os.path.isdir(os.path.join(self.raw_dir, name))]) + def check_and_resolve_cloud_count(self): + if self.raw_dir.exists(): + dir_count = len([name for name in os.listdir(self.raw_dir) if os.path.isdir(os.path.join(self.raw_dir, name))]) - if dir_count: - print(f'{dir_count} folders have been found....') - return dir_count - warn(ResourceWarning("No raw pointclouds have been found. Was this intentional?")) - return dir_count + if dir_count: + print(f'{dir_count} folders have been found....') + return dir_count + else: + warn(ResourceWarning("No raw pointclouds have been found. Was this intentional?")) + return dir_count + warn(ResourceWarning("The raw data folder does not exist. Was this intentional?")) + return -1 @property def num_classes(self): @@ -87,6 +91,10 @@ class CustomShapeNet(InMemoryDataset): print('Dataset Loaded') break except FileNotFoundError: + status = self.check_and_resolve_cloud_count() + if status in [0, -1]: + print(f'No dataset was loaded, status: {status}') + break self.process() continue return data, slices diff --git a/main_pipeline.py b/main_pipeline.py index 03ef08a..819d58b 100644 --- a/main_pipeline.py +++ b/main_pipeline.py @@ -42,9 +42,8 @@ def restore_logger_and_model(log_dir): def predict_prim_type(input_pc, model): - input_data = dict(norm=torch.tensor(np.array([input_pc[:, 3:6]], np.float)), - pos=torch.tensor(input_pc[:, 0:3]), - y=np.zeros(input_pc.shape[0]) + input_data = dict(norm=torch.tensor(np.array([input_pc[:, 3:6]], np.float)).unsqueeze(0), + pos=torch.tensor(input_pc[:, 0:3]).unsqueeze(0), ) batch_to_data = BatchToData() @@ -71,7 +70,7 @@ if __name__ == '__main__': input_pc = normalize_pointcloud(input_pc) - grid_clusters = cluster_cubes(input_pc, [1,1,1], 2048) + grid_clusters = cluster_cubes(input_pc, [1,1,1], 1024) ps.init() diff --git a/utils/pointcloud.py b/utils/pointcloud.py index d1c272c..fbceedf 100644 --- a/utils/pointcloud.py +++ b/utils/pointcloud.py @@ -1,6 +1,8 @@ import numpy as np from sklearn.cluster import DBSCAN +import open3d as o3d + from pyod.models.knn import KNN from pyod.models.sod import SOD from pyod.models.abod import ABOD