initial commit
This commit is contained in:
commit
0159363642
129
.gitignore
vendored
Normal file
129
.gitignore
vendored
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
# Created by .ignore support plugin (hsz.mobi)
|
||||||
|
### Python template
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
pip-wheel-metadata/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# celery beat schedule file
|
||||||
|
celerybeat-schedule
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
/data/
|
||||||
|
/checkpoint/
|
||||||
|
/shapenet/
|
2
.idea/.gitignore
generated
vendored
Normal file
2
.idea/.gitignore
generated
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
# Default ignored files
|
||||||
|
/workspace.xml
|
7
.idea/misc.xml
generated
Normal file
7
.idea/misc.xml
generated
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="JavaScriptSettings">
|
||||||
|
<option name="languageLevel" value="ES6" />
|
||||||
|
</component>
|
||||||
|
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (torch)" project-jdk-type="Python SDK" />
|
||||||
|
</project>
|
8
.idea/modules.xml
generated
Normal file
8
.idea/modules.xml
generated
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/pointnet2-pytorch.iml" filepath="$PROJECT_DIR$/.idea/pointnet2-pytorch.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
14
.idea/pointnet2-pytorch.iml
generated
Normal file
14
.idea/pointnet2-pytorch.iml
generated
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="PYTHON_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$">
|
||||||
|
<excludeFolder url="file://$MODULE_DIR$/data" />
|
||||||
|
<excludeFolder url="file://$MODULE_DIR$/net" />
|
||||||
|
</content>
|
||||||
|
<orderEntry type="jdk" jdkName="Python 3.7 (torch)" jdkType="Python SDK" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
<component name="TestRunnerService">
|
||||||
|
<option name="PROJECT_TEST_RUNNER" value="Unittests" />
|
||||||
|
</component>
|
||||||
|
</module>
|
6
.idea/vcs.xml
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
54
README.md
Normal file
54
README.md
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
# Pointnet++ Part segmentation
|
||||||
|
This repo is implementation for [PointNet++](https://arxiv.org/abs/1706.02413) part segmentation model based on [PyTorch](https://pytorch.org) and [pytorch_geometric](https://github.com/rusty1s/pytorch_geometric). It can achieve comparable or better performance even compared with [PointCNN](https://arxiv.org/abs/1801.07791) on Shapenet dataset.
|
||||||
|
|
||||||
|
**The model has been mergered into [pytorch_geometric](https://github.com/rusty1s/pytorch_geometric) as a point cloud segmentation [example](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/pointnet2_segmentation.py), you can try it.**
|
||||||
|
|
||||||
|
# Performance
|
||||||
|
Segmentation on [A subset of shapenet](http://web.stanford.edu/~ericyi/project_page/part_annotation/index.html).
|
||||||
|
|
||||||
|
| Method | mcIoU|Airplane|Bag|Cap|Car|Chair|Earphone|Guitar|Knife|Lamp|Laptop|Motorbike|Mug|Pistol|Rocket|Skateboard|Table
|
||||||
|
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
|
||||||
|
| PointNet++ | 81.9| 82.4| 79.0| 87.7| 77.3 |90.8| 71.8| 91.0| 85.9| 83.7| 95.3| 71.6| 94.1| 81.3| 58.7| 76.4| 82.6|
|
||||||
|
| PointCNN | 84.6| 84.11| **86.47**| 86.04| **80.83**| 90.62| **79.70**| 92.32| 88.44| 85.31| 96.11| **77.20**| 95.28| 84.21| 64.23| **80.00**| 82.99|
|
||||||
|
| PointNet++(this repo) | **84.68**| **85.42**| 85.92| **88.39**| 79.73| **91.86**| 75.37| **92.95**| **88.56**| **85.72**| **97.00**| 72.94| **96.88**| **84.52**| **64.38**| 79.39| **85.91**|
|
||||||
|
|
||||||
|
mcIOU: mean per-class pIoU
|
||||||
|
|
||||||
|
# Requirements
|
||||||
|
- [PyTorch](https://pytorch.org)
|
||||||
|
- [pytorch_geometric](https://github.com/rusty1s/pytorch_geometric)
|
||||||
|
- [Open3D](https://github.com/intel-isl/Open3D)(optional, for visualization of segmentation result)
|
||||||
|
|
||||||
|
## Quickly install pytorch_geometric and Open3D with Anaconda
|
||||||
|
```
|
||||||
|
$ pip install --verbose --no-cache-dir torch-scatter
|
||||||
|
$ pip install --verbose --no-cache-dir torch-sparse
|
||||||
|
$ pip install --verbose --no-cache-dir torch-cluster
|
||||||
|
$ pip install --verbose --no-cache-dir torch-spline-conv (optional)
|
||||||
|
$ pip install torch-geometric
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
# optional
|
||||||
|
conda install -c open3d-admin open3d
|
||||||
|
```
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
Training
|
||||||
|
```
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Show segmentation result
|
||||||
|
```
|
||||||
|
python vis/show_seg_res.py
|
||||||
|
```
|
||||||
|
|
||||||
|
# Sample segmentation result
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
|
# Links
|
||||||
|
- [pointnet.pytorch](https://github.com/fxia22/pointnet.pytorch) by fxia22. This repo's tranining code is heavily borrowed from fxia22's repo.
|
||||||
|
- Official [PointNet](https://github.com/charlesq34/pointnet) and [PointNet++](https://github.com/charlesq34/pointnet2) tensorflow implementations
|
||||||
|
- [PointNet++ classification example](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/pointnet%2B%2B.py) of pytorch_geometric library
|
158
dataset/shapenet.py
Normal file
158
dataset/shapenet.py
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torch_geometric.datasets import ShapeNet
|
||||||
|
|
||||||
|
from itertools import repeat, product
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import os
|
||||||
|
from tqdm import tqdm
|
||||||
|
import os.path as osp
|
||||||
|
import glob
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch_geometric.data import (Data, InMemoryDataset, download_url,
|
||||||
|
extract_zip)
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torch_geometric.read import read_txt_array
|
||||||
|
from torch_geometric.datasets import ShapeNet
|
||||||
|
|
||||||
|
from torch_geometric.read import parse_txt_array
|
||||||
|
|
||||||
|
|
||||||
|
class CustomShapeNet(InMemoryDataset):
|
||||||
|
|
||||||
|
categories = {key: val for val, key in enumerate(['Box', 'Cone', 'Cylinder', 'Sphere'])}
|
||||||
|
|
||||||
|
def __init__(self, root, train=True, transform=None, pre_filter=None, pre_transform=None, **kwargs):
|
||||||
|
super(CustomShapeNet, self).__init__(root, transform, pre_transform, pre_filter)
|
||||||
|
path = self.processed_paths[0] if train else self.processed_paths[1]
|
||||||
|
self.data, self.slices = torch.load(path)
|
||||||
|
print("Initialized")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def raw_file_names(self):
|
||||||
|
# Maybe add more data like validation sets
|
||||||
|
return ['train', 'test']
|
||||||
|
|
||||||
|
@property
|
||||||
|
def processed_file_names(self):
|
||||||
|
return [f'{x}.pt' for x in self.raw_file_names]
|
||||||
|
|
||||||
|
def download(self):
|
||||||
|
dir_count = len([name for name in os.listdir(self.raw_dir) if os.path.isdir(os.path.join(self.raw_dir, name))])
|
||||||
|
print(f'{dir_count} folders have been found....')
|
||||||
|
if dir_count:
|
||||||
|
return dir_count
|
||||||
|
raise IOError("No raw pointclouds have been found.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_classes(self):
|
||||||
|
return len(self.categories)
|
||||||
|
|
||||||
|
def _load_dataset(self):
|
||||||
|
data, slices = None, None
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
filepath = os.path.join(self.root, self.processed_dir, f'{"train" if self.train else "test"}.pt')
|
||||||
|
data, slices = torch.load(filepath)
|
||||||
|
print('Dataset Loaded')
|
||||||
|
break
|
||||||
|
except FileNotFoundError:
|
||||||
|
self.process()
|
||||||
|
continue
|
||||||
|
return data, slices
|
||||||
|
|
||||||
|
def process(self, delimiter=' '):
|
||||||
|
# idx = self.categories[self.category]
|
||||||
|
# paths = [osp.join(path, idx) for path in self.raw_paths]
|
||||||
|
|
||||||
|
datasets = defaultdict(list)
|
||||||
|
for idx, setting in enumerate(self.raw_file_names):
|
||||||
|
for pointcloud in tqdm(os.scandir(os.path.join(self.raw_dir, setting))):
|
||||||
|
if not os.path.isdir(pointcloud):
|
||||||
|
continue
|
||||||
|
for element in glob.glob(os.path.join(pointcloud.path, '*.dat')):
|
||||||
|
if os.path.split(element)[-1] not in ['pc.dat']:
|
||||||
|
# Assign training data to the data container
|
||||||
|
# Following the original logic;
|
||||||
|
# y should be the label;
|
||||||
|
# pos should be the six dimensional vector describing: !its pos not points!!
|
||||||
|
# x,y,z,x_rot,y_rot,z_rot
|
||||||
|
y_raw = os.path.splitext(element)[0].split('_')[-2]
|
||||||
|
with open(element,'r') as f:
|
||||||
|
headers = f.__next__()
|
||||||
|
# Check if there are no useable nodes in this file, header says 0.
|
||||||
|
if not int(headers.rstrip().split(delimiter)[0]):
|
||||||
|
continue
|
||||||
|
# Get the y - Label
|
||||||
|
|
||||||
|
# Iterate over all rows
|
||||||
|
src = [[float(x) if x not in ['-nan(ind)', 'nan(ind)'] else 0
|
||||||
|
for x in line.rstrip().split(delimiter)[None:None]] for line in f if line != '']
|
||||||
|
points = torch.tensor(src, dtype=None).squeeze()
|
||||||
|
if not len(points.shape) > 1:
|
||||||
|
continue
|
||||||
|
# pos = points[:, :3]
|
||||||
|
# norm = points[:, 3:]
|
||||||
|
y_all = [self.categories[y_raw]] * points.shape[0]
|
||||||
|
y = torch.as_tensor(y_all, dtype=torch.int)
|
||||||
|
# points = torch.as_tensor(points, dtype=torch.float)
|
||||||
|
# norm = torch.as_tensor(norm, dtype=torch.float)
|
||||||
|
data = Data(y=y, pos=points[:, :3])
|
||||||
|
# , points=points, norm=points[:3], )
|
||||||
|
# ToDo: ANy filter to apply? Then do it here.
|
||||||
|
if self.pre_filter is not None and not self.pre_filter(data):
|
||||||
|
data = self.pre_filter(data)
|
||||||
|
raise NotImplementedError
|
||||||
|
# ToDo: ANy transformation to apply? Then do it here.
|
||||||
|
if self.pre_transform is not None:
|
||||||
|
data = self.pre_transform(data)
|
||||||
|
raise NotImplementedError
|
||||||
|
datasets[setting].append(data)
|
||||||
|
|
||||||
|
os.makedirs(self.processed_dir, exist_ok=True)
|
||||||
|
torch.save(self.collate(datasets[setting]), self.processed_paths[idx])
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'{self.__class__.__name__}({len(self)})'
|
||||||
|
|
||||||
|
|
||||||
|
class ShapeNetPartSegDataset(Dataset):
|
||||||
|
"""
|
||||||
|
Resample raw point cloud to fixed number of points.
|
||||||
|
Map raw label from range [1, N] to [0, N-1].
|
||||||
|
"""
|
||||||
|
def __init__(self, root_dir, train=True, transform=None, npoints=1024):
|
||||||
|
super(ShapeNetPartSegDataset, self).__init__()
|
||||||
|
self.npoints = npoints
|
||||||
|
self.dataset = CustomShapeNet(root=root_dir, train=train, transform=transform)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
data = self.dataset[index]
|
||||||
|
points, labels = data.pos, data.y
|
||||||
|
|
||||||
|
# Resample to fixed number of points
|
||||||
|
try:
|
||||||
|
choice = np.random.choice(points.shape[0], self.npoints, replace=True)
|
||||||
|
except ValueError:
|
||||||
|
choice = []
|
||||||
|
|
||||||
|
points, labels = points[choice, :], labels[choice]
|
||||||
|
|
||||||
|
labels -= 1 if self.num_classes() in labels else 0 # Map label from [1, C] to [0, C-1]
|
||||||
|
|
||||||
|
sample = {
|
||||||
|
'points': points, # torch.Tensor (n, 3)
|
||||||
|
'labels': labels # torch.Tensor (n,)
|
||||||
|
}
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.dataset)
|
||||||
|
|
||||||
|
def num_classes(self):
|
||||||
|
return self.dataset.num_classes
|
200
main.py
Normal file
200
main.py
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
"""
|
||||||
|
Modified from https://github.com/fxia22/pointnet.pytorch/blob/master/utils/train_segmentation.py
|
||||||
|
"""
|
||||||
|
import os, sys
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import autograd
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
|
|
||||||
|
from dataset.shapenet import ShapeNetPartSegDataset
|
||||||
|
from model.pointnet2_part_seg import PointNet2PartSegmentNet
|
||||||
|
import torch_geometric.transforms as GT
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
fs_root = os.path.splitdrive(sys.executable)[0]
|
||||||
|
|
||||||
|
# Argument parser
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
default_data_dir = os.path.join(os.getcwd(), 'data')
|
||||||
|
parser.add_argument('--dataset', type=str, default=default_data_dir, help='dataset path')
|
||||||
|
parser.add_argument('--npoints', type=int, default=50, help='resample points number')
|
||||||
|
parser.add_argument('--model', type=str, default='checkpoint//seg_model_custom_24.pth', help='model path')
|
||||||
|
parser.add_argument('--nepoch', type=int, default=10, help='number of epochs to train for')
|
||||||
|
parser.add_argument('--outf', type=str, default='checkpoint', help='output folder')
|
||||||
|
parser.add_argument('--batch_size', type=int, default=8, help='input batch size')
|
||||||
|
parser.add_argument('--test_per_batches', type=int, default=1000, help='run a test batch per training batches number')
|
||||||
|
parser.add_argument('--num_workers', type=int, default=0, help='number of data loading workers')
|
||||||
|
|
||||||
|
opt = parser.parse_args()
|
||||||
|
print(opt)
|
||||||
|
|
||||||
|
# Random seed
|
||||||
|
opt.manual_seed = 123
|
||||||
|
print('Random seed: ', opt.manual_seed)
|
||||||
|
random.seed(opt.manual_seed)
|
||||||
|
np.random.seed(opt.manual_seed)
|
||||||
|
torch.manual_seed(opt.manual_seed)
|
||||||
|
torch.cuda.manual_seed(opt.manual_seed)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
# Dataset and transform
|
||||||
|
print('Construct dataset ..')
|
||||||
|
if True:
|
||||||
|
rot_max_angle = 15
|
||||||
|
trans_max_distance = 0.01
|
||||||
|
|
||||||
|
RotTransform = GT.Compose([GT.RandomRotate(rot_max_angle, 0),
|
||||||
|
GT.RandomRotate(rot_max_angle, 1),
|
||||||
|
GT.RandomRotate(rot_max_angle, 2)]
|
||||||
|
)
|
||||||
|
TransTransform = GT.RandomTranslate(trans_max_distance)
|
||||||
|
|
||||||
|
train_transform = GT.Compose([GT.NormalizeScale(), RotTransform, TransTransform])
|
||||||
|
test_transform = GT.Compose([GT.NormalizeScale(), ])
|
||||||
|
|
||||||
|
dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, train=True, transform=train_transform, npoints=opt.npoints)
|
||||||
|
dataloader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers)
|
||||||
|
|
||||||
|
test_dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, train=False, transform=test_transform, npoints=opt.npoints)
|
||||||
|
test_dataloader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers)
|
||||||
|
|
||||||
|
num_classes = dataset.num_classes()
|
||||||
|
|
||||||
|
print('dataset size: ', len(dataset))
|
||||||
|
print('test_dataset size: ', len(test_dataset))
|
||||||
|
print('num_classes: ', num_classes)
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.mkdir(opt.outf)
|
||||||
|
except OSError:
|
||||||
|
#FIXME: Why is this just a pass? What about missing permissions? LOL
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
## Model, criterion and optimizer
|
||||||
|
print('Construct model ..')
|
||||||
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
|
dtype = torch.float
|
||||||
|
print('cudnn.enabled: ', torch.backends.cudnn.enabled)
|
||||||
|
|
||||||
|
|
||||||
|
net = PointNet2PartSegmentNet(num_classes)
|
||||||
|
|
||||||
|
if opt.model != '':
|
||||||
|
net.load_state_dict(torch.load(opt.model))
|
||||||
|
net = net.to(device, dtype)
|
||||||
|
|
||||||
|
criterion = nn.NLLLoss()
|
||||||
|
optimizer = optim.Adam(net.parameters())
|
||||||
|
if True:
|
||||||
|
## Train
|
||||||
|
print('Training ..')
|
||||||
|
blue = lambda x: '\033[94m' + x + '\033[0m'
|
||||||
|
num_batch = len(dataset) // opt.batch_size
|
||||||
|
test_per_batches = opt.test_per_batches
|
||||||
|
|
||||||
|
print('number of epoches: ', opt.nepoch)
|
||||||
|
print('number of batches per epoch: ', num_batch)
|
||||||
|
print('run test per batches: ', test_per_batches)
|
||||||
|
|
||||||
|
for epoch in range(opt.nepoch):
|
||||||
|
print('Epoch {}, total epoches {}'.format(epoch+1, opt.nepoch))
|
||||||
|
|
||||||
|
net.train()
|
||||||
|
|
||||||
|
for batch_idx, sample in enumerate(dataloader):
|
||||||
|
# points: (batch_size, n, 3)
|
||||||
|
# labels: (batch_size, n)
|
||||||
|
points, labels = sample['points'], sample['labels']
|
||||||
|
points = points.transpose(1, 2).contiguous() # (batch_size, 3, n)
|
||||||
|
points, labels = points.to(device, dtype), labels.to(device, torch.long)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
pred = net(points) # (batch_size, n, num_classes)
|
||||||
|
pred = pred.view(-1, num_classes) # (batch_size * n, num_classes)
|
||||||
|
target = labels.view(-1, 1)[:, 0]
|
||||||
|
|
||||||
|
loss = F.nll_loss(pred, target)
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
##
|
||||||
|
pred_label = pred.detach().max(1)[1]
|
||||||
|
correct = pred_label.eq(target.detach()).cpu().sum()
|
||||||
|
total = pred_label.shape[0]
|
||||||
|
|
||||||
|
print('[{}: {}/{}] train loss: {} accuracy: {}'.format(epoch, batch_idx, num_batch, loss.item(), float(correct.item())/total))
|
||||||
|
|
||||||
|
##
|
||||||
|
if batch_idx % test_per_batches == 0:
|
||||||
|
print('Run a test batch')
|
||||||
|
net.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
batch_idx, sample = next(enumerate(test_dataloader))
|
||||||
|
|
||||||
|
points, labels = sample['points'], sample['labels']
|
||||||
|
points = points.transpose(1, 2).contiguous()
|
||||||
|
points, labels = points.to(device, dtype), labels.to(device, torch.long)
|
||||||
|
|
||||||
|
pred = net(points)
|
||||||
|
pred = pred.view(-1, num_classes)
|
||||||
|
target = labels.view(-1, 1)[:, 0]
|
||||||
|
|
||||||
|
target += 1 if -1 in target else 0
|
||||||
|
loss = F.nll_loss(pred, target)
|
||||||
|
|
||||||
|
pred_label = pred.detach().max(1)[1]
|
||||||
|
correct = pred_label.eq(target.detach()).cpu().sum()
|
||||||
|
total = pred_label.shape[0]
|
||||||
|
print('[{}: {}/{}] {} loss: {} accuracy: {}'.format(epoch, batch_idx, num_batch, blue('test'), loss.item(), float(correct.item())/total))
|
||||||
|
|
||||||
|
# Back to training mode
|
||||||
|
net.train()
|
||||||
|
|
||||||
|
torch.save(net.state_dict(), f'{opt.outf}/seg_model_custom_{epoch}.pth')
|
||||||
|
|
||||||
|
|
||||||
|
## Benchmarm mIOU
|
||||||
|
net.eval()
|
||||||
|
shape_ious = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_idx, sample in enumerate(test_dataloader):
|
||||||
|
points, labels = sample['points'], sample['labels']
|
||||||
|
points = points.transpose(1, 2).contiguous()
|
||||||
|
points = points.to(device, dtype)
|
||||||
|
|
||||||
|
# start_t = time.time()
|
||||||
|
pred = net(points) # (batch_size, n, num_classes)
|
||||||
|
# print('batch inference forward time used: {} ms'.format(time.time() - start_t))
|
||||||
|
|
||||||
|
pred_label = pred.max(2)[1]
|
||||||
|
pred_label = pred_label.cpu().numpy()
|
||||||
|
target_label = labels.numpy()
|
||||||
|
|
||||||
|
batch_size = target_label.shape[0]
|
||||||
|
for shape_idx in range(batch_size):
|
||||||
|
parts = range(num_classes) # np.unique(target_label[shape_idx])
|
||||||
|
part_ious = []
|
||||||
|
for part in parts:
|
||||||
|
I = np.sum(np.logical_and(pred_label[shape_idx] == part, target_label[shape_idx] == part))
|
||||||
|
U = np.sum(np.logical_or(pred_label[shape_idx] == part, target_label[shape_idx] == part))
|
||||||
|
iou = 1 if U == 0 else float(I) / U
|
||||||
|
part_ious.append(iou)
|
||||||
|
shape_ious.append(np.mean(part_ious))
|
||||||
|
|
||||||
|
print(f'mIOU for us Custom: {np.mean(shape_ious)}')
|
||||||
|
|
||||||
|
print('Done.')
|
286
model/pointnet2_part_seg.py
Normal file
286
model/pointnet2_part_seg.py
Normal file
@ -0,0 +1,286 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn import Sequential as Seq, Linear as Lin, ReLU, Dropout, BatchNorm1d
|
||||||
|
from torch_geometric.nn import PointConv, fps, radius, knn
|
||||||
|
from torch_geometric.nn.conv import MessagePassing
|
||||||
|
from torch_geometric.nn.inits import reset
|
||||||
|
from torch_geometric.utils.num_nodes import maybe_num_nodes
|
||||||
|
from torch_geometric.data.data import Data
|
||||||
|
from torch_scatter import scatter_add, scatter_max
|
||||||
|
|
||||||
|
|
||||||
|
class PointNet2SAModule(torch.nn.Module):
|
||||||
|
def __init__(self, sample_radio, radius, max_num_neighbors, mlp):
|
||||||
|
super(PointNet2SAModule, self).__init__()
|
||||||
|
self.sample_ratio = sample_radio
|
||||||
|
self.radius = radius
|
||||||
|
self.max_num_neighbors = max_num_neighbors
|
||||||
|
self.point_conv = PointConv(mlp)
|
||||||
|
|
||||||
|
def forward(self, data):
|
||||||
|
x, pos, batch = data
|
||||||
|
|
||||||
|
# Sample
|
||||||
|
idx = fps(pos, batch, ratio=self.sample_ratio)
|
||||||
|
|
||||||
|
# Group(Build graph)
|
||||||
|
row, col = radius(pos, pos[idx], self.radius, batch, batch[idx], max_num_neighbors=self.max_num_neighbors)
|
||||||
|
edge_index = torch.stack([col, row], dim=0)
|
||||||
|
|
||||||
|
# Apply pointnet
|
||||||
|
x1 = self.point_conv(x, (pos, pos[idx]), edge_index)
|
||||||
|
pos1, batch1 = pos[idx], batch[idx]
|
||||||
|
|
||||||
|
return x1, pos1, batch1
|
||||||
|
|
||||||
|
|
||||||
|
class PointNet2GlobalSAModule(torch.nn.Module):
|
||||||
|
'''
|
||||||
|
One group with all input points, can be viewed as a simple PointNet module.
|
||||||
|
It also return the only one output point(set as origin point).
|
||||||
|
'''
|
||||||
|
def __init__(self, mlp):
|
||||||
|
super(PointNet2GlobalSAModule, self).__init__()
|
||||||
|
self.mlp = mlp
|
||||||
|
|
||||||
|
def forward(self, data):
|
||||||
|
x, pos, batch = data
|
||||||
|
if x is not None: x = torch.cat([x, pos], dim=1)
|
||||||
|
x1 = self.mlp(x)
|
||||||
|
|
||||||
|
x1 = scatter_max(x1, batch, dim=0)[0] # (batch_size, C1)
|
||||||
|
|
||||||
|
batch_size = x1.shape[0]
|
||||||
|
pos1 = x1.new_zeros((batch_size, 3)) # set the output point as origin
|
||||||
|
batch1 = torch.arange(batch_size).to(batch.device, batch.dtype)
|
||||||
|
|
||||||
|
return x1, pos1, batch1
|
||||||
|
|
||||||
|
|
||||||
|
class PointConvFP(MessagePassing):
|
||||||
|
'''
|
||||||
|
Core layer of Feature propagtaion module.
|
||||||
|
'''
|
||||||
|
def __init__(self, mlp=None):
|
||||||
|
super(PointConvFP, self).__init__('add', 'source_to_target')
|
||||||
|
self.mlp = mlp
|
||||||
|
self.aggr = 'add'
|
||||||
|
self.flow = 'source_to_target'
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
reset(self.mlp)
|
||||||
|
|
||||||
|
def forward(self, x, pos, edge_index):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
x (tuple), (tensor, tensor) or (tensor, NoneType)
|
||||||
|
pos (tuple): The node position matrix. Either given as
|
||||||
|
tensor for use in general message passing or as tuple for use
|
||||||
|
in message passing in bipartite graphs.
|
||||||
|
edge_index (LongTensor): The edge indices.
|
||||||
|
"""
|
||||||
|
# Do not pass (tensor, None) directly into propagate(), sice it will check each item's size() inside.
|
||||||
|
x_tmp = x[0] if x[1] is None else x
|
||||||
|
aggr_out = self.propagate(edge_index, x=x_tmp, pos=pos)
|
||||||
|
|
||||||
|
#
|
||||||
|
i, j = (0, 1) if self.flow == 'target_to_source' else (1, 0)
|
||||||
|
x_target, pos_target = x[i], pos[i]
|
||||||
|
|
||||||
|
add = [pos_target,] if x_target is None else [x_target, pos_target]
|
||||||
|
aggr_out = torch.cat([aggr_out, *add], dim=1)
|
||||||
|
|
||||||
|
if self.mlp is not None: aggr_out = self.mlp(aggr_out)
|
||||||
|
|
||||||
|
return aggr_out
|
||||||
|
|
||||||
|
def message(self, x_j, pos_j, pos_i, edge_index):
|
||||||
|
'''
|
||||||
|
x_j: (E, in_channels)
|
||||||
|
pos_j: (E, 3)
|
||||||
|
pos_i: (E, 3)
|
||||||
|
'''
|
||||||
|
dist = (pos_j - pos_i).pow(2).sum(dim=1).pow(0.5)
|
||||||
|
dist = torch.max(dist, torch.Tensor([1e-10]).to(dist.device, dist.dtype))
|
||||||
|
weight = 1.0 / dist # (E,)
|
||||||
|
|
||||||
|
row, col = edge_index
|
||||||
|
index = col
|
||||||
|
num_nodes = maybe_num_nodes(index, None)
|
||||||
|
wsum = scatter_add(weight, col, dim=0, dim_size=num_nodes)[index] + 1e-16 # (E,)
|
||||||
|
weight /= wsum
|
||||||
|
|
||||||
|
return weight.view(-1, 1) * x_j
|
||||||
|
|
||||||
|
def update(self, aggr_out):
|
||||||
|
return aggr_out
|
||||||
|
|
||||||
|
|
||||||
|
class PointNet2FPModule(torch.nn.Module):
|
||||||
|
def __init__(self, knn_num, mlp):
|
||||||
|
super(PointNet2FPModule, self).__init__()
|
||||||
|
self.knn_num = knn_num
|
||||||
|
self.point_conv = PointConvFP(mlp)
|
||||||
|
|
||||||
|
def forward(self, in_layer_data, skip_layer_data):
|
||||||
|
in_x, in_pos, in_batch = in_layer_data
|
||||||
|
skip_x, skip_pos, skip_batch = skip_layer_data
|
||||||
|
|
||||||
|
row, col = knn(in_pos, skip_pos, self.knn_num, in_batch, skip_batch)
|
||||||
|
edge_index = torch.stack([col, row], dim=0)
|
||||||
|
|
||||||
|
x1 = self.point_conv((in_x, skip_x), (in_pos, skip_pos), edge_index)
|
||||||
|
pos1, batch1 = skip_pos, skip_batch
|
||||||
|
|
||||||
|
return x1, pos1, batch1
|
||||||
|
|
||||||
|
|
||||||
|
def make_mlp(in_channels, mlp_channels, batch_norm=True):
|
||||||
|
assert len(mlp_channels) >= 1
|
||||||
|
layers = []
|
||||||
|
|
||||||
|
for c in mlp_channels:
|
||||||
|
layers += [Lin(in_channels, c)]
|
||||||
|
if batch_norm: layers += [BatchNorm1d(c)]
|
||||||
|
layers += [ReLU()]
|
||||||
|
|
||||||
|
in_channels = c
|
||||||
|
|
||||||
|
return Seq(*layers)
|
||||||
|
|
||||||
|
|
||||||
|
class PointNet2PartSegmentNet(torch.nn.Module):
|
||||||
|
'''
|
||||||
|
ref:
|
||||||
|
- https://github.com/charlesq34/pointnet2/blob/master/models/pointnet2_part_seg.py
|
||||||
|
- https://github.com/rusty1s/pytorch_geometric/blob/master/examples/pointnet++.py
|
||||||
|
'''
|
||||||
|
def __init__(self, num_classes):
|
||||||
|
super(PointNet2PartSegmentNet, self).__init__()
|
||||||
|
self.num_classes = num_classes
|
||||||
|
|
||||||
|
# SA1
|
||||||
|
sa1_sample_ratio = 0.5
|
||||||
|
sa1_radius = 0.2
|
||||||
|
sa1_max_num_neighbours = 64
|
||||||
|
sa1_mlp = make_mlp(3, [64, 64, 128])
|
||||||
|
self.sa1_module = PointNet2SAModule(sa1_sample_ratio, sa1_radius, sa1_max_num_neighbours, sa1_mlp)
|
||||||
|
|
||||||
|
# SA2
|
||||||
|
sa2_sample_ratio = 0.25
|
||||||
|
sa2_radius = 0.4
|
||||||
|
sa2_max_num_neighbours = 64
|
||||||
|
sa2_mlp = make_mlp(128+3, [128, 128, 256])
|
||||||
|
self.sa2_module = PointNet2SAModule(sa2_sample_ratio, sa2_radius, sa2_max_num_neighbours, sa2_mlp)
|
||||||
|
|
||||||
|
# SA3
|
||||||
|
sa3_mlp = make_mlp(256+3, [256, 512, 1024])
|
||||||
|
self.sa3_module = PointNet2GlobalSAModule(sa3_mlp)
|
||||||
|
|
||||||
|
##
|
||||||
|
knn_num = 3
|
||||||
|
|
||||||
|
# FP3, reverse of sa3
|
||||||
|
fp3_knn_num = 1 # After global sa module, there is only one point in point cloud
|
||||||
|
fp3_mlp = make_mlp(1024+256+3, [256, 256])
|
||||||
|
self.fp3_module = PointNet2FPModule(fp3_knn_num, fp3_mlp)
|
||||||
|
|
||||||
|
# FP2, reverse of sa2
|
||||||
|
fp2_knn_num = knn_num
|
||||||
|
fp2_mlp = make_mlp(256+128+3, [256, 128])
|
||||||
|
self.fp2_module = PointNet2FPModule(fp2_knn_num, fp2_mlp)
|
||||||
|
|
||||||
|
# FP1, reverse of sa1
|
||||||
|
fp1_knn_num = knn_num
|
||||||
|
fp1_mlp = make_mlp(128+3, [128, 128, 128])
|
||||||
|
self.fp1_module = PointNet2FPModule(fp1_knn_num, fp1_mlp)
|
||||||
|
|
||||||
|
self.fc1 = Lin(128, 128)
|
||||||
|
self.dropout1 = Dropout(p=0.5)
|
||||||
|
self.fc2 = Lin(128, self.num_classes)
|
||||||
|
|
||||||
|
def forward(self, data):
|
||||||
|
'''
|
||||||
|
data: a batch of input, torch.Tensor or torch_geometric.data.Data type
|
||||||
|
- torch.Tensor: (batch_size, 3, num_points), as common batch input
|
||||||
|
|
||||||
|
- torch_geometric.data.Data, as torch_geometric batch input:
|
||||||
|
data.x: (batch_size * ~num_points, C), batch nodes/points feature,
|
||||||
|
~num_points means each sample can have different number of points/nodes
|
||||||
|
|
||||||
|
data.pos: (batch_size * ~num_points, 3)
|
||||||
|
|
||||||
|
data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud
|
||||||
|
idendifiers for all nodes of all graphs/pointclouds in the batch. See
|
||||||
|
pytorch_gemometric documentation for more information
|
||||||
|
'''
|
||||||
|
dense_input = True if isinstance(data, torch.Tensor) else False
|
||||||
|
|
||||||
|
if dense_input:
|
||||||
|
# Convert to torch_geometric.data.Data type
|
||||||
|
data = data.transpose(1, 2).contiguous()
|
||||||
|
batch_size, N, _ = data.shape # (batch_size, num_points, 3)
|
||||||
|
pos = data.view(batch_size*N, -1)
|
||||||
|
batch = torch.zeros((batch_size, N), device=pos.device, dtype=torch.long)
|
||||||
|
for i in range(batch_size): batch[i] = i
|
||||||
|
batch = batch.view(-1)
|
||||||
|
|
||||||
|
data = Data()
|
||||||
|
data.pos, data.batch = pos, batch
|
||||||
|
|
||||||
|
if not hasattr(data, 'x'): data.x = None
|
||||||
|
data_in = data.x, data.pos, data.batch
|
||||||
|
|
||||||
|
sa1_out = self.sa1_module(data_in)
|
||||||
|
sa2_out = self.sa2_module(sa1_out)
|
||||||
|
sa3_out = self.sa3_module(sa2_out)
|
||||||
|
|
||||||
|
fp3_out = self.fp3_module(sa3_out, sa2_out)
|
||||||
|
fp2_out = self.fp2_module(fp3_out, sa1_out)
|
||||||
|
fp1_out = self.fp1_module(fp2_out, data_in)
|
||||||
|
|
||||||
|
fp1_out_x, fp1_out_pos, fp1_out_batch = fp1_out
|
||||||
|
x = self.fc2(self.dropout1(self.fc1(fp1_out_x)))
|
||||||
|
x = F.log_softmax(x, dim=-1)
|
||||||
|
|
||||||
|
if dense_input: return x.view(batch_size, N, self.num_classes)
|
||||||
|
else: return x, fp1_out_batch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
num_classes = 10
|
||||||
|
net = PointNet2PartSegmentNet(num_classes)
|
||||||
|
|
||||||
|
#
|
||||||
|
print('Test dense input ..')
|
||||||
|
data1 = torch.rand((2, 3, 1024)) # (batch_size, 3, num_points)
|
||||||
|
print('data1: ', data1.shape)
|
||||||
|
|
||||||
|
out1 = net(data1)
|
||||||
|
print('out1: ', out1.shape)
|
||||||
|
|
||||||
|
#
|
||||||
|
print('Test torch_geometric.data.Data input ..')
|
||||||
|
def make_data_batch():
|
||||||
|
# batch_size = 2
|
||||||
|
pos_num1 = 1000
|
||||||
|
pos_num2 = 1024
|
||||||
|
|
||||||
|
data_batch = Data()
|
||||||
|
|
||||||
|
# data_batch.x = None
|
||||||
|
data_batch.pos = torch.cat([torch.rand(pos_num1, 3), torch.rand(pos_num2, 3)], dim=0)
|
||||||
|
data_batch.batch = torch.cat([torch.zeros(pos_num1, dtype=torch.long), torch.ones(pos_num2, dtype=torch.long)])
|
||||||
|
|
||||||
|
return data_batch
|
||||||
|
|
||||||
|
data2 = make_data_batch()
|
||||||
|
# print('data.x: ', data.x)
|
||||||
|
print('data2.pos: ', data2.pos.shape)
|
||||||
|
print('data2.batch: ', data2.batch.shape)
|
||||||
|
|
||||||
|
out2_x, out2_batch = net(data2)
|
||||||
|
print('out2_x: ', out2_x.shape)
|
||||||
|
print('out2_batch: ', out2_batch.shape)
|
133
vis/show_seg_res.py
Normal file
133
vis/show_seg_res.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
# Warning: import open3d may lead crash, try to to import open3d first here
|
||||||
|
from view import view_points_labels
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../') # add project root directory
|
||||||
|
|
||||||
|
from dataset.shapenet import ShapeNetPartSegDataset
|
||||||
|
from model.pointnet2_part_seg import PointNet2PartSegmentNet
|
||||||
|
import torch_geometric.transforms as GT
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
##
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--dataset', type=str, default='data', help='dataset path')
|
||||||
|
parser.add_argument('--npoints', type=int, default=50, help='resample points number')
|
||||||
|
parser.add_argument('--model', type=str, default='./checkpoint/seg_model_Airplane_24.pth', help='model path')
|
||||||
|
parser.add_argument('--sample_idx', type=int, default=0, help='select a sample to segment and view result')
|
||||||
|
|
||||||
|
opt = parser.parse_args()
|
||||||
|
print(opt)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
## Load dataset
|
||||||
|
print('Construct dataset ..')
|
||||||
|
test_transform = GT.Compose([GT.NormalizeScale(),])
|
||||||
|
|
||||||
|
test_dataset = ShapeNetPartSegDataset(
|
||||||
|
root_dir=opt.dataset,
|
||||||
|
train=False,
|
||||||
|
transform=test_transform,
|
||||||
|
npoints=opt.npoints
|
||||||
|
)
|
||||||
|
num_classes = test_dataset.num_classes()
|
||||||
|
|
||||||
|
print('test dataset size: ', len(test_dataset))
|
||||||
|
print('num_classes: ', num_classes)
|
||||||
|
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
print('Construct model ..')
|
||||||
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
|
dtype = torch.float
|
||||||
|
|
||||||
|
# net = PointNetPartSegmentNet(num_classes)
|
||||||
|
net = PointNet2PartSegmentNet(num_classes)
|
||||||
|
|
||||||
|
net.load_state_dict(torch.load(opt.model))
|
||||||
|
net = net.to(device, dtype)
|
||||||
|
net.eval()
|
||||||
|
|
||||||
|
|
||||||
|
##
|
||||||
|
def eval_sample(net, sample):
|
||||||
|
'''
|
||||||
|
sample: { 'points': tensor(n, 3), 'labels': tensor(n,) }
|
||||||
|
return: (pred_label, gt_label) with labels shape (n,)
|
||||||
|
'''
|
||||||
|
net.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
# points: (n, 3)
|
||||||
|
points, gt_label = sample['points'], sample['labels']
|
||||||
|
n = points.shape[0]
|
||||||
|
|
||||||
|
points = points.view(1, n, 3) # make a batch
|
||||||
|
points = points.transpose(1, 2).contiguous()
|
||||||
|
points = points.to(device, dtype)
|
||||||
|
|
||||||
|
pred = net(points) # (batch_size, n, num_classes)
|
||||||
|
pred_label = pred.max(2)[1]
|
||||||
|
pred_label = pred_label.view(-1).cpu() # (n,)
|
||||||
|
|
||||||
|
assert pred_label.shape == gt_label.shape
|
||||||
|
return (pred_label, gt_label)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_mIoU(pred_label, gt_label):
|
||||||
|
minl, maxl = np.min(gt_label), np.max(gt_label)
|
||||||
|
ious = []
|
||||||
|
for l in range(minl, maxl+1):
|
||||||
|
I = np.sum(np.logical_and(pred_label == l, gt_label == l))
|
||||||
|
U = np.sum(np.logical_or(pred_label == l, gt_label == l))
|
||||||
|
if U == 0: iou = 1
|
||||||
|
else: iou = float(I) / U
|
||||||
|
ious.append(iou)
|
||||||
|
return np.mean(ious)
|
||||||
|
|
||||||
|
|
||||||
|
def label_diff(pred_label, gt_label):
|
||||||
|
'''
|
||||||
|
Assign 1 if different label, or 0 if same label
|
||||||
|
'''
|
||||||
|
diff = pred_label - gt_label
|
||||||
|
diff_mask = (diff != 0)
|
||||||
|
|
||||||
|
diff_label = np.zeros((pred_label.shape[0]), dtype=np.int32)
|
||||||
|
diff_label[diff_mask] = 1
|
||||||
|
|
||||||
|
return diff_label
|
||||||
|
|
||||||
|
|
||||||
|
# Get one sample and eval
|
||||||
|
sample = test_dataset[opt.sample_idx]
|
||||||
|
|
||||||
|
print('Eval test sample ..')
|
||||||
|
pred_label, gt_label = eval_sample(net, sample)
|
||||||
|
print('Eval done ..')
|
||||||
|
|
||||||
|
|
||||||
|
# Get sample result
|
||||||
|
print('Compute mIoU ..')
|
||||||
|
points = sample['points'].numpy()
|
||||||
|
pred_labels = pred_label.numpy()
|
||||||
|
gt_labels = gt_label.numpy()
|
||||||
|
diff_labels = label_diff(pred_labels, gt_labels)
|
||||||
|
|
||||||
|
print('mIoU: ', compute_mIoU(pred_labels, gt_labels))
|
||||||
|
|
||||||
|
|
||||||
|
# View result
|
||||||
|
|
||||||
|
# print('View gt labels ..')
|
||||||
|
# view_points_labels(points, gt_labels)
|
||||||
|
|
||||||
|
# print('View diff labels ..')
|
||||||
|
# view_points_labels(points, diff_labels)
|
||||||
|
|
||||||
|
print('View pred labels ..')
|
||||||
|
view_points_labels(points, pred_labels)
|
63
vis/view.py
Normal file
63
vis/view.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import open3d as o3d
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def mini_color_table(index, norm=True):
|
||||||
|
colors = [
|
||||||
|
[0.5000, 0.5400, 0.5300], [0.8900, 0.1500, 0.2100], [0.6400, 0.5800, 0.5000],
|
||||||
|
[1.0000, 0.3800, 0.0100], [1.0000, 0.6600, 0.1400], [0.4980, 1.0000, 0.0000],
|
||||||
|
[0.4980, 1.0000, 0.8314], [0.9412, 0.9725, 1.0000], [0.5412, 0.1686, 0.8863],
|
||||||
|
[0.5765, 0.4392, 0.8588], [0.3600, 0.1400, 0.4300], [0.5600, 0.3700, 0.6000],
|
||||||
|
]
|
||||||
|
|
||||||
|
assert index >= 0 and index < len(colors)
|
||||||
|
color = colors[index]
|
||||||
|
|
||||||
|
if not norm:
|
||||||
|
color[0] *= 255
|
||||||
|
color[1] *= 255
|
||||||
|
color[2] *= 255
|
||||||
|
|
||||||
|
return color
|
||||||
|
|
||||||
|
|
||||||
|
def view_points(points, colors=None):
|
||||||
|
'''
|
||||||
|
points: np.ndarray with shape (n, 3)
|
||||||
|
colors: [r, g, b] or np.array with shape (n, 3)
|
||||||
|
'''
|
||||||
|
cloud = o3d.PointCloud()
|
||||||
|
cloud.points = o3d.Vector3dVector(points)
|
||||||
|
|
||||||
|
if colors is not None:
|
||||||
|
if isinstance(colors, np.ndarray):
|
||||||
|
cloud.colors = o3d.Vector3dVector(colors)
|
||||||
|
else: cloud.paint_uniform_color(colors)
|
||||||
|
|
||||||
|
o3d.draw_geometries([cloud])
|
||||||
|
|
||||||
|
|
||||||
|
def label2color(labels):
|
||||||
|
'''
|
||||||
|
labels: np.ndarray with shape (n, )
|
||||||
|
colors(return): np.ndarray with shape (n, 3)
|
||||||
|
'''
|
||||||
|
num = labels.shape[0]
|
||||||
|
colors = np.zeros((num, 3))
|
||||||
|
|
||||||
|
minl, maxl = np.min(labels), np.max(labels)
|
||||||
|
for l in range(minl, maxl + 1):
|
||||||
|
colors[labels==l, :] = mini_color_table(l)
|
||||||
|
|
||||||
|
return colors
|
||||||
|
|
||||||
|
|
||||||
|
def view_points_labels(points, labels):
|
||||||
|
'''
|
||||||
|
Assign points with colors by labels and view colored points.
|
||||||
|
points: np.ndarray with shape (n, 3)
|
||||||
|
labels: np.ndarray with shape (n, 1), dtype=np.int32
|
||||||
|
'''
|
||||||
|
assert points.shape[0] == labels.shape[0]
|
||||||
|
colors = label2color(labels)
|
||||||
|
view_points(points, colors)
|
Loading…
x
Reference in New Issue
Block a user