From b7d127e8402988949a61d618197f4b107c10b538 Mon Sep 17 00:00:00 2001 From: Si11ium Date: Thu, 1 Aug 2019 18:16:17 +0200 Subject: [PATCH] File based header detection, collate_per_PC training. --- dataset/shapenet.py | 5 +++-- vis/show_seg_res.py | 8 ++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/dataset/shapenet.py b/dataset/shapenet.py index 439e0e4..129bd84 100644 --- a/dataset/shapenet.py +++ b/dataset/shapenet.py @@ -287,10 +287,11 @@ class PredictNetPartSegDataset(Dataset): Resample raw point cloud to fixed number of points. Map raw label from range [1, N] to [0, N-1]. """ - def __init__(self, root_dir, transform=None, npoints=2048, headers=True): + def __init__(self, root_dir, train=False, transform=None, npoints=2048, headers=True, collate_per_segment=False): super(PredictNetPartSegDataset, self).__init__() self.npoints = npoints - self.dataset = ShapeNetPartSegDataset(root=root_dir, train=False, transform=transform, headers=headers) + self.dataset = PredictionShapeNet(root=root_dir, train=train, transform=transform, + headers=headers, collate_per_segment=collate_per_segment) def __getitem__(self, index): data = self.dataset[index] diff --git a/vis/show_seg_res.py b/vis/show_seg_res.py index 90e20e7..6de9a6a 100644 --- a/vis/show_seg_res.py +++ b/vis/show_seg_res.py @@ -5,7 +5,7 @@ import sys import os sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../') # add project root directory -from dataset.shapenet import ShapeNetPartSegDataset +from dataset.shapenet import PredictNetPartSegDataset, ShapeNetPartSegDataset from model.pointnet2_part_seg import PointNet2PartSegmentNet import torch_geometric.transforms as GT import torch @@ -28,7 +28,7 @@ if __name__ == '__main__': print('Construct dataset ..') test_transform = GT.Compose([GT.NormalizeScale(),]) - test_dataset = ShapeNetPartSegDataset( + test_dataset = PredictNetPartSegDataset( root_dir=opt.dataset, collate_per_segment=False, train=False, @@ -128,12 +128,12 @@ if __name__ == '__main__': print('View gt labels ..') view_points_labels(points, gt_labels) - if True: + if False: print('View diff labels ..') print(diff_labels) view_points_labels(points, diff_labels) - if True: + if False: print('View pred labels ..') print(pred_labels) view_points_labels(points, pred_labels)