commit 0159363642ebd4373caae9de9b706c0e1557b0f6
Author: Si11ium <steffen.illium@ifi.lmu.de>
Date:   Tue Jul 30 07:44:45 2019 +0200

    initial commit

diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..87b49df
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,129 @@
+# Created by .ignore support plugin (hsz.mobi)
+### Python template
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+#  Usually these files are written by a python script from a template
+#  before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+#   However, in case of collaboration, if having platform-specific dependencies or dependencies
+#   having no cross-platform support, pipenv may install dependencies that don't work, or not
+#   install all needed dependencies.
+#Pipfile.lock
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+/data/
+/checkpoint/
+/shapenet/
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..5c98b42
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,2 @@
+# Default ignored files
+/workspace.xml
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..a663f10
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,7 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="JavaScriptSettings">
+    <option name="languageLevel" value="ES6" />
+  </component>
+  <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (torch)" project-jdk-type="Python SDK" />
+</project>
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..57bd2fd
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="ProjectModuleManager">
+    <modules>
+      <module fileurl="file://$PROJECT_DIR$/.idea/pointnet2-pytorch.iml" filepath="$PROJECT_DIR$/.idea/pointnet2-pytorch.iml" />
+    </modules>
+  </component>
+</project>
\ No newline at end of file
diff --git a/.idea/pointnet2-pytorch.iml b/.idea/pointnet2-pytorch.iml
new file mode 100644
index 0000000..7401a39
--- /dev/null
+++ b/.idea/pointnet2-pytorch.iml
@@ -0,0 +1,14 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<module type="PYTHON_MODULE" version="4">
+  <component name="NewModuleRootManager">
+    <content url="file://$MODULE_DIR$">
+      <excludeFolder url="file://$MODULE_DIR$/data" />
+      <excludeFolder url="file://$MODULE_DIR$/net" />
+    </content>
+    <orderEntry type="jdk" jdkName="Python 3.7 (torch)" jdkType="Python SDK" />
+    <orderEntry type="sourceFolder" forTests="false" />
+  </component>
+  <component name="TestRunnerService">
+    <option name="PROJECT_TEST_RUNNER" value="Unittests" />
+  </component>
+</module>
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="VcsDirectoryMappings">
+    <mapping directory="$PROJECT_DIR$" vcs="Git" />
+  </component>
+</project>
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..9dd12af
--- /dev/null
+++ b/README.md
@@ -0,0 +1,54 @@
+# Pointnet++ Part segmentation
+This repo is implementation for [PointNet++](https://arxiv.org/abs/1706.02413) part segmentation model based on [PyTorch](https://pytorch.org) and [pytorch_geometric](https://github.com/rusty1s/pytorch_geometric). It can achieve comparable or better performance even compared with [PointCNN](https://arxiv.org/abs/1801.07791) on Shapenet dataset.
+
+**The model has been mergered into [pytorch_geometric](https://github.com/rusty1s/pytorch_geometric) as a point cloud segmentation [example](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/pointnet2_segmentation.py), you can try it.**
+
+# Performance
+Segmentation on  [A subset of shapenet](http://web.stanford.edu/~ericyi/project_page/part_annotation/index.html).
+
+| Method | mcIoU|Airplane|Bag|Cap|Car|Chair|Earphone|Guitar|Knife|Lamp|Laptop|Motorbike|Mug|Pistol|Rocket|Skateboard|Table
+| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 
+| PointNet++   | 81.9| 82.4| 79.0| 87.7| 77.3 |90.8| 71.8| 91.0| 85.9| 83.7| 95.3| 71.6| 94.1| 81.3| 58.7| 76.4| 82.6| 
+| PointCNN     | 84.6| 84.11| **86.47**| 86.04| **80.83**| 90.62| **79.70**| 92.32| 88.44| 85.31| 96.11| **77.20**| 95.28| 84.21| 64.23| **80.00**| 82.99| 
+| PointNet++(this repo) | **84.68**| **85.42**| 85.92| **88.39**| 79.73| **91.86**| 75.37| **92.95**| **88.56**| **85.72**| **97.00**| 72.94| **96.88**| **84.52**| **64.38**| 79.39| **85.91**|
+
+mcIOU: mean per-class pIoU
+
+# Requirements
+- [PyTorch](https://pytorch.org)
+- [pytorch_geometric](https://github.com/rusty1s/pytorch_geometric)
+- [Open3D](https://github.com/intel-isl/Open3D)(optional, for visualization of segmentation result)
+
+## Quickly install pytorch_geometric and Open3D with Anaconda
+```
+$ pip install --verbose --no-cache-dir torch-scatter
+$ pip install --verbose --no-cache-dir torch-sparse
+$ pip install --verbose --no-cache-dir torch-cluster
+$ pip install --verbose --no-cache-dir torch-spline-conv (optional)
+$ pip install torch-geometric
+```
+
+```
+# optional
+conda install -c open3d-admin open3d
+```
+
+# Usage
+Training
+```
+python main.py
+```
+
+Show segmentation result
+```
+python vis/show_seg_res.py
+```
+
+# Sample segmentation result
+![segmentation_result](figs/segmentation_result.png)
+
+
+# Links
+-  [pointnet.pytorch](https://github.com/fxia22/pointnet.pytorch) by fxia22. This repo's tranining code is heavily borrowed from fxia22's repo.
+- Official [PointNet](https://github.com/charlesq34/pointnet) and [PointNet++](https://github.com/charlesq34/pointnet2) tensorflow implementations
+- [PointNet++ classification example](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/pointnet%2B%2B.py) of pytorch_geometric library
diff --git a/dataset/shapenet.py b/dataset/shapenet.py
new file mode 100644
index 0000000..c8a694e
--- /dev/null
+++ b/dataset/shapenet.py
@@ -0,0 +1,158 @@
+import os
+import numpy as np
+
+from torch.utils.data import Dataset
+from torch_geometric.datasets import ShapeNet
+
+from itertools import repeat, product
+from collections import defaultdict
+
+import os
+from tqdm import tqdm
+import os.path as osp
+import glob
+
+import torch
+from torch_geometric.data import (Data, InMemoryDataset, download_url,
+                                  extract_zip)
+from torch.utils.data import Dataset
+from torch_geometric.read import read_txt_array
+from torch_geometric.datasets import ShapeNet
+
+from torch_geometric.read import parse_txt_array
+
+
+class CustomShapeNet(InMemoryDataset):
+
+    categories = {key: val for val, key in enumerate(['Box', 'Cone', 'Cylinder', 'Sphere'])}
+
+    def __init__(self, root, train=True, transform=None, pre_filter=None, pre_transform=None, **kwargs):
+        super(CustomShapeNet, self).__init__(root, transform, pre_transform, pre_filter)
+        path = self.processed_paths[0] if train else self.processed_paths[1]
+        self.data, self.slices = torch.load(path)
+        print("Initialized")
+
+    @property
+    def raw_file_names(self):
+        # Maybe add more data like validation sets
+        return ['train', 'test']
+
+    @property
+    def processed_file_names(self):
+        return [f'{x}.pt' for x in self.raw_file_names]
+
+    def download(self):
+        dir_count = len([name for name in os.listdir(self.raw_dir) if os.path.isdir(os.path.join(self.raw_dir, name))])
+        print(f'{dir_count} folders have been found....')
+        if dir_count:
+            return dir_count
+        raise IOError("No raw pointclouds have been found.")
+
+    @property
+    def num_classes(self):
+        return len(self.categories)
+
+    def _load_dataset(self):
+        data, slices = None, None
+        while True:
+            try:
+                filepath = os.path.join(self.root, self.processed_dir, f'{"train" if self.train else "test"}.pt')
+                data, slices = torch.load(filepath)
+                print('Dataset Loaded')
+                break
+            except FileNotFoundError:
+                self.process()
+                continue
+        return data, slices
+
+    def process(self, delimiter=' '):
+        # idx = self.categories[self.category]
+        # paths = [osp.join(path, idx) for path in self.raw_paths]
+
+        datasets = defaultdict(list)
+        for idx, setting in enumerate(self.raw_file_names):
+            for pointcloud in tqdm(os.scandir(os.path.join(self.raw_dir, setting))):
+                if not os.path.isdir(pointcloud):
+                    continue
+                for element in glob.glob(os.path.join(pointcloud.path, '*.dat')):
+                    if os.path.split(element)[-1] not in ['pc.dat']:
+                        # Assign training data to the data container
+                        # Following the original logic;
+                        # y should be the label;
+                        # pos should be the six dimensional vector describing: !its pos not points!!
+                        # x,y,z,x_rot,y_rot,z_rot
+                        y_raw = os.path.splitext(element)[0].split('_')[-2]
+                        with open(element,'r') as f:
+                            headers = f.__next__()
+                            # Check if there are no useable nodes in this file, header says 0.
+                            if not int(headers.rstrip().split(delimiter)[0]):
+                                continue
+                            # Get the y - Label
+
+                            # Iterate over all rows
+                            src = [[float(x) if x not in ['-nan(ind)', 'nan(ind)'] else 0
+                                    for x in line.rstrip().split(delimiter)[None:None]] for line in f if line != '']
+                        points = torch.tensor(src, dtype=None).squeeze()
+                        if not len(points.shape) > 1:
+                            continue
+                        # pos = points[:, :3]
+                        # norm = points[:, 3:]
+                        y_all = [self.categories[y_raw]] * points.shape[0]
+                        y = torch.as_tensor(y_all, dtype=torch.int)
+                        # points = torch.as_tensor(points, dtype=torch.float)
+                        # norm = torch.as_tensor(norm, dtype=torch.float)
+                        data = Data(y=y, pos=points[:, :3])
+                        # , points=points, norm=points[:3], )
+                        # ToDo: ANy filter to apply? Then do it here.
+                        if self.pre_filter is not None and not self.pre_filter(data):
+                            data = self.pre_filter(data)
+                            raise NotImplementedError
+                        # ToDo: ANy transformation to apply? Then do it here.
+                        if self.pre_transform is not None:
+                            data = self.pre_transform(data)
+                            raise NotImplementedError
+                        datasets[setting].append(data)
+
+            os.makedirs(self.processed_dir, exist_ok=True)
+            torch.save(self.collate(datasets[setting]), self.processed_paths[idx])
+
+    def __repr__(self):
+        return f'{self.__class__.__name__}({len(self)})'
+
+
+class ShapeNetPartSegDataset(Dataset):
+    """
+    Resample raw point cloud to fixed number of points.
+    Map raw label from range [1, N] to [0, N-1].
+    """
+    def __init__(self, root_dir, train=True, transform=None, npoints=1024):
+        super(ShapeNetPartSegDataset, self).__init__()
+        self.npoints = npoints
+        self.dataset = CustomShapeNet(root=root_dir, train=train, transform=transform)
+
+    def __getitem__(self, index):
+        data = self.dataset[index]
+        points, labels = data.pos, data.y
+
+        # Resample to fixed number of points
+        try:
+            choice = np.random.choice(points.shape[0], self.npoints, replace=True)
+        except ValueError:
+            choice = []
+
+        points, labels = points[choice, :], labels[choice]
+
+        labels -= 1 if self.num_classes() in labels else 0   # Map label from [1, C] to [0, C-1]
+
+        sample = {
+            'points': points,  # torch.Tensor (n, 3)
+            'labels': labels  # torch.Tensor (n,)
+        }
+
+        return sample
+
+    def __len__(self):
+        return len(self.dataset)
+
+    def num_classes(self):
+        return self.dataset.num_classes
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..c287cc7
--- /dev/null
+++ b/main.py
@@ -0,0 +1,200 @@
+"""
+Modified from https://github.com/fxia22/pointnet.pytorch/blob/master/utils/train_segmentation.py
+"""
+import os, sys
+import random
+import numpy as np
+import argparse
+import torch
+from torch.utils.data import DataLoader
+import torch.nn as nn
+import torch.optim as optim
+import torch.nn.functional as F
+from torch import autograd
+import torch.backends.cudnn as cudnn
+
+from dataset.shapenet import ShapeNetPartSegDataset
+from model.pointnet2_part_seg import PointNet2PartSegmentNet
+import torch_geometric.transforms as GT
+
+import time
+
+fs_root = os.path.splitdrive(sys.executable)[0]
+
+# Argument parser
+parser = argparse.ArgumentParser()
+default_data_dir = os.path.join(os.getcwd(), 'data')
+parser.add_argument('--dataset', type=str, default=default_data_dir, help='dataset path')
+parser.add_argument('--npoints', type=int, default=50, help='resample points number')
+parser.add_argument('--model', type=str, default='checkpoint//seg_model_custom_24.pth', help='model path')
+parser.add_argument('--nepoch', type=int, default=10, help='number of epochs to train for')
+parser.add_argument('--outf', type=str, default='checkpoint', help='output folder')
+parser.add_argument('--batch_size', type=int, default=8, help='input batch size')
+parser.add_argument('--test_per_batches', type=int, default=1000, help='run a test batch per training batches number')
+parser.add_argument('--num_workers', type=int, default=0, help='number of data loading workers')
+
+opt = parser.parse_args()
+print(opt)
+
+# Random seed
+opt.manual_seed = 123
+print('Random seed: ', opt.manual_seed)
+random.seed(opt.manual_seed)
+np.random.seed(opt.manual_seed)
+torch.manual_seed(opt.manual_seed)
+torch.cuda.manual_seed(opt.manual_seed)
+
+if __name__ == '__main__':
+
+    # Dataset and transform
+    print('Construct dataset ..')
+    if True:
+        rot_max_angle = 15
+        trans_max_distance = 0.01
+
+        RotTransform = GT.Compose([GT.RandomRotate(rot_max_angle, 0),
+                                   GT.RandomRotate(rot_max_angle, 1),
+                                   GT.RandomRotate(rot_max_angle, 2)]
+                                  )
+        TransTransform = GT.RandomTranslate(trans_max_distance)
+
+        train_transform = GT.Compose([GT.NormalizeScale(), RotTransform, TransTransform])
+        test_transform = GT.Compose([GT.NormalizeScale(), ])
+
+    dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, train=True, transform=train_transform, npoints=opt.npoints)
+    dataloader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers)
+
+    test_dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, train=False, transform=test_transform, npoints=opt.npoints)
+    test_dataloader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers)
+
+    num_classes = dataset.num_classes()
+
+    print('dataset size: ', len(dataset))
+    print('test_dataset size: ', len(test_dataset))
+    print('num_classes: ', num_classes)
+
+    try:
+        os.mkdir(opt.outf)
+    except OSError:
+        #FIXME: Why is this just a pass? What about missing permissions? LOL
+        pass
+
+
+    ## Model, criterion and optimizer
+    print('Construct model ..')
+    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+    dtype = torch.float
+    print('cudnn.enabled: ', torch.backends.cudnn.enabled)
+
+
+    net = PointNet2PartSegmentNet(num_classes)
+
+    if opt.model != '':
+        net.load_state_dict(torch.load(opt.model))
+    net = net.to(device, dtype)
+
+    criterion = nn.NLLLoss()
+    optimizer = optim.Adam(net.parameters())
+    if True:
+        ## Train
+        print('Training ..')
+        blue = lambda x: '\033[94m' + x + '\033[0m'
+        num_batch = len(dataset) // opt.batch_size
+        test_per_batches = opt.test_per_batches
+
+        print('number of epoches: ', opt.nepoch)
+        print('number of batches per epoch: ', num_batch)
+        print('run test per batches: ', test_per_batches)
+
+        for epoch in range(opt.nepoch):
+            print('Epoch {}, total epoches {}'.format(epoch+1, opt.nepoch))
+
+            net.train()
+
+            for batch_idx, sample in enumerate(dataloader):
+                # points: (batch_size, n, 3)
+                # labels: (batch_size, n)
+                points, labels = sample['points'], sample['labels']
+                points = points.transpose(1, 2).contiguous()  # (batch_size, 3, n)
+                points, labels = points.to(device, dtype), labels.to(device, torch.long)
+
+                optimizer.zero_grad()
+
+                pred = net(points)  # (batch_size, n, num_classes)
+                pred = pred.view(-1, num_classes)  # (batch_size * n, num_classes)
+                target = labels.view(-1, 1)[:, 0]
+
+                loss = F.nll_loss(pred, target)
+                loss.backward()
+
+                optimizer.step()
+
+                ##
+                pred_label = pred.detach().max(1)[1]
+                correct = pred_label.eq(target.detach()).cpu().sum()
+                total = pred_label.shape[0]
+
+                print('[{}: {}/{}] train loss: {} accuracy: {}'.format(epoch, batch_idx, num_batch, loss.item(), float(correct.item())/total))
+
+                ##
+                if batch_idx % test_per_batches == 0:
+                    print('Run a test batch')
+                    net.eval()
+
+                    with torch.no_grad():
+                        batch_idx, sample = next(enumerate(test_dataloader))
+
+                        points, labels = sample['points'], sample['labels']
+                        points = points.transpose(1, 2).contiguous()
+                        points, labels = points.to(device, dtype), labels.to(device, torch.long)
+
+                        pred = net(points)
+                        pred = pred.view(-1, num_classes)
+                        target = labels.view(-1, 1)[:, 0]
+
+                        target += 1 if -1 in target else 0
+                        loss = F.nll_loss(pred, target)
+
+                        pred_label = pred.detach().max(1)[1]
+                        correct = pred_label.eq(target.detach()).cpu().sum()
+                        total = pred_label.shape[0]
+                        print('[{}: {}/{}] {} loss: {} accuracy: {}'.format(epoch, batch_idx, num_batch, blue('test'), loss.item(), float(correct.item())/total))
+
+                    # Back to training mode
+                    net.train()
+
+            torch.save(net.state_dict(), f'{opt.outf}/seg_model_custom_{epoch}.pth')
+
+
+    ## Benchmarm mIOU
+    net.eval()
+    shape_ious = []
+
+    with torch.no_grad():
+        for batch_idx, sample in enumerate(test_dataloader):
+            points, labels = sample['points'], sample['labels']
+            points = points.transpose(1, 2).contiguous()
+            points = points.to(device, dtype)
+
+            # start_t = time.time()
+            pred = net(points) # (batch_size, n, num_classes)
+            # print('batch inference forward time used: {} ms'.format(time.time() - start_t))
+
+            pred_label = pred.max(2)[1]
+            pred_label = pred_label.cpu().numpy()
+            target_label = labels.numpy()
+
+            batch_size = target_label.shape[0]
+            for shape_idx in range(batch_size):
+                parts = range(num_classes)  # np.unique(target_label[shape_idx])
+                part_ious = []
+                for part in parts:
+                    I = np.sum(np.logical_and(pred_label[shape_idx] == part, target_label[shape_idx] == part))
+                    U = np.sum(np.logical_or(pred_label[shape_idx] == part, target_label[shape_idx] == part))
+                    iou = 1 if U == 0 else float(I) / U
+                    part_ious.append(iou)
+                shape_ious.append(np.mean(part_ious))
+
+        print(f'mIOU for us Custom: {np.mean(shape_ious)}')
+
+    print('Done.')
diff --git a/model/pointnet2_part_seg.py b/model/pointnet2_part_seg.py
new file mode 100644
index 0000000..64a7cc9
--- /dev/null
+++ b/model/pointnet2_part_seg.py
@@ -0,0 +1,286 @@
+import torch
+import torch.nn.functional as F
+from torch.nn import Sequential as Seq, Linear as Lin, ReLU, Dropout, BatchNorm1d
+from torch_geometric.nn import PointConv, fps, radius, knn
+from torch_geometric.nn.conv import MessagePassing
+from torch_geometric.nn.inits import reset
+from torch_geometric.utils.num_nodes import maybe_num_nodes
+from torch_geometric.data.data import Data
+from torch_scatter import scatter_add, scatter_max
+
+
+class PointNet2SAModule(torch.nn.Module):
+    def __init__(self, sample_radio, radius, max_num_neighbors, mlp):
+        super(PointNet2SAModule, self).__init__()
+        self.sample_ratio = sample_radio
+        self.radius = radius
+        self.max_num_neighbors = max_num_neighbors
+        self.point_conv = PointConv(mlp)
+
+    def forward(self, data):
+        x, pos, batch = data
+
+        # Sample
+        idx = fps(pos, batch, ratio=self.sample_ratio)
+
+        # Group(Build graph)
+        row, col = radius(pos, pos[idx], self.radius, batch, batch[idx], max_num_neighbors=self.max_num_neighbors)
+        edge_index = torch.stack([col, row], dim=0)
+
+        # Apply pointnet
+        x1 = self.point_conv(x, (pos, pos[idx]), edge_index)
+        pos1, batch1 = pos[idx], batch[idx]
+
+        return x1, pos1, batch1
+
+
+class PointNet2GlobalSAModule(torch.nn.Module):
+    '''
+    One group with all input points, can be viewed as a simple PointNet module.
+    It also return the only one output point(set as origin point).
+    '''
+    def __init__(self, mlp):
+        super(PointNet2GlobalSAModule, self).__init__()
+        self.mlp = mlp
+
+    def forward(self, data):
+        x, pos, batch = data
+        if x is not None: x = torch.cat([x, pos], dim=1)
+        x1 = self.mlp(x)
+
+        x1 = scatter_max(x1, batch, dim=0)[0]  # (batch_size, C1)
+
+        batch_size = x1.shape[0]
+        pos1 = x1.new_zeros((batch_size, 3))  # set the output point as origin
+        batch1 = torch.arange(batch_size).to(batch.device, batch.dtype)
+
+        return x1, pos1, batch1
+
+
+class PointConvFP(MessagePassing):
+    '''
+    Core layer of Feature propagtaion module.
+    '''
+    def __init__(self, mlp=None):
+        super(PointConvFP, self).__init__('add', 'source_to_target')
+        self.mlp = mlp
+        self.aggr = 'add'
+        self.flow = 'source_to_target'
+
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        reset(self.mlp)
+
+    def forward(self, x, pos, edge_index):
+        r"""
+        Args:
+            x (tuple), (tensor, tensor) or (tensor, NoneType)
+            pos (tuple): The node position matrix. Either given as
+                tensor for use in general message passing or as tuple for use
+                in message passing in bipartite graphs.
+            edge_index (LongTensor): The edge indices.
+        """
+        # Do not pass (tensor, None) directly into propagate(), sice it will check each item's size() inside.
+        x_tmp = x[0] if x[1] is None else x
+        aggr_out = self.propagate(edge_index, x=x_tmp, pos=pos)
+
+        #
+        i, j = (0, 1) if self.flow == 'target_to_source' else (1, 0)
+        x_target, pos_target = x[i], pos[i]
+
+        add = [pos_target,] if x_target is None else [x_target, pos_target]
+        aggr_out = torch.cat([aggr_out, *add], dim=1)
+
+        if self.mlp is not None: aggr_out = self.mlp(aggr_out)
+
+        return aggr_out
+
+    def message(self, x_j, pos_j, pos_i, edge_index):
+        '''
+        x_j: (E, in_channels)
+        pos_j: (E, 3)
+        pos_i: (E, 3)
+        '''
+        dist = (pos_j - pos_i).pow(2).sum(dim=1).pow(0.5)
+        dist = torch.max(dist, torch.Tensor([1e-10]).to(dist.device, dist.dtype))
+        weight = 1.0 / dist  # (E,)
+
+        row, col = edge_index
+        index = col
+        num_nodes = maybe_num_nodes(index, None)
+        wsum = scatter_add(weight, col, dim=0, dim_size=num_nodes)[index] + 1e-16  # (E,)
+        weight /= wsum
+
+        return weight.view(-1, 1) * x_j
+
+    def update(self, aggr_out):
+        return aggr_out
+
+
+class PointNet2FPModule(torch.nn.Module):
+    def __init__(self, knn_num, mlp):
+        super(PointNet2FPModule, self).__init__()
+        self.knn_num = knn_num
+        self.point_conv = PointConvFP(mlp)
+
+    def forward(self, in_layer_data, skip_layer_data):
+        in_x, in_pos, in_batch = in_layer_data
+        skip_x, skip_pos, skip_batch = skip_layer_data
+
+        row, col = knn(in_pos, skip_pos, self.knn_num, in_batch, skip_batch)        
+        edge_index = torch.stack([col, row], dim=0)
+
+        x1 = self.point_conv((in_x, skip_x), (in_pos, skip_pos), edge_index)
+        pos1, batch1 = skip_pos, skip_batch 
+
+        return x1, pos1, batch1
+
+
+def make_mlp(in_channels, mlp_channels, batch_norm=True):
+    assert len(mlp_channels) >= 1
+    layers = []
+
+    for c in mlp_channels:
+        layers += [Lin(in_channels, c)]
+        if batch_norm: layers += [BatchNorm1d(c)]
+        layers += [ReLU()]
+
+        in_channels = c
+
+    return Seq(*layers)
+
+
+class PointNet2PartSegmentNet(torch.nn.Module):
+    '''
+    ref:
+        - https://github.com/charlesq34/pointnet2/blob/master/models/pointnet2_part_seg.py
+        - https://github.com/rusty1s/pytorch_geometric/blob/master/examples/pointnet++.py
+    '''
+    def __init__(self, num_classes):
+        super(PointNet2PartSegmentNet, self).__init__()
+        self.num_classes = num_classes
+
+        # SA1
+        sa1_sample_ratio = 0.5
+        sa1_radius = 0.2
+        sa1_max_num_neighbours = 64
+        sa1_mlp = make_mlp(3, [64, 64, 128])
+        self.sa1_module = PointNet2SAModule(sa1_sample_ratio, sa1_radius, sa1_max_num_neighbours, sa1_mlp)
+
+        # SA2
+        sa2_sample_ratio = 0.25
+        sa2_radius = 0.4
+        sa2_max_num_neighbours = 64
+        sa2_mlp = make_mlp(128+3, [128, 128, 256])
+        self.sa2_module = PointNet2SAModule(sa2_sample_ratio, sa2_radius, sa2_max_num_neighbours, sa2_mlp)
+
+        # SA3
+        sa3_mlp = make_mlp(256+3, [256, 512, 1024])
+        self.sa3_module = PointNet2GlobalSAModule(sa3_mlp)
+
+        ##
+        knn_num = 3
+
+        # FP3, reverse of sa3
+        fp3_knn_num = 1  # After global sa module, there is only one point in point cloud
+        fp3_mlp = make_mlp(1024+256+3, [256, 256])
+        self.fp3_module = PointNet2FPModule(fp3_knn_num, fp3_mlp)
+
+        # FP2, reverse of sa2
+        fp2_knn_num = knn_num
+        fp2_mlp = make_mlp(256+128+3, [256, 128])
+        self.fp2_module = PointNet2FPModule(fp2_knn_num, fp2_mlp)
+
+        # FP1, reverse of sa1
+        fp1_knn_num = knn_num
+        fp1_mlp = make_mlp(128+3, [128, 128, 128])
+        self.fp1_module = PointNet2FPModule(fp1_knn_num, fp1_mlp)
+
+        self.fc1 = Lin(128, 128)
+        self.dropout1 = Dropout(p=0.5)
+        self.fc2 = Lin(128, self.num_classes)
+
+    def forward(self, data):
+        '''
+        data: a batch of input, torch.Tensor or torch_geometric.data.Data type
+            - torch.Tensor: (batch_size, 3, num_points), as common batch input
+
+            - torch_geometric.data.Data, as torch_geometric batch input:
+                data.x: (batch_size * ~num_points, C), batch nodes/points feature,
+                    ~num_points means each sample can have different number of points/nodes 
+
+                data.pos: (batch_size * ~num_points, 3)
+
+                data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud
+                    idendifiers for all nodes of all graphs/pointclouds in the batch. See
+                    pytorch_gemometric documentation for more information
+        '''
+        dense_input = True if isinstance(data, torch.Tensor) else False
+
+        if dense_input:
+            # Convert to torch_geometric.data.Data type
+            data = data.transpose(1, 2).contiguous()
+            batch_size, N, _ = data.shape  # (batch_size, num_points, 3)
+            pos = data.view(batch_size*N, -1)
+            batch = torch.zeros((batch_size, N), device=pos.device, dtype=torch.long)
+            for i in range(batch_size): batch[i] = i
+            batch = batch.view(-1)
+
+            data = Data()
+            data.pos, data.batch = pos, batch
+
+        if not hasattr(data, 'x'): data.x = None
+        data_in = data.x, data.pos, data.batch
+
+        sa1_out = self.sa1_module(data_in)
+        sa2_out = self.sa2_module(sa1_out)
+        sa3_out = self.sa3_module(sa2_out)
+
+        fp3_out = self.fp3_module(sa3_out, sa2_out)
+        fp2_out = self.fp2_module(fp3_out, sa1_out)
+        fp1_out = self.fp1_module(fp2_out, data_in)
+
+        fp1_out_x, fp1_out_pos, fp1_out_batch = fp1_out
+        x = self.fc2(self.dropout1(self.fc1(fp1_out_x)))
+        x = F.log_softmax(x, dim=-1)
+
+        if dense_input: return x.view(batch_size, N, self.num_classes)
+        else: return x, fp1_out_batch
+
+
+if __name__ == '__main__':
+    num_classes = 10
+    net = PointNet2PartSegmentNet(num_classes)
+
+    #
+    print('Test dense input ..')
+    data1 = torch.rand((2, 3, 1024))  # (batch_size, 3, num_points)
+    print('data1: ', data1.shape)
+
+    out1 = net(data1)
+    print('out1: ', out1.shape)
+
+    #
+    print('Test torch_geometric.data.Data input ..')
+    def make_data_batch():
+        # batch_size = 2
+        pos_num1 = 1000
+        pos_num2 = 1024
+
+        data_batch = Data()
+
+        # data_batch.x = None
+        data_batch.pos = torch.cat([torch.rand(pos_num1, 3), torch.rand(pos_num2, 3)], dim=0)
+        data_batch.batch = torch.cat([torch.zeros(pos_num1, dtype=torch.long), torch.ones(pos_num2, dtype=torch.long)])
+
+        return data_batch
+
+    data2 = make_data_batch()
+    # print('data.x: ', data.x)
+    print('data2.pos: ', data2.pos.shape)
+    print('data2.batch: ', data2.batch.shape)
+
+    out2_x, out2_batch = net(data2)
+    print('out2_x: ', out2_x.shape)
+    print('out2_batch: ', out2_batch.shape)
diff --git a/vis/show_seg_res.py b/vis/show_seg_res.py
new file mode 100644
index 0000000..1e20323
--- /dev/null
+++ b/vis/show_seg_res.py
@@ -0,0 +1,133 @@
+# Warning: import open3d may lead crash, try to to import open3d first here
+from view import view_points_labels
+
+import sys
+import os
+sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../')  # add project root directory
+
+from dataset.shapenet import ShapeNetPartSegDataset
+from model.pointnet2_part_seg import PointNet2PartSegmentNet
+import torch_geometric.transforms as GT
+import torch
+import numpy as np
+import argparse
+
+
+##
+parser = argparse.ArgumentParser()
+parser.add_argument('--dataset', type=str, default='data', help='dataset path')
+parser.add_argument('--npoints', type=int, default=50, help='resample points number')
+parser.add_argument('--model', type=str, default='./checkpoint/seg_model_Airplane_24.pth', help='model path')
+parser.add_argument('--sample_idx', type=int, default=0, help='select a sample to segment and view result')
+
+opt = parser.parse_args()
+print(opt)
+
+if __name__ == '__main__':
+
+    ## Load dataset
+    print('Construct dataset ..')
+    test_transform = GT.Compose([GT.NormalizeScale(),])
+
+    test_dataset = ShapeNetPartSegDataset(
+        root_dir=opt.dataset,
+        train=False,
+        transform=test_transform,
+        npoints=opt.npoints
+    )
+    num_classes = test_dataset.num_classes()
+
+    print('test dataset size: ', len(test_dataset))
+    print('num_classes: ', num_classes)
+
+
+    # Load model
+    print('Construct model ..')
+    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+    dtype = torch.float
+
+    # net = PointNetPartSegmentNet(num_classes)
+    net = PointNet2PartSegmentNet(num_classes)
+
+    net.load_state_dict(torch.load(opt.model))
+    net = net.to(device, dtype)
+    net.eval()
+
+
+    ##
+    def eval_sample(net, sample):
+        '''
+        sample: { 'points': tensor(n, 3), 'labels': tensor(n,) }
+        return: (pred_label, gt_label) with labels shape (n,)
+        '''
+        net.eval()
+        with torch.no_grad():
+            # points: (n, 3)
+            points, gt_label = sample['points'], sample['labels']
+            n = points.shape[0]
+
+            points = points.view(1, n, 3)  # make a batch
+            points = points.transpose(1, 2).contiguous()
+            points = points.to(device, dtype)
+
+            pred = net(points)  # (batch_size, n, num_classes)
+            pred_label = pred.max(2)[1]
+            pred_label = pred_label.view(-1).cpu()  # (n,)
+
+            assert pred_label.shape == gt_label.shape
+            return (pred_label, gt_label)
+
+
+    def compute_mIoU(pred_label, gt_label):
+        minl, maxl = np.min(gt_label), np.max(gt_label)
+        ious = []
+        for l in range(minl, maxl+1):
+            I = np.sum(np.logical_and(pred_label == l, gt_label == l))
+            U = np.sum(np.logical_or(pred_label == l, gt_label == l))
+            if U == 0: iou = 1
+            else: iou = float(I) / U
+            ious.append(iou)
+        return np.mean(ious)
+
+
+    def label_diff(pred_label, gt_label):
+        '''
+        Assign 1 if different label, or 0 if same label
+        '''
+        diff = pred_label - gt_label
+        diff_mask = (diff != 0)
+
+        diff_label = np.zeros((pred_label.shape[0]), dtype=np.int32)
+        diff_label[diff_mask] = 1
+
+        return diff_label
+
+
+    # Get one sample and eval
+    sample = test_dataset[opt.sample_idx]
+
+    print('Eval test sample ..')
+    pred_label, gt_label = eval_sample(net, sample)
+    print('Eval done ..')
+
+
+    # Get sample result
+    print('Compute mIoU ..')
+    points = sample['points'].numpy()
+    pred_labels = pred_label.numpy()
+    gt_labels = gt_label.numpy()
+    diff_labels = label_diff(pred_labels, gt_labels)
+
+    print('mIoU: ', compute_mIoU(pred_labels, gt_labels))
+
+
+    # View result
+
+    # print('View gt labels ..')
+    # view_points_labels(points, gt_labels)
+
+    # print('View diff labels ..')
+    # view_points_labels(points, diff_labels)
+
+    print('View pred labels ..')
+    view_points_labels(points, pred_labels)
diff --git a/vis/view.py b/vis/view.py
new file mode 100644
index 0000000..948fbf1
--- /dev/null
+++ b/vis/view.py
@@ -0,0 +1,63 @@
+import open3d as o3d
+import numpy as np
+
+
+def mini_color_table(index, norm=True):
+    colors = [
+        [0.5000, 0.5400, 0.5300], [0.8900, 0.1500, 0.2100], [0.6400, 0.5800, 0.5000],
+        [1.0000, 0.3800, 0.0100], [1.0000, 0.6600, 0.1400], [0.4980, 1.0000, 0.0000],
+        [0.4980, 1.0000, 0.8314], [0.9412, 0.9725, 1.0000], [0.5412, 0.1686, 0.8863],
+        [0.5765, 0.4392, 0.8588], [0.3600, 0.1400, 0.4300], [0.5600, 0.3700, 0.6000],
+    ]
+
+    assert index >= 0 and index < len(colors)
+    color = colors[index]
+
+    if not norm:
+        color[0] *= 255
+        color[1] *= 255
+        color[2] *= 255
+
+    return color
+
+
+def view_points(points, colors=None):
+    '''
+    points: np.ndarray with shape (n, 3)
+    colors: [r, g, b] or np.array with shape (n, 3)
+    '''
+    cloud = o3d.PointCloud()
+    cloud.points = o3d.Vector3dVector(points) 
+
+    if colors is not None:
+        if isinstance(colors, np.ndarray):
+            cloud.colors = o3d.Vector3dVector(colors)
+        else: cloud.paint_uniform_color(colors)
+
+    o3d.draw_geometries([cloud])
+
+
+def label2color(labels):
+    '''
+    labels: np.ndarray with shape (n, )
+    colors(return): np.ndarray with shape (n, 3)
+    '''
+    num = labels.shape[0]
+    colors = np.zeros((num, 3))
+
+    minl, maxl = np.min(labels), np.max(labels)
+    for l in range(minl, maxl + 1):
+        colors[labels==l, :] = mini_color_table(l) 
+
+    return colors
+
+
+def view_points_labels(points, labels):
+    '''
+    Assign points with colors by labels and view colored points.
+    points: np.ndarray with shape (n, 3)
+    labels: np.ndarray with shape (n, 1), dtype=np.int32
+    '''
+    assert points.shape[0] == labels.shape[0]
+    colors = label2color(labels)
+    view_points(points, colors)