Final *hopefully* adjustments

This commit is contained in:
Si11ium 2019-08-01 21:24:31 +02:00
parent ff117ea2f2
commit 22ea950d85
12 changed files with 2191 additions and 53 deletions

6
.idea/other.xml generated Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PySciProjectComponent">
<option name="PY_SCI_VIEW_SUGGESTED" value="true" />
</component>
</project>

View File

@ -10,6 +10,12 @@ import torch
from torch_geometric.data import InMemoryDataset
from torch_geometric.data import Data
from torch.utils.data import Dataset
import re
def save_names(name_list, path):
with open(path, 'wb'):
pass
class CustomShapeNet(InMemoryDataset):
@ -181,10 +187,8 @@ class ShapeNetPartSegDataset(Dataset):
class PredictionShapeNet(InMemoryDataset):
categories = {key: val for val, key in enumerate(['Box', 'Cone', 'Cylinder', 'Sphere'])}
def __init__(self, root, transform=None, pre_filter=None, pre_transform=None,
headers=True, **kwargs):
def __init__(self, root, transform=None, pre_filter=None, pre_transform=None, headers=True):
self.has_headers = headers
super(PredictionShapeNet, self).__init__(root, transform, pre_transform, pre_filter)
path = self.processed_paths[0]
@ -226,9 +230,8 @@ class PredictionShapeNet(InMemoryDataset):
def process(self, delimiter=' '):
datasets = defaultdict(list)
for idx, setting in enumerate(self.raw_file_names):
path_to_clouds = os.path.join(self.raw_dir, setting)
datasets, filenames = defaultdict(list), []
path_to_clouds = os.path.join(self.raw_dir, self.raw_file_names[0])
if '.headers' in os.listdir(path_to_clouds):
self.has_headers = True
@ -240,9 +243,10 @@ class PredictionShapeNet(InMemoryDataset):
for pointcloud in tqdm(os.scandir(path_to_clouds)):
if not os.path.isdir(pointcloud):
continue
for extention in ['dat', 'xyz']:
file = os.path.join(pointcloud.path, f'pc.{extention}')
if not os.path.exists(file):
full_cloud_pattern = '\d+?_pc\.(xyz|dat)'
pattern = re.compile(full_cloud_pattern)
for file in os.scandir(pointcloud.path):
if not pattern.match(file.name):
continue
with open(file, 'r') as f:
if self.has_headers:
@ -273,10 +277,12 @@ class PredictionShapeNet(InMemoryDataset):
if self.pre_transform is not None:
data = self.pre_transform(data)
raise NotImplementedError
datasets[setting].append(data)
datasets[self.raw_file_names[0]].append(data)
filenames.append(file)
os.makedirs(self.processed_dir, exist_ok=True)
torch.save(self.collate(datasets[setting]), self.processed_paths[idx])
torch.save(self.collate(datasets[self.raw_file_names[0]]), self.processed_paths[0])
# save_names(filenames)
def __repr__(self):
return f'{self.__class__.__name__}({len(self)})'
@ -287,11 +293,11 @@ class PredictNetPartSegDataset(Dataset):
Resample raw point cloud to fixed number of points.
Map raw label from range [1, N] to [0, N-1].
"""
def __init__(self, root_dir, train=False, transform=None, npoints=2048, headers=True, collate_per_segment=False):
def __init__(self, root_dir, num_classes, transform=None, npoints=2048, headers=True):
super(PredictNetPartSegDataset, self).__init__()
self.npoints = npoints
self.dataset = PredictionShapeNet(root=root_dir, train=train, transform=transform,
headers=headers, collate_per_segment=collate_per_segment)
self._num_classes = num_classes
self.dataset = PredictionShapeNet(root=root_dir, transform=transform, headers=headers)
def __getitem__(self, index):
data = self.dataset[index]
@ -311,11 +317,10 @@ class PredictNetPartSegDataset(Dataset):
'points': points, # torch.Tensor (n, 3)
'labels': labels # torch.Tensor (n,)
}
return sample
def __len__(self):
return len(self.dataset)
def num_classes(self):
return self.dataset.num_classes
return self._num_classes

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

File diff suppressed because it is too large Load Diff

View File

79
predict/predict.py Normal file
View File

@ -0,0 +1,79 @@
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../') # add project root directory
from dataset.shapenet import PredictNetPartSegDataset, ShapeNetPartSegDataset
from model.pointnet2_part_seg import PointNet2PartSegmentNet
import torch_geometric.transforms as GT
import torch
import numpy as np
import argparse
##
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='data', help='dataset path')
parser.add_argument('--npoints', type=int, default=2048, help='resample points number')
parser.add_argument('--model', type=str, default='./checkpoint/seg_model_custom_249.pth', help='model path')
parser.add_argument('--sample_idx', type=int, default=0, help='select a sample to segment and view result')
opt = parser.parse_args()
print(opt)
if __name__ == '__main__':
# Load dataset
print('Construct dataset ..')
test_transform = GT.Compose([GT.NormalizeScale(),])
test_dataset = PredictNetPartSegDataset(
root_dir=opt.dataset,
num_classes=4,
transform=None,
npoints=opt.npoints
)
num_classes = test_dataset.num_classes()
print('test dataset size: ', len(test_dataset))
# Load model
print('Construct model ..')
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
dtype = torch.float
# net = PointNetPartSegmentNet(num_classes)
net = PointNet2PartSegmentNet(num_classes)
net.load_state_dict(torch.load(opt.model, map_location=device.type))
net = net.to(device, dtype)
net.eval()
##
def eval_sample(net, sample):
'''
sample: { 'points': tensor(n, 3), 'labels': tensor(n,) }
return: (pred_label, gt_label) with labels shape (n,)
'''
net.eval()
with torch.no_grad():
# points: (n, 3)
points, gt_label = sample['points'], sample['labels']
n = points.shape[0]
points = points.view(1, n, 3) # make a batch
points = points.transpose(1, 2).contiguous()
points = points.to(device, dtype)
pred = net(points) # (batch_size, n, num_classes)
pred_label = pred.max(2)[1]
pred_label = pred_label.view(-1).cpu() # (n,)
assert pred_label.shape == gt_label.shape
return (pred_label, gt_label)
# Iterate over all the samples
for sample in test_dataset:
print('Eval test sample ..')
pred_label, gt_label = eval_sample(net, sample)
print('Eval done ..')
pred_labels = pred_label.numpy()
print(pred_labels)