From 167ac4991e348123371a68045956523ab9b7cf8b Mon Sep 17 00:00:00 2001
From: Si11ium <steffen.illium@ifi.lmu.de>
Date: Fri, 9 Aug 2019 15:37:26 +0200
Subject: [PATCH] Can now be trained with normals

---
 dataset/shapenet.py |  2 ++
 predict/predict.py  | 16 +++++++++++-----
 2 files changed, 13 insertions(+), 5 deletions(-)

diff --git a/dataset/shapenet.py b/dataset/shapenet.py
index cf224fc..c26ad18 100644
--- a/dataset/shapenet.py
+++ b/dataset/shapenet.py
@@ -148,6 +148,8 @@ class CustomShapeNet(InMemoryDataset):
                     ####################################
                     # This is where you define the keys
                     attr_dict = dict(y=y, pos=points[:, :3 if not self.with_normals else 6])
+                    if not self.with_normals:
+                        attr_dict.update(normals=points[:, 3:6])
                     ####################################
                     if self.collate_per_element:
                         data = Data(**attr_dict)
diff --git a/predict/predict.py b/predict/predict.py
index a41f37b..ece4273 100644
--- a/predict/predict.py
+++ b/predict/predict.py
@@ -28,8 +28,9 @@ def eval_sample(net, sample):
         # points: (n, 3)
         points, gt_label = sample['points'], sample['labels']
         n = points.shape[0]
+        f = points.shape[1]
 
-        points = points.view(1, n, 3)  # make a batch
+        points = points.view(1, n, f)  # make a batch
         points = points.transpose(1, 2).contiguous()
         points = points.to(device, dtype)
 
@@ -237,15 +238,16 @@ def draw_sample_data(sample_data, colored_normals = False):
 def recreate_folder(folder):
     if os.path.exists(folder) and os.path.isdir(folder):
         shutil.rmtree(folder)
-    os.mkdir(folder)
+    os.makedirs(folder, exist_ok=True)
 
 sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../')  # add project root directory
 
 parser = argparse.ArgumentParser()
 parser.add_argument('--npoints', type=int, default=2048, help='resample points number')
-parser.add_argument('--model', type=str, default='./checkpoint/seg_model_custom_3.pth', help='model path')
+parser.add_argument('--model', type=str, default='./checkpoint/seg_model_custom_1.pth', help='model path')
 parser.add_argument('--sample_idx', type=int, default=0, help='select a sample to segment and view result')
 parser.add_argument('--headers', type=strtobool, default=True, help='if raw files come with headers')
+parser.add_argument('--with_normals', type=strtobool, default=True, help='if training will include normals')
 parser.add_argument('--collate_per_segment', type=strtobool, default=True, help='whether to look at pointclouds or sub')
 parser.add_argument('--has_variations', type=strtobool, default=False,
                     help='whether a single pointcloud has variations '
@@ -303,6 +305,7 @@ if __name__ == '__main__':
     test_dataset = ShapeNetPartSegDataset(
         mode='predict',
         root_dir='data',
+        with_normals=opt.with_normals,
         npoints=opt.npoints,
         refresh=True,
         collate_per_segment=opt.collate_per_segment,
@@ -318,7 +321,7 @@ if __name__ == '__main__':
     dtype = torch.float
 
     # net = PointNetPartSegmentNet(num_classes)
-    net = PointNet2PartSegmentNet(num_classes)
+    net = PointNet2PartSegmentNet(num_classes, with_normals=opt.with_normals)
 
     net.load_state_dict(torch.load(opt.model, map_location=device.type))
     net = net.to(device, dtype)
@@ -332,7 +335,10 @@ if __name__ == '__main__':
         # Predict
 
         pred_label, gt_label = eval_sample(net, sample)
-        sample_data = np.column_stack((sample["points"].numpy(), sample["normals"].numpy(), pred_label.numpy()))
+        if opt.with_normals:
+            sample_data = np.column_stack((sample["points"].numpy(), pred_label.numpy()))
+        else:
+            sample_data = np.column_stack((sample["points"].numpy(), sample["normals"], pred_label.numpy()))
 
         draw_sample_data(sample_data, False)