initial commit

This commit is contained in:
Si11ium 2019-07-30 07:44:45 +02:00
commit 0159363642
12 changed files with 1060 additions and 0 deletions

129
.gitignore vendored Normal file
View File

@ -0,0 +1,129 @@
# Created by .ignore support plugin (hsz.mobi)
### Python template
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
/data/
/checkpoint/
/shapenet/

2
.idea/.gitignore generated vendored Normal file
View File

@ -0,0 +1,2 @@
# Default ignored files
/workspace.xml

7
.idea/misc.xml generated Normal file
View File

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="JavaScriptSettings">
<option name="languageLevel" value="ES6" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (torch)" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml generated Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/pointnet2-pytorch.iml" filepath="$PROJECT_DIR$/.idea/pointnet2-pytorch.iml" />
</modules>
</component>
</project>

14
.idea/pointnet2-pytorch.iml generated Normal file
View File

@ -0,0 +1,14 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<excludeFolder url="file://$MODULE_DIR$/data" />
<excludeFolder url="file://$MODULE_DIR$/net" />
</content>
<orderEntry type="jdk" jdkName="Python 3.7 (torch)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="TestRunnerService">
<option name="PROJECT_TEST_RUNNER" value="Unittests" />
</component>
</module>

6
.idea/vcs.xml generated Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>

54
README.md Normal file
View File

@ -0,0 +1,54 @@
# Pointnet++ Part segmentation
This repo is implementation for [PointNet++](https://arxiv.org/abs/1706.02413) part segmentation model based on [PyTorch](https://pytorch.org) and [pytorch_geometric](https://github.com/rusty1s/pytorch_geometric). It can achieve comparable or better performance even compared with [PointCNN](https://arxiv.org/abs/1801.07791) on Shapenet dataset.
**The model has been mergered into [pytorch_geometric](https://github.com/rusty1s/pytorch_geometric) as a point cloud segmentation [example](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/pointnet2_segmentation.py), you can try it.**
# Performance
Segmentation on [A subset of shapenet](http://web.stanford.edu/~ericyi/project_page/part_annotation/index.html).
| Method | mcIoU|Airplane|Bag|Cap|Car|Chair|Earphone|Guitar|Knife|Lamp|Laptop|Motorbike|Mug|Pistol|Rocket|Skateboard|Table
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| PointNet++ | 81.9| 82.4| 79.0| 87.7| 77.3 |90.8| 71.8| 91.0| 85.9| 83.7| 95.3| 71.6| 94.1| 81.3| 58.7| 76.4| 82.6|
| PointCNN | 84.6| 84.11| **86.47**| 86.04| **80.83**| 90.62| **79.70**| 92.32| 88.44| 85.31| 96.11| **77.20**| 95.28| 84.21| 64.23| **80.00**| 82.99|
| PointNet++(this repo) | **84.68**| **85.42**| 85.92| **88.39**| 79.73| **91.86**| 75.37| **92.95**| **88.56**| **85.72**| **97.00**| 72.94| **96.88**| **84.52**| **64.38**| 79.39| **85.91**|
mcIOU: mean per-class pIoU
# Requirements
- [PyTorch](https://pytorch.org)
- [pytorch_geometric](https://github.com/rusty1s/pytorch_geometric)
- [Open3D](https://github.com/intel-isl/Open3D)(optional, for visualization of segmentation result)
## Quickly install pytorch_geometric and Open3D with Anaconda
```
$ pip install --verbose --no-cache-dir torch-scatter
$ pip install --verbose --no-cache-dir torch-sparse
$ pip install --verbose --no-cache-dir torch-cluster
$ pip install --verbose --no-cache-dir torch-spline-conv (optional)
$ pip install torch-geometric
```
```
# optional
conda install -c open3d-admin open3d
```
# Usage
Training
```
python main.py
```
Show segmentation result
```
python vis/show_seg_res.py
```
# Sample segmentation result
![segmentation_result](figs/segmentation_result.png)
# Links
- [pointnet.pytorch](https://github.com/fxia22/pointnet.pytorch) by fxia22. This repo's tranining code is heavily borrowed from fxia22's repo.
- Official [PointNet](https://github.com/charlesq34/pointnet) and [PointNet++](https://github.com/charlesq34/pointnet2) tensorflow implementations
- [PointNet++ classification example](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/pointnet%2B%2B.py) of pytorch_geometric library

158
dataset/shapenet.py Normal file
View File

@ -0,0 +1,158 @@
import os
import numpy as np
from torch.utils.data import Dataset
from torch_geometric.datasets import ShapeNet
from itertools import repeat, product
from collections import defaultdict
import os
from tqdm import tqdm
import os.path as osp
import glob
import torch
from torch_geometric.data import (Data, InMemoryDataset, download_url,
extract_zip)
from torch.utils.data import Dataset
from torch_geometric.read import read_txt_array
from torch_geometric.datasets import ShapeNet
from torch_geometric.read import parse_txt_array
class CustomShapeNet(InMemoryDataset):
categories = {key: val for val, key in enumerate(['Box', 'Cone', 'Cylinder', 'Sphere'])}
def __init__(self, root, train=True, transform=None, pre_filter=None, pre_transform=None, **kwargs):
super(CustomShapeNet, self).__init__(root, transform, pre_transform, pre_filter)
path = self.processed_paths[0] if train else self.processed_paths[1]
self.data, self.slices = torch.load(path)
print("Initialized")
@property
def raw_file_names(self):
# Maybe add more data like validation sets
return ['train', 'test']
@property
def processed_file_names(self):
return [f'{x}.pt' for x in self.raw_file_names]
def download(self):
dir_count = len([name for name in os.listdir(self.raw_dir) if os.path.isdir(os.path.join(self.raw_dir, name))])
print(f'{dir_count} folders have been found....')
if dir_count:
return dir_count
raise IOError("No raw pointclouds have been found.")
@property
def num_classes(self):
return len(self.categories)
def _load_dataset(self):
data, slices = None, None
while True:
try:
filepath = os.path.join(self.root, self.processed_dir, f'{"train" if self.train else "test"}.pt')
data, slices = torch.load(filepath)
print('Dataset Loaded')
break
except FileNotFoundError:
self.process()
continue
return data, slices
def process(self, delimiter=' '):
# idx = self.categories[self.category]
# paths = [osp.join(path, idx) for path in self.raw_paths]
datasets = defaultdict(list)
for idx, setting in enumerate(self.raw_file_names):
for pointcloud in tqdm(os.scandir(os.path.join(self.raw_dir, setting))):
if not os.path.isdir(pointcloud):
continue
for element in glob.glob(os.path.join(pointcloud.path, '*.dat')):
if os.path.split(element)[-1] not in ['pc.dat']:
# Assign training data to the data container
# Following the original logic;
# y should be the label;
# pos should be the six dimensional vector describing: !its pos not points!!
# x,y,z,x_rot,y_rot,z_rot
y_raw = os.path.splitext(element)[0].split('_')[-2]
with open(element,'r') as f:
headers = f.__next__()
# Check if there are no useable nodes in this file, header says 0.
if not int(headers.rstrip().split(delimiter)[0]):
continue
# Get the y - Label
# Iterate over all rows
src = [[float(x) if x not in ['-nan(ind)', 'nan(ind)'] else 0
for x in line.rstrip().split(delimiter)[None:None]] for line in f if line != '']
points = torch.tensor(src, dtype=None).squeeze()
if not len(points.shape) > 1:
continue
# pos = points[:, :3]
# norm = points[:, 3:]
y_all = [self.categories[y_raw]] * points.shape[0]
y = torch.as_tensor(y_all, dtype=torch.int)
# points = torch.as_tensor(points, dtype=torch.float)
# norm = torch.as_tensor(norm, dtype=torch.float)
data = Data(y=y, pos=points[:, :3])
# , points=points, norm=points[:3], )
# ToDo: ANy filter to apply? Then do it here.
if self.pre_filter is not None and not self.pre_filter(data):
data = self.pre_filter(data)
raise NotImplementedError
# ToDo: ANy transformation to apply? Then do it here.
if self.pre_transform is not None:
data = self.pre_transform(data)
raise NotImplementedError
datasets[setting].append(data)
os.makedirs(self.processed_dir, exist_ok=True)
torch.save(self.collate(datasets[setting]), self.processed_paths[idx])
def __repr__(self):
return f'{self.__class__.__name__}({len(self)})'
class ShapeNetPartSegDataset(Dataset):
"""
Resample raw point cloud to fixed number of points.
Map raw label from range [1, N] to [0, N-1].
"""
def __init__(self, root_dir, train=True, transform=None, npoints=1024):
super(ShapeNetPartSegDataset, self).__init__()
self.npoints = npoints
self.dataset = CustomShapeNet(root=root_dir, train=train, transform=transform)
def __getitem__(self, index):
data = self.dataset[index]
points, labels = data.pos, data.y
# Resample to fixed number of points
try:
choice = np.random.choice(points.shape[0], self.npoints, replace=True)
except ValueError:
choice = []
points, labels = points[choice, :], labels[choice]
labels -= 1 if self.num_classes() in labels else 0 # Map label from [1, C] to [0, C-1]
sample = {
'points': points, # torch.Tensor (n, 3)
'labels': labels # torch.Tensor (n,)
}
return sample
def __len__(self):
return len(self.dataset)
def num_classes(self):
return self.dataset.num_classes

200
main.py Normal file
View File

@ -0,0 +1,200 @@
"""
Modified from https://github.com/fxia22/pointnet.pytorch/blob/master/utils/train_segmentation.py
"""
import os, sys
import random
import numpy as np
import argparse
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import autograd
import torch.backends.cudnn as cudnn
from dataset.shapenet import ShapeNetPartSegDataset
from model.pointnet2_part_seg import PointNet2PartSegmentNet
import torch_geometric.transforms as GT
import time
fs_root = os.path.splitdrive(sys.executable)[0]
# Argument parser
parser = argparse.ArgumentParser()
default_data_dir = os.path.join(os.getcwd(), 'data')
parser.add_argument('--dataset', type=str, default=default_data_dir, help='dataset path')
parser.add_argument('--npoints', type=int, default=50, help='resample points number')
parser.add_argument('--model', type=str, default='checkpoint//seg_model_custom_24.pth', help='model path')
parser.add_argument('--nepoch', type=int, default=10, help='number of epochs to train for')
parser.add_argument('--outf', type=str, default='checkpoint', help='output folder')
parser.add_argument('--batch_size', type=int, default=8, help='input batch size')
parser.add_argument('--test_per_batches', type=int, default=1000, help='run a test batch per training batches number')
parser.add_argument('--num_workers', type=int, default=0, help='number of data loading workers')
opt = parser.parse_args()
print(opt)
# Random seed
opt.manual_seed = 123
print('Random seed: ', opt.manual_seed)
random.seed(opt.manual_seed)
np.random.seed(opt.manual_seed)
torch.manual_seed(opt.manual_seed)
torch.cuda.manual_seed(opt.manual_seed)
if __name__ == '__main__':
# Dataset and transform
print('Construct dataset ..')
if True:
rot_max_angle = 15
trans_max_distance = 0.01
RotTransform = GT.Compose([GT.RandomRotate(rot_max_angle, 0),
GT.RandomRotate(rot_max_angle, 1),
GT.RandomRotate(rot_max_angle, 2)]
)
TransTransform = GT.RandomTranslate(trans_max_distance)
train_transform = GT.Compose([GT.NormalizeScale(), RotTransform, TransTransform])
test_transform = GT.Compose([GT.NormalizeScale(), ])
dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, train=True, transform=train_transform, npoints=opt.npoints)
dataloader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers)
test_dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, train=False, transform=test_transform, npoints=opt.npoints)
test_dataloader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers)
num_classes = dataset.num_classes()
print('dataset size: ', len(dataset))
print('test_dataset size: ', len(test_dataset))
print('num_classes: ', num_classes)
try:
os.mkdir(opt.outf)
except OSError:
#FIXME: Why is this just a pass? What about missing permissions? LOL
pass
## Model, criterion and optimizer
print('Construct model ..')
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
dtype = torch.float
print('cudnn.enabled: ', torch.backends.cudnn.enabled)
net = PointNet2PartSegmentNet(num_classes)
if opt.model != '':
net.load_state_dict(torch.load(opt.model))
net = net.to(device, dtype)
criterion = nn.NLLLoss()
optimizer = optim.Adam(net.parameters())
if True:
## Train
print('Training ..')
blue = lambda x: '\033[94m' + x + '\033[0m'
num_batch = len(dataset) // opt.batch_size
test_per_batches = opt.test_per_batches
print('number of epoches: ', opt.nepoch)
print('number of batches per epoch: ', num_batch)
print('run test per batches: ', test_per_batches)
for epoch in range(opt.nepoch):
print('Epoch {}, total epoches {}'.format(epoch+1, opt.nepoch))
net.train()
for batch_idx, sample in enumerate(dataloader):
# points: (batch_size, n, 3)
# labels: (batch_size, n)
points, labels = sample['points'], sample['labels']
points = points.transpose(1, 2).contiguous() # (batch_size, 3, n)
points, labels = points.to(device, dtype), labels.to(device, torch.long)
optimizer.zero_grad()
pred = net(points) # (batch_size, n, num_classes)
pred = pred.view(-1, num_classes) # (batch_size * n, num_classes)
target = labels.view(-1, 1)[:, 0]
loss = F.nll_loss(pred, target)
loss.backward()
optimizer.step()
##
pred_label = pred.detach().max(1)[1]
correct = pred_label.eq(target.detach()).cpu().sum()
total = pred_label.shape[0]
print('[{}: {}/{}] train loss: {} accuracy: {}'.format(epoch, batch_idx, num_batch, loss.item(), float(correct.item())/total))
##
if batch_idx % test_per_batches == 0:
print('Run a test batch')
net.eval()
with torch.no_grad():
batch_idx, sample = next(enumerate(test_dataloader))
points, labels = sample['points'], sample['labels']
points = points.transpose(1, 2).contiguous()
points, labels = points.to(device, dtype), labels.to(device, torch.long)
pred = net(points)
pred = pred.view(-1, num_classes)
target = labels.view(-1, 1)[:, 0]
target += 1 if -1 in target else 0
loss = F.nll_loss(pred, target)
pred_label = pred.detach().max(1)[1]
correct = pred_label.eq(target.detach()).cpu().sum()
total = pred_label.shape[0]
print('[{}: {}/{}] {} loss: {} accuracy: {}'.format(epoch, batch_idx, num_batch, blue('test'), loss.item(), float(correct.item())/total))
# Back to training mode
net.train()
torch.save(net.state_dict(), f'{opt.outf}/seg_model_custom_{epoch}.pth')
## Benchmarm mIOU
net.eval()
shape_ious = []
with torch.no_grad():
for batch_idx, sample in enumerate(test_dataloader):
points, labels = sample['points'], sample['labels']
points = points.transpose(1, 2).contiguous()
points = points.to(device, dtype)
# start_t = time.time()
pred = net(points) # (batch_size, n, num_classes)
# print('batch inference forward time used: {} ms'.format(time.time() - start_t))
pred_label = pred.max(2)[1]
pred_label = pred_label.cpu().numpy()
target_label = labels.numpy()
batch_size = target_label.shape[0]
for shape_idx in range(batch_size):
parts = range(num_classes) # np.unique(target_label[shape_idx])
part_ious = []
for part in parts:
I = np.sum(np.logical_and(pred_label[shape_idx] == part, target_label[shape_idx] == part))
U = np.sum(np.logical_or(pred_label[shape_idx] == part, target_label[shape_idx] == part))
iou = 1 if U == 0 else float(I) / U
part_ious.append(iou)
shape_ious.append(np.mean(part_ious))
print(f'mIOU for us Custom: {np.mean(shape_ious)}')
print('Done.')

286
model/pointnet2_part_seg.py Normal file
View File

@ -0,0 +1,286 @@
import torch
import torch.nn.functional as F
from torch.nn import Sequential as Seq, Linear as Lin, ReLU, Dropout, BatchNorm1d
from torch_geometric.nn import PointConv, fps, radius, knn
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_geometric.data.data import Data
from torch_scatter import scatter_add, scatter_max
class PointNet2SAModule(torch.nn.Module):
def __init__(self, sample_radio, radius, max_num_neighbors, mlp):
super(PointNet2SAModule, self).__init__()
self.sample_ratio = sample_radio
self.radius = radius
self.max_num_neighbors = max_num_neighbors
self.point_conv = PointConv(mlp)
def forward(self, data):
x, pos, batch = data
# Sample
idx = fps(pos, batch, ratio=self.sample_ratio)
# Group(Build graph)
row, col = radius(pos, pos[idx], self.radius, batch, batch[idx], max_num_neighbors=self.max_num_neighbors)
edge_index = torch.stack([col, row], dim=0)
# Apply pointnet
x1 = self.point_conv(x, (pos, pos[idx]), edge_index)
pos1, batch1 = pos[idx], batch[idx]
return x1, pos1, batch1
class PointNet2GlobalSAModule(torch.nn.Module):
'''
One group with all input points, can be viewed as a simple PointNet module.
It also return the only one output point(set as origin point).
'''
def __init__(self, mlp):
super(PointNet2GlobalSAModule, self).__init__()
self.mlp = mlp
def forward(self, data):
x, pos, batch = data
if x is not None: x = torch.cat([x, pos], dim=1)
x1 = self.mlp(x)
x1 = scatter_max(x1, batch, dim=0)[0] # (batch_size, C1)
batch_size = x1.shape[0]
pos1 = x1.new_zeros((batch_size, 3)) # set the output point as origin
batch1 = torch.arange(batch_size).to(batch.device, batch.dtype)
return x1, pos1, batch1
class PointConvFP(MessagePassing):
'''
Core layer of Feature propagtaion module.
'''
def __init__(self, mlp=None):
super(PointConvFP, self).__init__('add', 'source_to_target')
self.mlp = mlp
self.aggr = 'add'
self.flow = 'source_to_target'
self.reset_parameters()
def reset_parameters(self):
reset(self.mlp)
def forward(self, x, pos, edge_index):
r"""
Args:
x (tuple), (tensor, tensor) or (tensor, NoneType)
pos (tuple): The node position matrix. Either given as
tensor for use in general message passing or as tuple for use
in message passing in bipartite graphs.
edge_index (LongTensor): The edge indices.
"""
# Do not pass (tensor, None) directly into propagate(), sice it will check each item's size() inside.
x_tmp = x[0] if x[1] is None else x
aggr_out = self.propagate(edge_index, x=x_tmp, pos=pos)
#
i, j = (0, 1) if self.flow == 'target_to_source' else (1, 0)
x_target, pos_target = x[i], pos[i]
add = [pos_target,] if x_target is None else [x_target, pos_target]
aggr_out = torch.cat([aggr_out, *add], dim=1)
if self.mlp is not None: aggr_out = self.mlp(aggr_out)
return aggr_out
def message(self, x_j, pos_j, pos_i, edge_index):
'''
x_j: (E, in_channels)
pos_j: (E, 3)
pos_i: (E, 3)
'''
dist = (pos_j - pos_i).pow(2).sum(dim=1).pow(0.5)
dist = torch.max(dist, torch.Tensor([1e-10]).to(dist.device, dist.dtype))
weight = 1.0 / dist # (E,)
row, col = edge_index
index = col
num_nodes = maybe_num_nodes(index, None)
wsum = scatter_add(weight, col, dim=0, dim_size=num_nodes)[index] + 1e-16 # (E,)
weight /= wsum
return weight.view(-1, 1) * x_j
def update(self, aggr_out):
return aggr_out
class PointNet2FPModule(torch.nn.Module):
def __init__(self, knn_num, mlp):
super(PointNet2FPModule, self).__init__()
self.knn_num = knn_num
self.point_conv = PointConvFP(mlp)
def forward(self, in_layer_data, skip_layer_data):
in_x, in_pos, in_batch = in_layer_data
skip_x, skip_pos, skip_batch = skip_layer_data
row, col = knn(in_pos, skip_pos, self.knn_num, in_batch, skip_batch)
edge_index = torch.stack([col, row], dim=0)
x1 = self.point_conv((in_x, skip_x), (in_pos, skip_pos), edge_index)
pos1, batch1 = skip_pos, skip_batch
return x1, pos1, batch1
def make_mlp(in_channels, mlp_channels, batch_norm=True):
assert len(mlp_channels) >= 1
layers = []
for c in mlp_channels:
layers += [Lin(in_channels, c)]
if batch_norm: layers += [BatchNorm1d(c)]
layers += [ReLU()]
in_channels = c
return Seq(*layers)
class PointNet2PartSegmentNet(torch.nn.Module):
'''
ref:
- https://github.com/charlesq34/pointnet2/blob/master/models/pointnet2_part_seg.py
- https://github.com/rusty1s/pytorch_geometric/blob/master/examples/pointnet++.py
'''
def __init__(self, num_classes):
super(PointNet2PartSegmentNet, self).__init__()
self.num_classes = num_classes
# SA1
sa1_sample_ratio = 0.5
sa1_radius = 0.2
sa1_max_num_neighbours = 64
sa1_mlp = make_mlp(3, [64, 64, 128])
self.sa1_module = PointNet2SAModule(sa1_sample_ratio, sa1_radius, sa1_max_num_neighbours, sa1_mlp)
# SA2
sa2_sample_ratio = 0.25
sa2_radius = 0.4
sa2_max_num_neighbours = 64
sa2_mlp = make_mlp(128+3, [128, 128, 256])
self.sa2_module = PointNet2SAModule(sa2_sample_ratio, sa2_radius, sa2_max_num_neighbours, sa2_mlp)
# SA3
sa3_mlp = make_mlp(256+3, [256, 512, 1024])
self.sa3_module = PointNet2GlobalSAModule(sa3_mlp)
##
knn_num = 3
# FP3, reverse of sa3
fp3_knn_num = 1 # After global sa module, there is only one point in point cloud
fp3_mlp = make_mlp(1024+256+3, [256, 256])
self.fp3_module = PointNet2FPModule(fp3_knn_num, fp3_mlp)
# FP2, reverse of sa2
fp2_knn_num = knn_num
fp2_mlp = make_mlp(256+128+3, [256, 128])
self.fp2_module = PointNet2FPModule(fp2_knn_num, fp2_mlp)
# FP1, reverse of sa1
fp1_knn_num = knn_num
fp1_mlp = make_mlp(128+3, [128, 128, 128])
self.fp1_module = PointNet2FPModule(fp1_knn_num, fp1_mlp)
self.fc1 = Lin(128, 128)
self.dropout1 = Dropout(p=0.5)
self.fc2 = Lin(128, self.num_classes)
def forward(self, data):
'''
data: a batch of input, torch.Tensor or torch_geometric.data.Data type
- torch.Tensor: (batch_size, 3, num_points), as common batch input
- torch_geometric.data.Data, as torch_geometric batch input:
data.x: (batch_size * ~num_points, C), batch nodes/points feature,
~num_points means each sample can have different number of points/nodes
data.pos: (batch_size * ~num_points, 3)
data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud
idendifiers for all nodes of all graphs/pointclouds in the batch. See
pytorch_gemometric documentation for more information
'''
dense_input = True if isinstance(data, torch.Tensor) else False
if dense_input:
# Convert to torch_geometric.data.Data type
data = data.transpose(1, 2).contiguous()
batch_size, N, _ = data.shape # (batch_size, num_points, 3)
pos = data.view(batch_size*N, -1)
batch = torch.zeros((batch_size, N), device=pos.device, dtype=torch.long)
for i in range(batch_size): batch[i] = i
batch = batch.view(-1)
data = Data()
data.pos, data.batch = pos, batch
if not hasattr(data, 'x'): data.x = None
data_in = data.x, data.pos, data.batch
sa1_out = self.sa1_module(data_in)
sa2_out = self.sa2_module(sa1_out)
sa3_out = self.sa3_module(sa2_out)
fp3_out = self.fp3_module(sa3_out, sa2_out)
fp2_out = self.fp2_module(fp3_out, sa1_out)
fp1_out = self.fp1_module(fp2_out, data_in)
fp1_out_x, fp1_out_pos, fp1_out_batch = fp1_out
x = self.fc2(self.dropout1(self.fc1(fp1_out_x)))
x = F.log_softmax(x, dim=-1)
if dense_input: return x.view(batch_size, N, self.num_classes)
else: return x, fp1_out_batch
if __name__ == '__main__':
num_classes = 10
net = PointNet2PartSegmentNet(num_classes)
#
print('Test dense input ..')
data1 = torch.rand((2, 3, 1024)) # (batch_size, 3, num_points)
print('data1: ', data1.shape)
out1 = net(data1)
print('out1: ', out1.shape)
#
print('Test torch_geometric.data.Data input ..')
def make_data_batch():
# batch_size = 2
pos_num1 = 1000
pos_num2 = 1024
data_batch = Data()
# data_batch.x = None
data_batch.pos = torch.cat([torch.rand(pos_num1, 3), torch.rand(pos_num2, 3)], dim=0)
data_batch.batch = torch.cat([torch.zeros(pos_num1, dtype=torch.long), torch.ones(pos_num2, dtype=torch.long)])
return data_batch
data2 = make_data_batch()
# print('data.x: ', data.x)
print('data2.pos: ', data2.pos.shape)
print('data2.batch: ', data2.batch.shape)
out2_x, out2_batch = net(data2)
print('out2_x: ', out2_x.shape)
print('out2_batch: ', out2_batch.shape)

133
vis/show_seg_res.py Normal file
View File

@ -0,0 +1,133 @@
# Warning: import open3d may lead crash, try to to import open3d first here
from view import view_points_labels
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../') # add project root directory
from dataset.shapenet import ShapeNetPartSegDataset
from model.pointnet2_part_seg import PointNet2PartSegmentNet
import torch_geometric.transforms as GT
import torch
import numpy as np
import argparse
##
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='data', help='dataset path')
parser.add_argument('--npoints', type=int, default=50, help='resample points number')
parser.add_argument('--model', type=str, default='./checkpoint/seg_model_Airplane_24.pth', help='model path')
parser.add_argument('--sample_idx', type=int, default=0, help='select a sample to segment and view result')
opt = parser.parse_args()
print(opt)
if __name__ == '__main__':
## Load dataset
print('Construct dataset ..')
test_transform = GT.Compose([GT.NormalizeScale(),])
test_dataset = ShapeNetPartSegDataset(
root_dir=opt.dataset,
train=False,
transform=test_transform,
npoints=opt.npoints
)
num_classes = test_dataset.num_classes()
print('test dataset size: ', len(test_dataset))
print('num_classes: ', num_classes)
# Load model
print('Construct model ..')
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
dtype = torch.float
# net = PointNetPartSegmentNet(num_classes)
net = PointNet2PartSegmentNet(num_classes)
net.load_state_dict(torch.load(opt.model))
net = net.to(device, dtype)
net.eval()
##
def eval_sample(net, sample):
'''
sample: { 'points': tensor(n, 3), 'labels': tensor(n,) }
return: (pred_label, gt_label) with labels shape (n,)
'''
net.eval()
with torch.no_grad():
# points: (n, 3)
points, gt_label = sample['points'], sample['labels']
n = points.shape[0]
points = points.view(1, n, 3) # make a batch
points = points.transpose(1, 2).contiguous()
points = points.to(device, dtype)
pred = net(points) # (batch_size, n, num_classes)
pred_label = pred.max(2)[1]
pred_label = pred_label.view(-1).cpu() # (n,)
assert pred_label.shape == gt_label.shape
return (pred_label, gt_label)
def compute_mIoU(pred_label, gt_label):
minl, maxl = np.min(gt_label), np.max(gt_label)
ious = []
for l in range(minl, maxl+1):
I = np.sum(np.logical_and(pred_label == l, gt_label == l))
U = np.sum(np.logical_or(pred_label == l, gt_label == l))
if U == 0: iou = 1
else: iou = float(I) / U
ious.append(iou)
return np.mean(ious)
def label_diff(pred_label, gt_label):
'''
Assign 1 if different label, or 0 if same label
'''
diff = pred_label - gt_label
diff_mask = (diff != 0)
diff_label = np.zeros((pred_label.shape[0]), dtype=np.int32)
diff_label[diff_mask] = 1
return diff_label
# Get one sample and eval
sample = test_dataset[opt.sample_idx]
print('Eval test sample ..')
pred_label, gt_label = eval_sample(net, sample)
print('Eval done ..')
# Get sample result
print('Compute mIoU ..')
points = sample['points'].numpy()
pred_labels = pred_label.numpy()
gt_labels = gt_label.numpy()
diff_labels = label_diff(pred_labels, gt_labels)
print('mIoU: ', compute_mIoU(pred_labels, gt_labels))
# View result
# print('View gt labels ..')
# view_points_labels(points, gt_labels)
# print('View diff labels ..')
# view_points_labels(points, diff_labels)
print('View pred labels ..')
view_points_labels(points, pred_labels)

63
vis/view.py Normal file
View File

@ -0,0 +1,63 @@
import open3d as o3d
import numpy as np
def mini_color_table(index, norm=True):
colors = [
[0.5000, 0.5400, 0.5300], [0.8900, 0.1500, 0.2100], [0.6400, 0.5800, 0.5000],
[1.0000, 0.3800, 0.0100], [1.0000, 0.6600, 0.1400], [0.4980, 1.0000, 0.0000],
[0.4980, 1.0000, 0.8314], [0.9412, 0.9725, 1.0000], [0.5412, 0.1686, 0.8863],
[0.5765, 0.4392, 0.8588], [0.3600, 0.1400, 0.4300], [0.5600, 0.3700, 0.6000],
]
assert index >= 0 and index < len(colors)
color = colors[index]
if not norm:
color[0] *= 255
color[1] *= 255
color[2] *= 255
return color
def view_points(points, colors=None):
'''
points: np.ndarray with shape (n, 3)
colors: [r, g, b] or np.array with shape (n, 3)
'''
cloud = o3d.PointCloud()
cloud.points = o3d.Vector3dVector(points)
if colors is not None:
if isinstance(colors, np.ndarray):
cloud.colors = o3d.Vector3dVector(colors)
else: cloud.paint_uniform_color(colors)
o3d.draw_geometries([cloud])
def label2color(labels):
'''
labels: np.ndarray with shape (n, )
colors(return): np.ndarray with shape (n, 3)
'''
num = labels.shape[0]
colors = np.zeros((num, 3))
minl, maxl = np.min(labels), np.max(labels)
for l in range(minl, maxl + 1):
colors[labels==l, :] = mini_color_table(l)
return colors
def view_points_labels(points, labels):
'''
Assign points with colors by labels and view colored points.
points: np.ndarray with shape (n, 3)
labels: np.ndarray with shape (n, 1), dtype=np.int32
'''
assert points.shape[0] == labels.shape[0]
colors = label2color(labels)
view_points(points, colors)