diff --git a/dataset/shapenet.py b/dataset/shapenet.py index cf224fc..c26ad18 100644 --- a/dataset/shapenet.py +++ b/dataset/shapenet.py @@ -148,6 +148,8 @@ class CustomShapeNet(InMemoryDataset): #################################### # This is where you define the keys attr_dict = dict(y=y, pos=points[:, :3 if not self.with_normals else 6]) + if not self.with_normals: + attr_dict.update(normals=points[:, 3:6]) #################################### if self.collate_per_element: data = Data(**attr_dict) diff --git a/predict/predict.py b/predict/predict.py index a41f37b..ece4273 100644 --- a/predict/predict.py +++ b/predict/predict.py @@ -28,8 +28,9 @@ def eval_sample(net, sample): # points: (n, 3) points, gt_label = sample['points'], sample['labels'] n = points.shape[0] + f = points.shape[1] - points = points.view(1, n, 3) # make a batch + points = points.view(1, n, f) # make a batch points = points.transpose(1, 2).contiguous() points = points.to(device, dtype) @@ -237,15 +238,16 @@ def draw_sample_data(sample_data, colored_normals = False): def recreate_folder(folder): if os.path.exists(folder) and os.path.isdir(folder): shutil.rmtree(folder) - os.mkdir(folder) + os.makedirs(folder, exist_ok=True) sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../') # add project root directory parser = argparse.ArgumentParser() parser.add_argument('--npoints', type=int, default=2048, help='resample points number') -parser.add_argument('--model', type=str, default='./checkpoint/seg_model_custom_3.pth', help='model path') +parser.add_argument('--model', type=str, default='./checkpoint/seg_model_custom_1.pth', help='model path') parser.add_argument('--sample_idx', type=int, default=0, help='select a sample to segment and view result') parser.add_argument('--headers', type=strtobool, default=True, help='if raw files come with headers') +parser.add_argument('--with_normals', type=strtobool, default=True, help='if training will include normals') parser.add_argument('--collate_per_segment', type=strtobool, default=True, help='whether to look at pointclouds or sub') parser.add_argument('--has_variations', type=strtobool, default=False, help='whether a single pointcloud has variations ' @@ -303,6 +305,7 @@ if __name__ == '__main__': test_dataset = ShapeNetPartSegDataset( mode='predict', root_dir='data', + with_normals=opt.with_normals, npoints=opt.npoints, refresh=True, collate_per_segment=opt.collate_per_segment, @@ -318,7 +321,7 @@ if __name__ == '__main__': dtype = torch.float # net = PointNetPartSegmentNet(num_classes) - net = PointNet2PartSegmentNet(num_classes) + net = PointNet2PartSegmentNet(num_classes, with_normals=opt.with_normals) net.load_state_dict(torch.load(opt.model, map_location=device.type)) net = net.to(device, dtype) @@ -332,7 +335,10 @@ if __name__ == '__main__': # Predict pred_label, gt_label = eval_sample(net, sample) - sample_data = np.column_stack((sample["points"].numpy(), sample["normals"].numpy(), pred_label.numpy())) + if opt.with_normals: + sample_data = np.column_stack((sample["points"].numpy(), pred_label.numpy())) + else: + sample_data = np.column_stack((sample["points"].numpy(), sample["normals"], pred_label.numpy())) draw_sample_data(sample_data, False)