From a19bd9cafdf0c51dc3a92538f28f4ac2e4844e37 Mon Sep 17 00:00:00 2001 From: Si11ium Date: Fri, 19 Jun 2020 19:00:07 +0200 Subject: [PATCH] dataset modification --- datasets/shapenet.py | 36 +++++++++++++++++++++--------------- main_pipeline.py | 14 +++++++++----- utils/pointcloud.py | 5 +++++ utils/project_settings.py | 1 + 4 files changed, 36 insertions(+), 20 deletions(-) diff --git a/datasets/shapenet.py b/datasets/shapenet.py index 2b959ae..b756640 100644 --- a/datasets/shapenet.py +++ b/datasets/shapenet.py @@ -8,7 +8,6 @@ from collections import defaultdict import os from torch.utils.data import Dataset from tqdm import tqdm -import glob import torch from torch_geometric.data import InMemoryDataset @@ -45,7 +44,9 @@ class CustomShapeNet(InMemoryDataset): assert mode in self.modes.keys(), f'"mode" must be one of {self.modes.keys()}' # Set the Dataset Parameters - self.collate_per_segment, self.mode, self.refresh = collate_per_segment, mode, refresh + self.collate_per_segment = collate_per_segment + self.mode = mode + self.refresh = refresh self.with_normals = with_normals root_dir = Path(root_dir) super(CustomShapeNet, self).__init__(root_dir, transform, pre_transform, pre_filter) @@ -57,15 +58,15 @@ class CustomShapeNet(InMemoryDataset): return [f'{self.mode}.pt'] def check_and_resolve_cloud_count(self): - if self.raw_dir.exists(): - dir_count = len([name for name in os.listdir(self.raw_dir) if os.path.isdir(os.path.join(self.raw_dir, name))]) + if (self.raw_dir / self.mode).exists(): + file_count = len([cloud for cloud in (self.raw_dir / self.mode).iterdir() if cloud.is_file()]) - if dir_count: - print(f'{dir_count} folders have been found....') - return dir_count + if file_count: + print(f'{file_count} files have been found....') + return file_count else: warn(ResourceWarning("No raw pointclouds have been found. Was this intentional?")) - return dir_count + return file_count warn(ResourceWarning("The raw data folder does not exist. Was this intentional?")) return -1 @@ -99,7 +100,7 @@ class CustomShapeNet(InMemoryDataset): continue return data, slices - def _transform_and_filter(self, data): + def _pre_transform_and_filter(self, data): # ToDo: ANy filter to apply? Then do it here. if self.pre_filter is not None and not self.pre_filter(data): data = self.pre_filter(data) @@ -133,7 +134,9 @@ class CustomShapeNet(InMemoryDataset): src[key] = torch.tensor(values, dtype=torch.double).squeeze() if not self.collate_per_segment: - src = dict(all=torch.stack([x for x in src.values()])) + src = dict( + all=torch.cat(tuple(src.values())) + ) for key, values in src.items(): try: @@ -157,17 +160,18 @@ class CustomShapeNet(InMemoryDataset): if self.collate_per_segment: data = Data(**attr_dict) else: - if not data: + if data is None: data = defaultdict(list) # points=points, norm=points[:, 3:] for key, val in attr_dict.items(): data[key].append(val) + # data = Data(**data) - data = self._transform_and_filter(data) + # data = self._pre_transform_and_filter(data) if self.collate_per_segment: datasets[self.mode].append(data) if not self.collate_per_segment: - # Todo: What is this? + # This is just to be sure, but should not be needed, since src[all] == all there is in this cloud datasets[self.mode].append(Data(**{key: torch.cat(data[key]) for key in data.keys()})) if datasets[self.mode]: @@ -198,8 +202,10 @@ class ShapeNetPartSegDataset(Dataset): # Resample to fixed number of points try: - npoints = self.npoints if self.mode != 'predict' else data.pos.shape[0] - choice = np.random.choice(data.pos.shape[0], npoints, replace=False if self.mode == 'predict' else True) + npoints = self.npoints if self.mode != DataSplit.predict else data.pos.shape[0] + choice = np.random.choice(data.pos.shape[0], npoints, + replace=False if self.mode == DataSplit.predict else True + ) except ValueError: choice = [] diff --git a/main_pipeline.py b/main_pipeline.py index 819d58b..ec27789 100644 --- a/main_pipeline.py +++ b/main_pipeline.py @@ -57,7 +57,7 @@ def predict_prim_type(input_pc, model): if __name__ == '__main__': - input_pc_path = Path('data') / 'pc' / 'pc.txt' + input_pc_path = Path('data') / 'pc' / 'test.xyz' model_path = Path('output') / 'PN2' / 'PN_26512907a2de0664bfad2349a6bffee3' / 'version_0' # config_filename = 'config.ini' @@ -66,15 +66,19 @@ if __name__ == '__main__': loaded_model = restore_logger_and_model(model_path) loaded_model.eval() - input_pc = read_pointcloud(input_pc_path, ' ', False) + #input_pc = read_pointcloud(input_pc_path, ' ', False) - input_pc = normalize_pointcloud(input_pc) + # input_pc = normalize_pointcloud(input_pc) - grid_clusters = cluster_cubes(input_pc, [1,1,1], 1024) + # TEST DATASET + test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False, + npoints=1024, refresh=True) + + grid_clusters = cluster_cubes(test_dataset[0], [3, 3, 3], max_points_per_cluster=1024) ps.init() - for i,grid_cluster_pc in enumerate(grid_clusters): + for i, grid_cluster_pc in enumerate(grid_clusters): print("Cluster pointcloud size: {}".format(grid_cluster_pc.shape[0])) diff --git a/utils/pointcloud.py b/utils/pointcloud.py index fbceedf..911f03f 100644 --- a/utils/pointcloud.py +++ b/utils/pointcloud.py @@ -17,6 +17,7 @@ from pyod.models.loci import LOCI from pyod.models.hbos import HBOS from pyod.models.lscp import LSCP from pyod.models.feature_bagging import FeatureBagging +from torch_geometric.data import Data from utils.project_settings import Classes @@ -116,6 +117,10 @@ def cluster_cubes(data, cluster_dims, max_points_per_cluster=-1, min_points_per_ print("no need to cluster.") return [farthest_point_sampling(data, max_points_per_cluster)] + if isinstance(data, Data): + import torch + data = torch.cat((data.pos, data.norm, data.y.double().unsqueeze(-1)), dim=-1).numpy() + max = data[:, :3].max(axis=0) max += max * 0.01 diff --git a/utils/project_settings.py b/utils/project_settings.py index 6033d8e..b13c54f 100644 --- a/utils/project_settings.py +++ b/utils/project_settings.py @@ -43,6 +43,7 @@ class DataSplit(DataClass): train = 'train' devel = 'devel' test = 'test' + predict = 'predict' class GlobalVar(DataClass):