Final *hopefully* adjustments

This commit is contained in:
Si11ium
2019-08-01 21:24:31 +02:00
parent ff117ea2f2
commit 22ea950d85
12 changed files with 2191 additions and 53 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

File diff suppressed because it is too large Load Diff

View File

79
predict/predict.py Normal file
View File

@@ -0,0 +1,79 @@
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../') # add project root directory
from dataset.shapenet import PredictNetPartSegDataset, ShapeNetPartSegDataset
from model.pointnet2_part_seg import PointNet2PartSegmentNet
import torch_geometric.transforms as GT
import torch
import numpy as np
import argparse
##
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='data', help='dataset path')
parser.add_argument('--npoints', type=int, default=2048, help='resample points number')
parser.add_argument('--model', type=str, default='./checkpoint/seg_model_custom_249.pth', help='model path')
parser.add_argument('--sample_idx', type=int, default=0, help='select a sample to segment and view result')
opt = parser.parse_args()
print(opt)
if __name__ == '__main__':
# Load dataset
print('Construct dataset ..')
test_transform = GT.Compose([GT.NormalizeScale(),])
test_dataset = PredictNetPartSegDataset(
root_dir=opt.dataset,
num_classes=4,
transform=None,
npoints=opt.npoints
)
num_classes = test_dataset.num_classes()
print('test dataset size: ', len(test_dataset))
# Load model
print('Construct model ..')
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
dtype = torch.float
# net = PointNetPartSegmentNet(num_classes)
net = PointNet2PartSegmentNet(num_classes)
net.load_state_dict(torch.load(opt.model, map_location=device.type))
net = net.to(device, dtype)
net.eval()
##
def eval_sample(net, sample):
'''
sample: { 'points': tensor(n, 3), 'labels': tensor(n,) }
return: (pred_label, gt_label) with labels shape (n,)
'''
net.eval()
with torch.no_grad():
# points: (n, 3)
points, gt_label = sample['points'], sample['labels']
n = points.shape[0]
points = points.view(1, n, 3) # make a batch
points = points.transpose(1, 2).contiguous()
points = points.to(device, dtype)
pred = net(points) # (batch_size, n, num_classes)
pred_label = pred.max(2)[1]
pred_label = pred_label.view(-1).cpu() # (n,)
assert pred_label.shape == gt_label.shape
return (pred_label, gt_label)
# Iterate over all the samples
for sample in test_dataset:
print('Eval test sample ..')
pred_label, gt_label = eval_sample(net, sample)
print('Eval done ..')
pred_labels = pred_label.numpy()
print(pred_labels)