From 7443be4c40f2339427e761e71b55ba703fad3494 Mon Sep 17 00:00:00 2001 From: Si11ium Date: Tue, 6 Aug 2019 14:33:30 +0200 Subject: [PATCH] Labels can now be placed along next to the points within the datasetfile --- dataset/shapenet.py | 9 ++++++--- predict/predict.py | 20 +++++++++++++++----- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/dataset/shapenet.py b/dataset/shapenet.py index ffe829e..74a8b9f 100644 --- a/dataset/shapenet.py +++ b/dataset/shapenet.py @@ -111,7 +111,7 @@ class CustomShapeNet(InMemoryDataset): for element in paths: # This was build to filter all variations that aregreater then 25 - pattern = re.compile('^((6[0-1]|[1-5][0-9])_\w+?\d+?|pc|\d+?_pc)\.(xyz|dat)$') + pattern = re.compile('^((6[0-1]|[1-5][0-9])_\w+?\d+?|\d+?_pc)\.(xyz|dat)$') if pattern.match(os.path.split(element)[-1]): continue else: @@ -134,8 +134,11 @@ class CustomShapeNet(InMemoryDataset): points = points[:, :-1] else: # Get the y - Label - y_raw = next(i for i, v in enumerate(self.categories.keys()) if v.lower() in element.lower()) - y_all = ([y_raw] if self.mode != 'predict' else [-1]) * points.shape[0] + if self.mode != 'predict': + y_raw = next(i for i, v in enumerate(self.categories.keys()) if v.lower() in element.lower()) + y_all = [y_raw] * points.shape[0] + else: + y_all = [-1] * points.shape[0] y = torch.as_tensor(y_all, dtype=torch.int) if self.collate_per_element: diff --git a/predict/predict.py b/predict/predict.py index 9e14e27..c5b8307 100644 --- a/predict/predict.py +++ b/predict/predict.py @@ -6,15 +6,22 @@ from dataset.shapenet import ShapeNetPartSegDataset from model.pointnet2_part_seg import PointNet2PartSegmentNet import torch_geometric.transforms as GT import torch +from distutils.util import strtobool + import numpy as np import argparse ## parser = argparse.ArgumentParser() -parser.add_argument('--dataset', type=str, default='data', help='dataset path') parser.add_argument('--npoints', type=int, default=2048, help='resample points number') -parser.add_argument('--model', type=str, default='./checkpoint/seg_model_custom_246.pth', help='model path') +parser.add_argument('--model', type=str, default='./checkpoint/seg_model_custom_3.pth', help='model path') parser.add_argument('--sample_idx', type=int, default=0, help='select a sample to segment and view result') +parser.add_argument('--headers', type=strtobool, default=True, help='if raw files come with headers') +parser.add_argument('--collate_per_segment', type=strtobool, default=True, help='whether to look at pointclouds or sub') +parser.add_argument('--has_variations', type=strtobool, default=False, + help='whether a single pointcloud has variations ' + 'named int(id)_pc.(xyz|dat) look at pointclouds or sub') + opt = parser.parse_args() print(opt) @@ -26,11 +33,14 @@ if __name__ == '__main__': test_dataset = ShapeNetPartSegDataset( mode='predict', - root_dir=opt.dataset, - transform=None, + root_dir='data', npoints=opt.npoints, - refresh=True + refresh=True, + collate_per_segment=opt.collate_per_segment, + has_variations=opt.has_variations, + headers=opt.headers ) + num_classes = test_dataset.num_classes() print('test dataset size: ', len(test_dataset))