diff --git a/dataset/shapenet.py b/dataset/shapenet.py
index 74a8b9f..dfb334e 100644
--- a/dataset/shapenet.py
+++ b/dataset/shapenet.py
@@ -141,13 +141,16 @@ class CustomShapeNet(InMemoryDataset):
                             y_all = [-1] * points.shape[0]
 
                     y = torch.as_tensor(y_all, dtype=torch.int)
+                    attr_dict = dict(y=y, pos=points[:, :3])
+                    if self.mode == 'predict':
+                        attr_dict.update(normals=points[:, 3:6])
                     if self.collate_per_element:
-                        data = Data(y=y, pos=points[:, :3])  # , points=points, norm=points[:, 3:])
+                        data = Data(**attr_dict)
                     else:
                         if not data:
                             data = defaultdict(list)
                         # points=points, norm=points[:, 3:]
-                        for key, val in dict(y=y, pos=points[:, :3]).items():
+                        for key, val in attr_dict.items():
                             data[key].append(val)
 
                     data = self._transform_and_filter(data)
@@ -175,9 +178,10 @@ class ShapeNetPartSegDataset(Dataset):
     Resample raw point cloud to fixed number of points.
     Map raw label from range [1, N] to [0, N-1].
     """
-    def __init__(self, root_dir, npoints=1024, **kwargs):
+    def __init__(self, root_dir, npoints=1024, mode='train', **kwargs):
         super(ShapeNetPartSegDataset, self).__init__()
-        kwargs.update(dict(root_dir=root_dir))
+        self.mode = mode
+        kwargs.update(dict(root_dir=root_dir, mode=self.mode))
         self.npoints = npoints
         self.dataset = CustomShapeNet(**kwargs)
 
@@ -199,6 +203,8 @@ class ShapeNetPartSegDataset(Dataset):
             'points': points,  # torch.Tensor (n, 3)
             'labels': labels  # torch.Tensor (n,)
         }
+        if self.mode == 'predict':
+            sample.update(normals=data.normals)
 
         return sample