diff --git a/_parameters.py b/_parameters.py index c8f97d3..d686425 100644 --- a/_parameters.py +++ b/_parameters.py @@ -23,11 +23,10 @@ main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.geten main_arg_parser.add_argument("--data_worker", type=int, default=10, help="") main_arg_parser.add_argument("--data_npoints", type=int, default=1024, help="") main_arg_parser.add_argument("--data_root", type=str, default='data', help="") +main_arg_parser.add_argument("--data_refresh", type=strtobool, default=False, help="") main_arg_parser.add_argument("--data_dataset_type", type=str, default='ShapeNetPartSegDataset', help="") main_arg_parser.add_argument("--data_cluster_type", type=str, default='grid', help="") -main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="") -main_arg_parser.add_argument("--data_normals_as_cords", type=strtobool, default=True, help="") -main_arg_parser.add_argument("--data_refresh", type=strtobool, default=False, help="") +main_arg_parser.add_argument("--data_normals_as_cords", type=strtobool, default=False, help="") main_arg_parser.add_argument("--data_poly_as_plane", type=strtobool, default=True, help="") # Transformations diff --git a/datasets/shapenet.py b/datasets/shapenet.py index 3f8e0df..a6746f8 100644 --- a/datasets/shapenet.py +++ b/datasets/shapenet.py @@ -1,3 +1,4 @@ +import pickle from pathlib import Path from typing import Union from warnings import warn @@ -96,19 +97,32 @@ class CustomShapeNet(InMemoryDataset): def _load_dataset(self): data, slices = None, None filepath = self.processed_paths[0] + config_path = Path(filepath).parent / f'{self.mode}_params.ini' + if config_path.exists() and not self.refresh and not self.mode == DataSplit().predict: + with config_path.open('rb') as f: + config = pickle.load(f) + if config == self._build_config(): + pass + else: + print('The given data parameters seem to differ from the one used to process the dataset:') + self.refresh = True if self.refresh: try: os.remove(filepath) + try: + config_path.unlink() + except: + pass print('Processed Location "Refreshed" (We deleted the Files)') except FileNotFoundError: - print('You meant to refresh the allready processed dataset, but there were none...') + print('The allready processed dataset was meant to be refreshed, but there was none...') print('continue processing') pass while True: try: data, slices = torch.load(filepath) - print('Dataset Loaded') + print(f'{self.mode.title()}-Dataset Loaded') break except FileNotFoundError: status = self.check_and_resolve_cloud_count() @@ -117,8 +131,18 @@ class CustomShapeNet(InMemoryDataset): break self.process() continue + if not self.mode == DataSplit().predict: + config = self._build_config() + with config_path.open('wb') as f: + pickle.dump(config, f, pickle.HIGHEST_PROTOCOL) return data, slices + def _build_config(self): + conf_dict = {key:str(val) for key, val in self.__dict__.items() if '__' not in key and key not in [ + 'classes', 'refresh', 'transform', 'data', 'slices' + ]} + return conf_dict + def _pre_transform_and_filter(self, data): if self.pre_filter is not None and not self.pre_filter(data): data = self.pre_filter(data) @@ -129,76 +153,83 @@ class CustomShapeNet(InMemoryDataset): def process(self, delimiter=' '): datasets = defaultdict(list) path_to_clouds = self.raw_dir / self.mode - for pointcloud in tqdm(path_to_clouds.glob('*.xyz')): - if self.cluster_type not in pointcloud.name: - continue - data = None - - with pointcloud.open('r') as f: - src = defaultdict(list) - # Iterate over all rows - for row in f: - if row != '': - vals = row.rstrip().split(delimiter)[None:None] - vals = [float(x) if x not in ['-nan(ind)', 'nan(ind)'] else 0 for x in vals] - src[vals[-1]].append(vals) - - # Switch from un-pickable Defaultdict to Standard Dict - src = dict(src) - - # Transform the Dict[List] to Dict[torch.Tensor] - for key, values in src.items(): - src[key] = torch.tensor(values, dtype=torch.double).squeeze() - - # Screw the Sorting and make it a FullCloud rather than a seperated - if not self.collate_per_segment: - src = dict( - all=torch.cat(tuple(src.values())) - ) - - # Transform Box and Polytope to Plane if poly_as_plane is set - for key, tensor in src.items(): - if tensor.ndim == 1: - if all([x == 0 for x in tensor]): - continue - tensor = tensor.unsqueeze(0) - if self.poly_as_plane: - tensor[:, -2][tensor[:, -2] == float(self.classes.Plane)] = 4.0 - tensor[:, -2][tensor[:, -2] == float(self.classes.Box)] = 4.0 - tensor[:, -2][tensor[:, -2] == float(self.classes.Polytope)] = 4.0 - tensor[:, -2][tensor[:, -2] == self.classes.Torus] = 3.0 - src[key] = tensor - - for key, values in src.items(): - try: - points = values[:, :-2] - except IndexError: + found_clouds = list(path_to_clouds.glob('*.xyz')) + if len(found_clouds): + for pointcloud in tqdm(found_clouds): + if self.cluster_type not in pointcloud.name: continue - y = torch.as_tensor(values[:, -2], dtype=torch.long) - y_c = torch.as_tensor(values[:, -1], dtype=torch.long) - #################################### - # This is where you define the keys - attr_dict = dict( - y=y, - y_c=y_c, - pos=points[:, :3], - norm=points[:, 3:6] - ) + data = None - #################################### - if self.collate_per_segment: - data = Data(**attr_dict) - else: - if data is None: - data = defaultdict(list) - for attr_key, val in attr_dict.items(): - data[attr_key].append(val) + with pointcloud.open('r') as f: + src = defaultdict(list) + # Iterate over all rows + for row in f: + if row != '': + vals = row.rstrip().split(delimiter)[None:None] + vals = [float(x) if x not in ['-nan(ind)', 'nan(ind)'] else 0 for x in vals] + if len(vals) < 6: + raise ValueError('Check the Input!!!!!!') + # Expand the values from the csv by fake labels if non are provided. + vals = vals + [0] * (8 - len(vals)) - # data = self._pre_transform_and_filter(data) - if self.collate_per_segment: - datasets[self.mode].append(data) - if not self.collate_per_segment: - datasets[self.mode].append(Data(**{key: torch.cat(data[key]) for key in data.keys()})) + src[vals[-1]].append(vals) + + # Switch from un-pickable Defaultdict to Standard Dict + src = dict(src) + + # Transform the Dict[List] to Dict[torch.Tensor] + for key, values in src.items(): + src[key] = torch.tensor(values, dtype=torch.double).squeeze() + + # Screw the Sorting and make it a FullCloud rather than a seperated + if not self.collate_per_segment: + src = dict( + all=torch.cat(tuple(src.values())) + ) + + # Transform Box and Polytope to Plane if poly_as_plane is set + for key, tensor in src.items(): + if tensor.ndim == 1: + if all([x == 0 for x in tensor]): + continue + tensor = tensor.unsqueeze(0) + if self.poly_as_plane: + tensor[:, -2][tensor[:, -2] == float(self.classes.Plane)] = 4.0 + tensor[:, -2][tensor[:, -2] == float(self.classes.Box)] = 4.0 + tensor[:, -2][tensor[:, -2] == float(self.classes.Polytope)] = 4.0 + tensor[:, -2][tensor[:, -2] == self.classes.Torus] = 3.0 + src[key] = tensor + + for key, values in src.items(): + try: + points = values[:, :-2] + except IndexError: + continue + y = torch.as_tensor(values[:, -2], dtype=torch.long) + y_c = torch.as_tensor(values[:, -1], dtype=torch.long) + #################################### + # This is where you define the keys + attr_dict = dict( + y=y, + y_c=y_c, + pos=points[:, :3], + norm=points[:, 3:6] + ) + + #################################### + if self.collate_per_segment: + data = Data(**attr_dict) + else: + if data is None: + data = defaultdict(list) + for attr_key, val in attr_dict.items(): + data[attr_key].append(val) + + # data = self._pre_transform_and_filter(data) + if self.collate_per_segment: + datasets[self.mode].append(data) + if not self.collate_per_segment: + datasets[self.mode].append(Data(**{key: torch.cat(data[key]) for key in data.keys()})) if datasets[self.mode]: os.makedirs(self.processed_dir, exist_ok=True) diff --git a/main_pipeline.py b/main_pipeline.py index 327d01e..317dffd 100644 --- a/main_pipeline.py +++ b/main_pipeline.py @@ -54,9 +54,9 @@ def predict_prim_type(input_pc, model): if __name__ == '__main__': - input_pc_path = Path('data') / 'pc' / 'test.xyz' + # input_pc_path = Path('data') / 'pc' / 'test.xyz' - model_path = Path('output') / 'PN2' / 'PN_9843bf499399786cfd58fe79fa1b3db8' / 'version_0' + model_path = Path('output') / 'PN2' / 'PN_14628b734c5b651b013ad9e36c406934' / 'version_0' # config_filename = 'config.ini' # config = ThisConfig() # config.read_file((Path(model_path) / config_filename).open('r')) @@ -72,7 +72,7 @@ if __name__ == '__main__': test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False, refresh=True, transform=transforms) - grid_clusters = cluster_cubes(test_dataset[1], [1, 1, 1], max_points_per_cluster=32768) + grid_clusters = cluster_cubes(test_dataset[0], [1, 1, 1], max_points_per_cluster=8192) ps.init() diff --git a/models/_point_net_2.py b/models/_point_net_2.py index ca4d7bd..3534257 100644 --- a/models/_point_net_2.py +++ b/models/_point_net_2.py @@ -24,13 +24,13 @@ class _PointNetCore(LightningBaseModule, ABC): self.cord_dims = 6 if self.params.normals_as_cords else 3 # Modules - self.sa1_module = SAModule(0.2, 0.2, MLP([self.cord_dims, 64, 64, 128])) + self.sa1_module = SAModule(0.2, 0.2, MLP([3 + 3, 64, 64, 128])) self.sa2_module = SAModule(0.25, 0.4, MLP([128 + self.cord_dims, 128, 128, 256])) self.sa3_module = GlobalSAModule(MLP([256 + self.cord_dims, 256, 512, 1024]), channels=self.cord_dims) self.fp3_module = FPModule(1, MLP([1024 + 256, 256, 256])) self.fp2_module = FPModule(3, MLP([256 + 128, 256, 128])) - self.fp1_module = FPModule(3, MLP([128, 128, 128, 128])) + self.fp1_module = FPModule(3, MLP([128 + (3 if not self.params.normals_as_cords else 0), 128, 128, 128])) self.lin1 = torch.nn.Linear(128, 128) self.lin2 = torch.nn.Linear(128, 128) diff --git a/utils/pointcloud.py b/utils/pointcloud.py index f0b6476..273b6fc 100644 --- a/utils/pointcloud.py +++ b/utils/pointcloud.py @@ -75,9 +75,12 @@ def write_pointcloud(file, pc, numCols=6): def farthest_point_sampling(pts, K): - if isinstance(pts, Data): - pts = pts.pos.numpy() - if pts.shape[0] < K: + if K > 0: + if isinstance(pts, Data): + pts = pts.pos.numpy() + if pts.shape[0] < K: + return pts + else: return pts def calc_distances(p0, points):