diff --git a/dataset/shapenet.py b/dataset/shapenet.py index 74a8b9f..dfb334e 100644 --- a/dataset/shapenet.py +++ b/dataset/shapenet.py @@ -141,13 +141,16 @@ class CustomShapeNet(InMemoryDataset): y_all = [-1] * points.shape[0] y = torch.as_tensor(y_all, dtype=torch.int) + attr_dict = dict(y=y, pos=points[:, :3]) + if self.mode == 'predict': + attr_dict.update(normals=points[:, 3:6]) if self.collate_per_element: - data = Data(y=y, pos=points[:, :3]) # , points=points, norm=points[:, 3:]) + data = Data(**attr_dict) else: if not data: data = defaultdict(list) # points=points, norm=points[:, 3:] - for key, val in dict(y=y, pos=points[:, :3]).items(): + for key, val in attr_dict.items(): data[key].append(val) data = self._transform_and_filter(data) @@ -175,9 +178,10 @@ class ShapeNetPartSegDataset(Dataset): Resample raw point cloud to fixed number of points. Map raw label from range [1, N] to [0, N-1]. """ - def __init__(self, root_dir, npoints=1024, **kwargs): + def __init__(self, root_dir, npoints=1024, mode='train', **kwargs): super(ShapeNetPartSegDataset, self).__init__() - kwargs.update(dict(root_dir=root_dir)) + self.mode = mode + kwargs.update(dict(root_dir=root_dir, mode=self.mode)) self.npoints = npoints self.dataset = CustomShapeNet(**kwargs) @@ -199,6 +203,8 @@ class ShapeNetPartSegDataset(Dataset): 'points': points, # torch.Tensor (n, 3) 'labels': labels # torch.Tensor (n,) } + if self.mode == 'predict': + sample.update(normals=data.normals) return sample