From 0159363642ebd4373caae9de9b706c0e1557b0f6 Mon Sep 17 00:00:00 2001 From: Si11ium Date: Tue, 30 Jul 2019 07:44:45 +0200 Subject: [PATCH] initial commit --- .gitignore | 129 ++++++++++++++++ .idea/.gitignore | 2 + .idea/misc.xml | 7 + .idea/modules.xml | 8 + .idea/pointnet2-pytorch.iml | 14 ++ .idea/vcs.xml | 6 + README.md | 54 +++++++ dataset/shapenet.py | 158 ++++++++++++++++++++ main.py | 200 +++++++++++++++++++++++++ model/pointnet2_part_seg.py | 286 ++++++++++++++++++++++++++++++++++++ vis/show_seg_res.py | 133 +++++++++++++++++ vis/view.py | 63 ++++++++ 12 files changed, 1060 insertions(+) create mode 100644 .gitignore create mode 100644 .idea/.gitignore create mode 100644 .idea/misc.xml create mode 100644 .idea/modules.xml create mode 100644 .idea/pointnet2-pytorch.iml create mode 100644 .idea/vcs.xml create mode 100644 README.md create mode 100644 dataset/shapenet.py create mode 100644 main.py create mode 100644 model/pointnet2_part_seg.py create mode 100644 vis/show_seg_res.py create mode 100644 vis/view.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..87b49df --- /dev/null +++ b/.gitignore @@ -0,0 +1,129 @@ +# Created by .ignore support plugin (hsz.mobi) +### Python template +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +/data/ +/checkpoint/ +/shapenet/ diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..5c98b42 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,2 @@ +# Default ignored files +/workspace.xml \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..a663f10 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..57bd2fd --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/pointnet2-pytorch.iml b/.idea/pointnet2-pytorch.iml new file mode 100644 index 0000000..7401a39 --- /dev/null +++ b/.idea/pointnet2-pytorch.iml @@ -0,0 +1,14 @@ + + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..9dd12af --- /dev/null +++ b/README.md @@ -0,0 +1,54 @@ +# Pointnet++ Part segmentation +This repo is implementation for [PointNet++](https://arxiv.org/abs/1706.02413) part segmentation model based on [PyTorch](https://pytorch.org) and [pytorch_geometric](https://github.com/rusty1s/pytorch_geometric). It can achieve comparable or better performance even compared with [PointCNN](https://arxiv.org/abs/1801.07791) on Shapenet dataset. + +**The model has been mergered into [pytorch_geometric](https://github.com/rusty1s/pytorch_geometric) as a point cloud segmentation [example](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/pointnet2_segmentation.py), you can try it.** + +# Performance +Segmentation on [A subset of shapenet](http://web.stanford.edu/~ericyi/project_page/part_annotation/index.html). + +| Method | mcIoU|Airplane|Bag|Cap|Car|Chair|Earphone|Guitar|Knife|Lamp|Laptop|Motorbike|Mug|Pistol|Rocket|Skateboard|Table +| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | +| PointNet++ | 81.9| 82.4| 79.0| 87.7| 77.3 |90.8| 71.8| 91.0| 85.9| 83.7| 95.3| 71.6| 94.1| 81.3| 58.7| 76.4| 82.6| +| PointCNN | 84.6| 84.11| **86.47**| 86.04| **80.83**| 90.62| **79.70**| 92.32| 88.44| 85.31| 96.11| **77.20**| 95.28| 84.21| 64.23| **80.00**| 82.99| +| PointNet++(this repo) | **84.68**| **85.42**| 85.92| **88.39**| 79.73| **91.86**| 75.37| **92.95**| **88.56**| **85.72**| **97.00**| 72.94| **96.88**| **84.52**| **64.38**| 79.39| **85.91**| + +mcIOU: mean per-class pIoU + +# Requirements +- [PyTorch](https://pytorch.org) +- [pytorch_geometric](https://github.com/rusty1s/pytorch_geometric) +- [Open3D](https://github.com/intel-isl/Open3D)(optional, for visualization of segmentation result) + +## Quickly install pytorch_geometric and Open3D with Anaconda +``` +$ pip install --verbose --no-cache-dir torch-scatter +$ pip install --verbose --no-cache-dir torch-sparse +$ pip install --verbose --no-cache-dir torch-cluster +$ pip install --verbose --no-cache-dir torch-spline-conv (optional) +$ pip install torch-geometric +``` + +``` +# optional +conda install -c open3d-admin open3d +``` + +# Usage +Training +``` +python main.py +``` + +Show segmentation result +``` +python vis/show_seg_res.py +``` + +# Sample segmentation result +![segmentation_result](figs/segmentation_result.png) + + +# Links +- [pointnet.pytorch](https://github.com/fxia22/pointnet.pytorch) by fxia22. This repo's tranining code is heavily borrowed from fxia22's repo. +- Official [PointNet](https://github.com/charlesq34/pointnet) and [PointNet++](https://github.com/charlesq34/pointnet2) tensorflow implementations +- [PointNet++ classification example](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/pointnet%2B%2B.py) of pytorch_geometric library diff --git a/dataset/shapenet.py b/dataset/shapenet.py new file mode 100644 index 0000000..c8a694e --- /dev/null +++ b/dataset/shapenet.py @@ -0,0 +1,158 @@ +import os +import numpy as np + +from torch.utils.data import Dataset +from torch_geometric.datasets import ShapeNet + +from itertools import repeat, product +from collections import defaultdict + +import os +from tqdm import tqdm +import os.path as osp +import glob + +import torch +from torch_geometric.data import (Data, InMemoryDataset, download_url, + extract_zip) +from torch.utils.data import Dataset +from torch_geometric.read import read_txt_array +from torch_geometric.datasets import ShapeNet + +from torch_geometric.read import parse_txt_array + + +class CustomShapeNet(InMemoryDataset): + + categories = {key: val for val, key in enumerate(['Box', 'Cone', 'Cylinder', 'Sphere'])} + + def __init__(self, root, train=True, transform=None, pre_filter=None, pre_transform=None, **kwargs): + super(CustomShapeNet, self).__init__(root, transform, pre_transform, pre_filter) + path = self.processed_paths[0] if train else self.processed_paths[1] + self.data, self.slices = torch.load(path) + print("Initialized") + + @property + def raw_file_names(self): + # Maybe add more data like validation sets + return ['train', 'test'] + + @property + def processed_file_names(self): + return [f'{x}.pt' for x in self.raw_file_names] + + def download(self): + dir_count = len([name for name in os.listdir(self.raw_dir) if os.path.isdir(os.path.join(self.raw_dir, name))]) + print(f'{dir_count} folders have been found....') + if dir_count: + return dir_count + raise IOError("No raw pointclouds have been found.") + + @property + def num_classes(self): + return len(self.categories) + + def _load_dataset(self): + data, slices = None, None + while True: + try: + filepath = os.path.join(self.root, self.processed_dir, f'{"train" if self.train else "test"}.pt') + data, slices = torch.load(filepath) + print('Dataset Loaded') + break + except FileNotFoundError: + self.process() + continue + return data, slices + + def process(self, delimiter=' '): + # idx = self.categories[self.category] + # paths = [osp.join(path, idx) for path in self.raw_paths] + + datasets = defaultdict(list) + for idx, setting in enumerate(self.raw_file_names): + for pointcloud in tqdm(os.scandir(os.path.join(self.raw_dir, setting))): + if not os.path.isdir(pointcloud): + continue + for element in glob.glob(os.path.join(pointcloud.path, '*.dat')): + if os.path.split(element)[-1] not in ['pc.dat']: + # Assign training data to the data container + # Following the original logic; + # y should be the label; + # pos should be the six dimensional vector describing: !its pos not points!! + # x,y,z,x_rot,y_rot,z_rot + y_raw = os.path.splitext(element)[0].split('_')[-2] + with open(element,'r') as f: + headers = f.__next__() + # Check if there are no useable nodes in this file, header says 0. + if not int(headers.rstrip().split(delimiter)[0]): + continue + # Get the y - Label + + # Iterate over all rows + src = [[float(x) if x not in ['-nan(ind)', 'nan(ind)'] else 0 + for x in line.rstrip().split(delimiter)[None:None]] for line in f if line != ''] + points = torch.tensor(src, dtype=None).squeeze() + if not len(points.shape) > 1: + continue + # pos = points[:, :3] + # norm = points[:, 3:] + y_all = [self.categories[y_raw]] * points.shape[0] + y = torch.as_tensor(y_all, dtype=torch.int) + # points = torch.as_tensor(points, dtype=torch.float) + # norm = torch.as_tensor(norm, dtype=torch.float) + data = Data(y=y, pos=points[:, :3]) + # , points=points, norm=points[:3], ) + # ToDo: ANy filter to apply? Then do it here. + if self.pre_filter is not None and not self.pre_filter(data): + data = self.pre_filter(data) + raise NotImplementedError + # ToDo: ANy transformation to apply? Then do it here. + if self.pre_transform is not None: + data = self.pre_transform(data) + raise NotImplementedError + datasets[setting].append(data) + + os.makedirs(self.processed_dir, exist_ok=True) + torch.save(self.collate(datasets[setting]), self.processed_paths[idx]) + + def __repr__(self): + return f'{self.__class__.__name__}({len(self)})' + + +class ShapeNetPartSegDataset(Dataset): + """ + Resample raw point cloud to fixed number of points. + Map raw label from range [1, N] to [0, N-1]. + """ + def __init__(self, root_dir, train=True, transform=None, npoints=1024): + super(ShapeNetPartSegDataset, self).__init__() + self.npoints = npoints + self.dataset = CustomShapeNet(root=root_dir, train=train, transform=transform) + + def __getitem__(self, index): + data = self.dataset[index] + points, labels = data.pos, data.y + + # Resample to fixed number of points + try: + choice = np.random.choice(points.shape[0], self.npoints, replace=True) + except ValueError: + choice = [] + + points, labels = points[choice, :], labels[choice] + + labels -= 1 if self.num_classes() in labels else 0 # Map label from [1, C] to [0, C-1] + + sample = { + 'points': points, # torch.Tensor (n, 3) + 'labels': labels # torch.Tensor (n,) + } + + return sample + + def __len__(self): + return len(self.dataset) + + def num_classes(self): + return self.dataset.num_classes diff --git a/main.py b/main.py new file mode 100644 index 0000000..c287cc7 --- /dev/null +++ b/main.py @@ -0,0 +1,200 @@ +""" +Modified from https://github.com/fxia22/pointnet.pytorch/blob/master/utils/train_segmentation.py +""" +import os, sys +import random +import numpy as np +import argparse +import torch +from torch.utils.data import DataLoader +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from torch import autograd +import torch.backends.cudnn as cudnn + +from dataset.shapenet import ShapeNetPartSegDataset +from model.pointnet2_part_seg import PointNet2PartSegmentNet +import torch_geometric.transforms as GT + +import time + +fs_root = os.path.splitdrive(sys.executable)[0] + +# Argument parser +parser = argparse.ArgumentParser() +default_data_dir = os.path.join(os.getcwd(), 'data') +parser.add_argument('--dataset', type=str, default=default_data_dir, help='dataset path') +parser.add_argument('--npoints', type=int, default=50, help='resample points number') +parser.add_argument('--model', type=str, default='checkpoint//seg_model_custom_24.pth', help='model path') +parser.add_argument('--nepoch', type=int, default=10, help='number of epochs to train for') +parser.add_argument('--outf', type=str, default='checkpoint', help='output folder') +parser.add_argument('--batch_size', type=int, default=8, help='input batch size') +parser.add_argument('--test_per_batches', type=int, default=1000, help='run a test batch per training batches number') +parser.add_argument('--num_workers', type=int, default=0, help='number of data loading workers') + +opt = parser.parse_args() +print(opt) + +# Random seed +opt.manual_seed = 123 +print('Random seed: ', opt.manual_seed) +random.seed(opt.manual_seed) +np.random.seed(opt.manual_seed) +torch.manual_seed(opt.manual_seed) +torch.cuda.manual_seed(opt.manual_seed) + +if __name__ == '__main__': + + # Dataset and transform + print('Construct dataset ..') + if True: + rot_max_angle = 15 + trans_max_distance = 0.01 + + RotTransform = GT.Compose([GT.RandomRotate(rot_max_angle, 0), + GT.RandomRotate(rot_max_angle, 1), + GT.RandomRotate(rot_max_angle, 2)] + ) + TransTransform = GT.RandomTranslate(trans_max_distance) + + train_transform = GT.Compose([GT.NormalizeScale(), RotTransform, TransTransform]) + test_transform = GT.Compose([GT.NormalizeScale(), ]) + + dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, train=True, transform=train_transform, npoints=opt.npoints) + dataloader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) + + test_dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, train=False, transform=test_transform, npoints=opt.npoints) + test_dataloader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) + + num_classes = dataset.num_classes() + + print('dataset size: ', len(dataset)) + print('test_dataset size: ', len(test_dataset)) + print('num_classes: ', num_classes) + + try: + os.mkdir(opt.outf) + except OSError: + #FIXME: Why is this just a pass? What about missing permissions? LOL + pass + + + ## Model, criterion and optimizer + print('Construct model ..') + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + dtype = torch.float + print('cudnn.enabled: ', torch.backends.cudnn.enabled) + + + net = PointNet2PartSegmentNet(num_classes) + + if opt.model != '': + net.load_state_dict(torch.load(opt.model)) + net = net.to(device, dtype) + + criterion = nn.NLLLoss() + optimizer = optim.Adam(net.parameters()) + if True: + ## Train + print('Training ..') + blue = lambda x: '\033[94m' + x + '\033[0m' + num_batch = len(dataset) // opt.batch_size + test_per_batches = opt.test_per_batches + + print('number of epoches: ', opt.nepoch) + print('number of batches per epoch: ', num_batch) + print('run test per batches: ', test_per_batches) + + for epoch in range(opt.nepoch): + print('Epoch {}, total epoches {}'.format(epoch+1, opt.nepoch)) + + net.train() + + for batch_idx, sample in enumerate(dataloader): + # points: (batch_size, n, 3) + # labels: (batch_size, n) + points, labels = sample['points'], sample['labels'] + points = points.transpose(1, 2).contiguous() # (batch_size, 3, n) + points, labels = points.to(device, dtype), labels.to(device, torch.long) + + optimizer.zero_grad() + + pred = net(points) # (batch_size, n, num_classes) + pred = pred.view(-1, num_classes) # (batch_size * n, num_classes) + target = labels.view(-1, 1)[:, 0] + + loss = F.nll_loss(pred, target) + loss.backward() + + optimizer.step() + + ## + pred_label = pred.detach().max(1)[1] + correct = pred_label.eq(target.detach()).cpu().sum() + total = pred_label.shape[0] + + print('[{}: {}/{}] train loss: {} accuracy: {}'.format(epoch, batch_idx, num_batch, loss.item(), float(correct.item())/total)) + + ## + if batch_idx % test_per_batches == 0: + print('Run a test batch') + net.eval() + + with torch.no_grad(): + batch_idx, sample = next(enumerate(test_dataloader)) + + points, labels = sample['points'], sample['labels'] + points = points.transpose(1, 2).contiguous() + points, labels = points.to(device, dtype), labels.to(device, torch.long) + + pred = net(points) + pred = pred.view(-1, num_classes) + target = labels.view(-1, 1)[:, 0] + + target += 1 if -1 in target else 0 + loss = F.nll_loss(pred, target) + + pred_label = pred.detach().max(1)[1] + correct = pred_label.eq(target.detach()).cpu().sum() + total = pred_label.shape[0] + print('[{}: {}/{}] {} loss: {} accuracy: {}'.format(epoch, batch_idx, num_batch, blue('test'), loss.item(), float(correct.item())/total)) + + # Back to training mode + net.train() + + torch.save(net.state_dict(), f'{opt.outf}/seg_model_custom_{epoch}.pth') + + + ## Benchmarm mIOU + net.eval() + shape_ious = [] + + with torch.no_grad(): + for batch_idx, sample in enumerate(test_dataloader): + points, labels = sample['points'], sample['labels'] + points = points.transpose(1, 2).contiguous() + points = points.to(device, dtype) + + # start_t = time.time() + pred = net(points) # (batch_size, n, num_classes) + # print('batch inference forward time used: {} ms'.format(time.time() - start_t)) + + pred_label = pred.max(2)[1] + pred_label = pred_label.cpu().numpy() + target_label = labels.numpy() + + batch_size = target_label.shape[0] + for shape_idx in range(batch_size): + parts = range(num_classes) # np.unique(target_label[shape_idx]) + part_ious = [] + for part in parts: + I = np.sum(np.logical_and(pred_label[shape_idx] == part, target_label[shape_idx] == part)) + U = np.sum(np.logical_or(pred_label[shape_idx] == part, target_label[shape_idx] == part)) + iou = 1 if U == 0 else float(I) / U + part_ious.append(iou) + shape_ious.append(np.mean(part_ious)) + + print(f'mIOU for us Custom: {np.mean(shape_ious)}') + + print('Done.') diff --git a/model/pointnet2_part_seg.py b/model/pointnet2_part_seg.py new file mode 100644 index 0000000..64a7cc9 --- /dev/null +++ b/model/pointnet2_part_seg.py @@ -0,0 +1,286 @@ +import torch +import torch.nn.functional as F +from torch.nn import Sequential as Seq, Linear as Lin, ReLU, Dropout, BatchNorm1d +from torch_geometric.nn import PointConv, fps, radius, knn +from torch_geometric.nn.conv import MessagePassing +from torch_geometric.nn.inits import reset +from torch_geometric.utils.num_nodes import maybe_num_nodes +from torch_geometric.data.data import Data +from torch_scatter import scatter_add, scatter_max + + +class PointNet2SAModule(torch.nn.Module): + def __init__(self, sample_radio, radius, max_num_neighbors, mlp): + super(PointNet2SAModule, self).__init__() + self.sample_ratio = sample_radio + self.radius = radius + self.max_num_neighbors = max_num_neighbors + self.point_conv = PointConv(mlp) + + def forward(self, data): + x, pos, batch = data + + # Sample + idx = fps(pos, batch, ratio=self.sample_ratio) + + # Group(Build graph) + row, col = radius(pos, pos[idx], self.radius, batch, batch[idx], max_num_neighbors=self.max_num_neighbors) + edge_index = torch.stack([col, row], dim=0) + + # Apply pointnet + x1 = self.point_conv(x, (pos, pos[idx]), edge_index) + pos1, batch1 = pos[idx], batch[idx] + + return x1, pos1, batch1 + + +class PointNet2GlobalSAModule(torch.nn.Module): + ''' + One group with all input points, can be viewed as a simple PointNet module. + It also return the only one output point(set as origin point). + ''' + def __init__(self, mlp): + super(PointNet2GlobalSAModule, self).__init__() + self.mlp = mlp + + def forward(self, data): + x, pos, batch = data + if x is not None: x = torch.cat([x, pos], dim=1) + x1 = self.mlp(x) + + x1 = scatter_max(x1, batch, dim=0)[0] # (batch_size, C1) + + batch_size = x1.shape[0] + pos1 = x1.new_zeros((batch_size, 3)) # set the output point as origin + batch1 = torch.arange(batch_size).to(batch.device, batch.dtype) + + return x1, pos1, batch1 + + +class PointConvFP(MessagePassing): + ''' + Core layer of Feature propagtaion module. + ''' + def __init__(self, mlp=None): + super(PointConvFP, self).__init__('add', 'source_to_target') + self.mlp = mlp + self.aggr = 'add' + self.flow = 'source_to_target' + + self.reset_parameters() + + def reset_parameters(self): + reset(self.mlp) + + def forward(self, x, pos, edge_index): + r""" + Args: + x (tuple), (tensor, tensor) or (tensor, NoneType) + pos (tuple): The node position matrix. Either given as + tensor for use in general message passing or as tuple for use + in message passing in bipartite graphs. + edge_index (LongTensor): The edge indices. + """ + # Do not pass (tensor, None) directly into propagate(), sice it will check each item's size() inside. + x_tmp = x[0] if x[1] is None else x + aggr_out = self.propagate(edge_index, x=x_tmp, pos=pos) + + # + i, j = (0, 1) if self.flow == 'target_to_source' else (1, 0) + x_target, pos_target = x[i], pos[i] + + add = [pos_target,] if x_target is None else [x_target, pos_target] + aggr_out = torch.cat([aggr_out, *add], dim=1) + + if self.mlp is not None: aggr_out = self.mlp(aggr_out) + + return aggr_out + + def message(self, x_j, pos_j, pos_i, edge_index): + ''' + x_j: (E, in_channels) + pos_j: (E, 3) + pos_i: (E, 3) + ''' + dist = (pos_j - pos_i).pow(2).sum(dim=1).pow(0.5) + dist = torch.max(dist, torch.Tensor([1e-10]).to(dist.device, dist.dtype)) + weight = 1.0 / dist # (E,) + + row, col = edge_index + index = col + num_nodes = maybe_num_nodes(index, None) + wsum = scatter_add(weight, col, dim=0, dim_size=num_nodes)[index] + 1e-16 # (E,) + weight /= wsum + + return weight.view(-1, 1) * x_j + + def update(self, aggr_out): + return aggr_out + + +class PointNet2FPModule(torch.nn.Module): + def __init__(self, knn_num, mlp): + super(PointNet2FPModule, self).__init__() + self.knn_num = knn_num + self.point_conv = PointConvFP(mlp) + + def forward(self, in_layer_data, skip_layer_data): + in_x, in_pos, in_batch = in_layer_data + skip_x, skip_pos, skip_batch = skip_layer_data + + row, col = knn(in_pos, skip_pos, self.knn_num, in_batch, skip_batch) + edge_index = torch.stack([col, row], dim=0) + + x1 = self.point_conv((in_x, skip_x), (in_pos, skip_pos), edge_index) + pos1, batch1 = skip_pos, skip_batch + + return x1, pos1, batch1 + + +def make_mlp(in_channels, mlp_channels, batch_norm=True): + assert len(mlp_channels) >= 1 + layers = [] + + for c in mlp_channels: + layers += [Lin(in_channels, c)] + if batch_norm: layers += [BatchNorm1d(c)] + layers += [ReLU()] + + in_channels = c + + return Seq(*layers) + + +class PointNet2PartSegmentNet(torch.nn.Module): + ''' + ref: + - https://github.com/charlesq34/pointnet2/blob/master/models/pointnet2_part_seg.py + - https://github.com/rusty1s/pytorch_geometric/blob/master/examples/pointnet++.py + ''' + def __init__(self, num_classes): + super(PointNet2PartSegmentNet, self).__init__() + self.num_classes = num_classes + + # SA1 + sa1_sample_ratio = 0.5 + sa1_radius = 0.2 + sa1_max_num_neighbours = 64 + sa1_mlp = make_mlp(3, [64, 64, 128]) + self.sa1_module = PointNet2SAModule(sa1_sample_ratio, sa1_radius, sa1_max_num_neighbours, sa1_mlp) + + # SA2 + sa2_sample_ratio = 0.25 + sa2_radius = 0.4 + sa2_max_num_neighbours = 64 + sa2_mlp = make_mlp(128+3, [128, 128, 256]) + self.sa2_module = PointNet2SAModule(sa2_sample_ratio, sa2_radius, sa2_max_num_neighbours, sa2_mlp) + + # SA3 + sa3_mlp = make_mlp(256+3, [256, 512, 1024]) + self.sa3_module = PointNet2GlobalSAModule(sa3_mlp) + + ## + knn_num = 3 + + # FP3, reverse of sa3 + fp3_knn_num = 1 # After global sa module, there is only one point in point cloud + fp3_mlp = make_mlp(1024+256+3, [256, 256]) + self.fp3_module = PointNet2FPModule(fp3_knn_num, fp3_mlp) + + # FP2, reverse of sa2 + fp2_knn_num = knn_num + fp2_mlp = make_mlp(256+128+3, [256, 128]) + self.fp2_module = PointNet2FPModule(fp2_knn_num, fp2_mlp) + + # FP1, reverse of sa1 + fp1_knn_num = knn_num + fp1_mlp = make_mlp(128+3, [128, 128, 128]) + self.fp1_module = PointNet2FPModule(fp1_knn_num, fp1_mlp) + + self.fc1 = Lin(128, 128) + self.dropout1 = Dropout(p=0.5) + self.fc2 = Lin(128, self.num_classes) + + def forward(self, data): + ''' + data: a batch of input, torch.Tensor or torch_geometric.data.Data type + - torch.Tensor: (batch_size, 3, num_points), as common batch input + + - torch_geometric.data.Data, as torch_geometric batch input: + data.x: (batch_size * ~num_points, C), batch nodes/points feature, + ~num_points means each sample can have different number of points/nodes + + data.pos: (batch_size * ~num_points, 3) + + data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud + idendifiers for all nodes of all graphs/pointclouds in the batch. See + pytorch_gemometric documentation for more information + ''' + dense_input = True if isinstance(data, torch.Tensor) else False + + if dense_input: + # Convert to torch_geometric.data.Data type + data = data.transpose(1, 2).contiguous() + batch_size, N, _ = data.shape # (batch_size, num_points, 3) + pos = data.view(batch_size*N, -1) + batch = torch.zeros((batch_size, N), device=pos.device, dtype=torch.long) + for i in range(batch_size): batch[i] = i + batch = batch.view(-1) + + data = Data() + data.pos, data.batch = pos, batch + + if not hasattr(data, 'x'): data.x = None + data_in = data.x, data.pos, data.batch + + sa1_out = self.sa1_module(data_in) + sa2_out = self.sa2_module(sa1_out) + sa3_out = self.sa3_module(sa2_out) + + fp3_out = self.fp3_module(sa3_out, sa2_out) + fp2_out = self.fp2_module(fp3_out, sa1_out) + fp1_out = self.fp1_module(fp2_out, data_in) + + fp1_out_x, fp1_out_pos, fp1_out_batch = fp1_out + x = self.fc2(self.dropout1(self.fc1(fp1_out_x))) + x = F.log_softmax(x, dim=-1) + + if dense_input: return x.view(batch_size, N, self.num_classes) + else: return x, fp1_out_batch + + +if __name__ == '__main__': + num_classes = 10 + net = PointNet2PartSegmentNet(num_classes) + + # + print('Test dense input ..') + data1 = torch.rand((2, 3, 1024)) # (batch_size, 3, num_points) + print('data1: ', data1.shape) + + out1 = net(data1) + print('out1: ', out1.shape) + + # + print('Test torch_geometric.data.Data input ..') + def make_data_batch(): + # batch_size = 2 + pos_num1 = 1000 + pos_num2 = 1024 + + data_batch = Data() + + # data_batch.x = None + data_batch.pos = torch.cat([torch.rand(pos_num1, 3), torch.rand(pos_num2, 3)], dim=0) + data_batch.batch = torch.cat([torch.zeros(pos_num1, dtype=torch.long), torch.ones(pos_num2, dtype=torch.long)]) + + return data_batch + + data2 = make_data_batch() + # print('data.x: ', data.x) + print('data2.pos: ', data2.pos.shape) + print('data2.batch: ', data2.batch.shape) + + out2_x, out2_batch = net(data2) + print('out2_x: ', out2_x.shape) + print('out2_batch: ', out2_batch.shape) diff --git a/vis/show_seg_res.py b/vis/show_seg_res.py new file mode 100644 index 0000000..1e20323 --- /dev/null +++ b/vis/show_seg_res.py @@ -0,0 +1,133 @@ +# Warning: import open3d may lead crash, try to to import open3d first here +from view import view_points_labels + +import sys +import os +sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../') # add project root directory + +from dataset.shapenet import ShapeNetPartSegDataset +from model.pointnet2_part_seg import PointNet2PartSegmentNet +import torch_geometric.transforms as GT +import torch +import numpy as np +import argparse + + +## +parser = argparse.ArgumentParser() +parser.add_argument('--dataset', type=str, default='data', help='dataset path') +parser.add_argument('--npoints', type=int, default=50, help='resample points number') +parser.add_argument('--model', type=str, default='./checkpoint/seg_model_Airplane_24.pth', help='model path') +parser.add_argument('--sample_idx', type=int, default=0, help='select a sample to segment and view result') + +opt = parser.parse_args() +print(opt) + +if __name__ == '__main__': + + ## Load dataset + print('Construct dataset ..') + test_transform = GT.Compose([GT.NormalizeScale(),]) + + test_dataset = ShapeNetPartSegDataset( + root_dir=opt.dataset, + train=False, + transform=test_transform, + npoints=opt.npoints + ) + num_classes = test_dataset.num_classes() + + print('test dataset size: ', len(test_dataset)) + print('num_classes: ', num_classes) + + + # Load model + print('Construct model ..') + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + dtype = torch.float + + # net = PointNetPartSegmentNet(num_classes) + net = PointNet2PartSegmentNet(num_classes) + + net.load_state_dict(torch.load(opt.model)) + net = net.to(device, dtype) + net.eval() + + + ## + def eval_sample(net, sample): + ''' + sample: { 'points': tensor(n, 3), 'labels': tensor(n,) } + return: (pred_label, gt_label) with labels shape (n,) + ''' + net.eval() + with torch.no_grad(): + # points: (n, 3) + points, gt_label = sample['points'], sample['labels'] + n = points.shape[0] + + points = points.view(1, n, 3) # make a batch + points = points.transpose(1, 2).contiguous() + points = points.to(device, dtype) + + pred = net(points) # (batch_size, n, num_classes) + pred_label = pred.max(2)[1] + pred_label = pred_label.view(-1).cpu() # (n,) + + assert pred_label.shape == gt_label.shape + return (pred_label, gt_label) + + + def compute_mIoU(pred_label, gt_label): + minl, maxl = np.min(gt_label), np.max(gt_label) + ious = [] + for l in range(minl, maxl+1): + I = np.sum(np.logical_and(pred_label == l, gt_label == l)) + U = np.sum(np.logical_or(pred_label == l, gt_label == l)) + if U == 0: iou = 1 + else: iou = float(I) / U + ious.append(iou) + return np.mean(ious) + + + def label_diff(pred_label, gt_label): + ''' + Assign 1 if different label, or 0 if same label + ''' + diff = pred_label - gt_label + diff_mask = (diff != 0) + + diff_label = np.zeros((pred_label.shape[0]), dtype=np.int32) + diff_label[diff_mask] = 1 + + return diff_label + + + # Get one sample and eval + sample = test_dataset[opt.sample_idx] + + print('Eval test sample ..') + pred_label, gt_label = eval_sample(net, sample) + print('Eval done ..') + + + # Get sample result + print('Compute mIoU ..') + points = sample['points'].numpy() + pred_labels = pred_label.numpy() + gt_labels = gt_label.numpy() + diff_labels = label_diff(pred_labels, gt_labels) + + print('mIoU: ', compute_mIoU(pred_labels, gt_labels)) + + + # View result + + # print('View gt labels ..') + # view_points_labels(points, gt_labels) + + # print('View diff labels ..') + # view_points_labels(points, diff_labels) + + print('View pred labels ..') + view_points_labels(points, pred_labels) diff --git a/vis/view.py b/vis/view.py new file mode 100644 index 0000000..948fbf1 --- /dev/null +++ b/vis/view.py @@ -0,0 +1,63 @@ +import open3d as o3d +import numpy as np + + +def mini_color_table(index, norm=True): + colors = [ + [0.5000, 0.5400, 0.5300], [0.8900, 0.1500, 0.2100], [0.6400, 0.5800, 0.5000], + [1.0000, 0.3800, 0.0100], [1.0000, 0.6600, 0.1400], [0.4980, 1.0000, 0.0000], + [0.4980, 1.0000, 0.8314], [0.9412, 0.9725, 1.0000], [0.5412, 0.1686, 0.8863], + [0.5765, 0.4392, 0.8588], [0.3600, 0.1400, 0.4300], [0.5600, 0.3700, 0.6000], + ] + + assert index >= 0 and index < len(colors) + color = colors[index] + + if not norm: + color[0] *= 255 + color[1] *= 255 + color[2] *= 255 + + return color + + +def view_points(points, colors=None): + ''' + points: np.ndarray with shape (n, 3) + colors: [r, g, b] or np.array with shape (n, 3) + ''' + cloud = o3d.PointCloud() + cloud.points = o3d.Vector3dVector(points) + + if colors is not None: + if isinstance(colors, np.ndarray): + cloud.colors = o3d.Vector3dVector(colors) + else: cloud.paint_uniform_color(colors) + + o3d.draw_geometries([cloud]) + + +def label2color(labels): + ''' + labels: np.ndarray with shape (n, ) + colors(return): np.ndarray with shape (n, 3) + ''' + num = labels.shape[0] + colors = np.zeros((num, 3)) + + minl, maxl = np.min(labels), np.max(labels) + for l in range(minl, maxl + 1): + colors[labels==l, :] = mini_color_table(l) + + return colors + + +def view_points_labels(points, labels): + ''' + Assign points with colors by labels and view colored points. + points: np.ndarray with shape (n, 3) + labels: np.ndarray with shape (n, 1), dtype=np.int32 + ''' + assert points.shape[0] == labels.shape[0] + colors = label2color(labels) + view_points(points, colors)