diff --git a/datasets/shapenet.py b/datasets/shapenet.py index 2c5f858..de1754c 100644 --- a/datasets/shapenet.py +++ b/datasets/shapenet.py @@ -94,6 +94,30 @@ class CustomShapeNet(InMemoryDataset): def num_classes(self): return len(self.categories) if self.poly_as_plane else (len(self.categories) - 2) + @property + def class_map_all(self): + return {0: 0, + 1: 1, + 2: None, + 3: 2, + 4: 3, + 5: None, + 6: 4, + 7: None + } + + @property + def class_map_poly_as_plane(self): + return {0: 0, + 1: 1, + 2: None, + 3: 2, + 4: 2, + 5: None, + 6: 2, + 7: None + } + def _load_dataset(self): data, slices = None, None filepath = self.processed_paths[0] @@ -154,6 +178,7 @@ class CustomShapeNet(InMemoryDataset): datasets = defaultdict(list) path_to_clouds = self.raw_dir / self.mode found_clouds = list(path_to_clouds.glob('*.xyz')) + class_map = self.class_map_all if not self.poly_as_plane else self.class_map_poly_as_plane if len(found_clouds): for pointcloud in tqdm(found_clouds): if self.cluster_type not in pointcloud.name: @@ -171,7 +196,7 @@ class CustomShapeNet(InMemoryDataset): raise ValueError('Check the Input!!!!!!') # Expand the values from the csv by fake labels if non are provided. vals = vals + [0] * (8 - len(vals)) - + vals[-2] = float(class_map[int(vals[-2])]) src[vals[-1]].append(vals) # Switch from un-pickable Defaultdict to Standard Dict @@ -193,10 +218,6 @@ class CustomShapeNet(InMemoryDataset): if all([x == 0 for x in tensor]): continue tensor = tensor.unsqueeze(0) - if self.poly_as_plane: - tensor[:, -2][tensor[:, -2] == float(self.classes.Plane)] = 2.0 - tensor[:, -2][tensor[:, -2] == float(self.classes.Box)] = 2.0 - tensor[:, -2][tensor[:, -2] == float(self.classes.Polytope)] = 2.0 src[key] = tensor for key, values in src.items(): diff --git a/main.py b/main.py index 53f78f1..1b157c8 100644 --- a/main.py +++ b/main.py @@ -54,7 +54,7 @@ def run_lightning_loop(config_obj): show_progress_bar=True, weights_save_path=logger.log_dir, gpus=[0] if torch.cuda.is_available() else None, - check_val_every_n_epoch=10, + check_val_every_n_epoch=2, # num_sanity_val_steps=config_obj.train.num_sanity_val_steps, # row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting # log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting diff --git a/models/point_net_2.py b/models/point_net_2.py index af9273a..f1f97a1 100644 --- a/models/point_net_2.py +++ b/models/point_net_2.py @@ -25,7 +25,7 @@ class PointNet2(BaseValMixin, # Dataset # ============================================================================= # rot_max_angle = 15 - trans_max_distance = 0.02 + trans_max_distance = 0.01 transforms = Compose( [ RandomFlip(0, p=0.8), @@ -54,7 +54,6 @@ class PointNet2(BaseValMixin, self.n_classes = len(GlobalVar.classes) if not self.params.poly_as_plane else (len(GlobalVar.classes) - 2) # Modules - self.point_net_core = () self.lin3 = torch.nn.Linear(128, self.n_classes) # Utility