Final *hopefully* adjustments

This commit is contained in:
Si11ium 2019-08-01 21:24:31 +02:00
parent ff117ea2f2
commit 22ea950d85
12 changed files with 2191 additions and 53 deletions

6
.idea/other.xml generated Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PySciProjectComponent">
<option name="PY_SCI_VIEW_SUGGESTED" value="true" />
</component>
</project>

View File

@ -10,6 +10,12 @@ import torch
from torch_geometric.data import InMemoryDataset
from torch_geometric.data import Data
from torch.utils.data import Dataset
import re
def save_names(name_list, path):
with open(path, 'wb'):
pass
class CustomShapeNet(InMemoryDataset):
@ -181,10 +187,8 @@ class ShapeNetPartSegDataset(Dataset):
class PredictionShapeNet(InMemoryDataset):
categories = {key: val for val, key in enumerate(['Box', 'Cone', 'Cylinder', 'Sphere'])}
def __init__(self, root, transform=None, pre_filter=None, pre_transform=None,
headers=True, **kwargs):
def __init__(self, root, transform=None, pre_filter=None, pre_transform=None, headers=True):
self.has_headers = headers
super(PredictionShapeNet, self).__init__(root, transform, pre_transform, pre_filter)
path = self.processed_paths[0]
@ -226,57 +230,59 @@ class PredictionShapeNet(InMemoryDataset):
def process(self, delimiter=' '):
datasets = defaultdict(list)
for idx, setting in enumerate(self.raw_file_names):
path_to_clouds = os.path.join(self.raw_dir, setting)
datasets, filenames = defaultdict(list), []
path_to_clouds = os.path.join(self.raw_dir, self.raw_file_names[0])
if '.headers' in os.listdir(path_to_clouds):
self.has_headers = True
elif 'no.headers' in os.listdir(path_to_clouds):
self.has_headers = False
else:
pass
if '.headers' in os.listdir(path_to_clouds):
self.has_headers = True
elif 'no.headers' in os.listdir(path_to_clouds):
self.has_headers = False
else:
pass
for pointcloud in tqdm(os.scandir(path_to_clouds)):
if not os.path.isdir(pointcloud):
for pointcloud in tqdm(os.scandir(path_to_clouds)):
if not os.path.isdir(pointcloud):
continue
full_cloud_pattern = '\d+?_pc\.(xyz|dat)'
pattern = re.compile(full_cloud_pattern)
for file in os.scandir(pointcloud.path):
if not pattern.match(file.name):
continue
for extention in ['dat', 'xyz']:
file = os.path.join(pointcloud.path, f'pc.{extention}')
if not os.path.exists(file):
continue
with open(file, 'r') as f:
if self.has_headers:
headers = f.__next__()
# Check if there are no useable nodes in this file, header says 0.
if not int(headers.rstrip().split(delimiter)[0]):
continue
with open(file, 'r') as f:
if self.has_headers:
headers = f.__next__()
# Check if there are no useable nodes in this file, header says 0.
if not int(headers.rstrip().split(delimiter)[0]):
continue
# Iterate over all rows
src = [[float(x) if x not in ['-nan(ind)', 'nan(ind)'] else 0
for x in line.rstrip().split(delimiter)[None:None]] for line in f if line != '']
points = torch.tensor(src, dtype=None).squeeze()
if not len(points.shape) > 1:
continue
# pos = points[:, :3]
# norm = points[:, 3:]
y_fake_all = [-1] * points.shape[0]
y = torch.as_tensor(y_fake_all, dtype=torch.int)
# points = torch.as_tensor(points, dtype=torch.float)
# norm = torch.as_tensor(norm, dtype=torch.float)
data = Data(y=y, pos=points[:, :3])
# , points=points, norm=points[:3], )
# ToDo: ANy filter to apply? Then do it here.
if self.pre_filter is not None and not self.pre_filter(data):
data = self.pre_filter(data)
raise NotImplementedError
# ToDo: ANy transformation to apply? Then do it here.
if self.pre_transform is not None:
data = self.pre_transform(data)
raise NotImplementedError
datasets[setting].append(data)
# Iterate over all rows
src = [[float(x) if x not in ['-nan(ind)', 'nan(ind)'] else 0
for x in line.rstrip().split(delimiter)[None:None]] for line in f if line != '']
points = torch.tensor(src, dtype=None).squeeze()
if not len(points.shape) > 1:
continue
# pos = points[:, :3]
# norm = points[:, 3:]
y_fake_all = [-1] * points.shape[0]
y = torch.as_tensor(y_fake_all, dtype=torch.int)
# points = torch.as_tensor(points, dtype=torch.float)
# norm = torch.as_tensor(norm, dtype=torch.float)
data = Data(y=y, pos=points[:, :3])
# , points=points, norm=points[:3], )
# ToDo: ANy filter to apply? Then do it here.
if self.pre_filter is not None and not self.pre_filter(data):
data = self.pre_filter(data)
raise NotImplementedError
# ToDo: ANy transformation to apply? Then do it here.
if self.pre_transform is not None:
data = self.pre_transform(data)
raise NotImplementedError
datasets[self.raw_file_names[0]].append(data)
filenames.append(file)
os.makedirs(self.processed_dir, exist_ok=True)
torch.save(self.collate(datasets[setting]), self.processed_paths[idx])
torch.save(self.collate(datasets[self.raw_file_names[0]]), self.processed_paths[0])
# save_names(filenames)
def __repr__(self):
return f'{self.__class__.__name__}({len(self)})'
@ -287,11 +293,11 @@ class PredictNetPartSegDataset(Dataset):
Resample raw point cloud to fixed number of points.
Map raw label from range [1, N] to [0, N-1].
"""
def __init__(self, root_dir, train=False, transform=None, npoints=2048, headers=True, collate_per_segment=False):
def __init__(self, root_dir, num_classes, transform=None, npoints=2048, headers=True):
super(PredictNetPartSegDataset, self).__init__()
self.npoints = npoints
self.dataset = PredictionShapeNet(root=root_dir, train=train, transform=transform,
headers=headers, collate_per_segment=collate_per_segment)
self._num_classes = num_classes
self.dataset = PredictionShapeNet(root=root_dir, transform=transform, headers=headers)
def __getitem__(self, index):
data = self.dataset[index]
@ -311,11 +317,10 @@ class PredictNetPartSegDataset(Dataset):
'points': points, # torch.Tensor (n, 3)
'labels': labels # torch.Tensor (n,)
}
return sample
def __len__(self):
return len(self.dataset)
def num_classes(self):
return self.dataset.num_classes
return self._num_classes

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

File diff suppressed because it is too large Load Diff

View File

79
predict/predict.py Normal file
View File

@ -0,0 +1,79 @@
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../') # add project root directory
from dataset.shapenet import PredictNetPartSegDataset, ShapeNetPartSegDataset
from model.pointnet2_part_seg import PointNet2PartSegmentNet
import torch_geometric.transforms as GT
import torch
import numpy as np
import argparse
##
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='data', help='dataset path')
parser.add_argument('--npoints', type=int, default=2048, help='resample points number')
parser.add_argument('--model', type=str, default='./checkpoint/seg_model_custom_249.pth', help='model path')
parser.add_argument('--sample_idx', type=int, default=0, help='select a sample to segment and view result')
opt = parser.parse_args()
print(opt)
if __name__ == '__main__':
# Load dataset
print('Construct dataset ..')
test_transform = GT.Compose([GT.NormalizeScale(),])
test_dataset = PredictNetPartSegDataset(
root_dir=opt.dataset,
num_classes=4,
transform=None,
npoints=opt.npoints
)
num_classes = test_dataset.num_classes()
print('test dataset size: ', len(test_dataset))
# Load model
print('Construct model ..')
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
dtype = torch.float
# net = PointNetPartSegmentNet(num_classes)
net = PointNet2PartSegmentNet(num_classes)
net.load_state_dict(torch.load(opt.model, map_location=device.type))
net = net.to(device, dtype)
net.eval()
##
def eval_sample(net, sample):
'''
sample: { 'points': tensor(n, 3), 'labels': tensor(n,) }
return: (pred_label, gt_label) with labels shape (n,)
'''
net.eval()
with torch.no_grad():
# points: (n, 3)
points, gt_label = sample['points'], sample['labels']
n = points.shape[0]
points = points.view(1, n, 3) # make a batch
points = points.transpose(1, 2).contiguous()
points = points.to(device, dtype)
pred = net(points) # (batch_size, n, num_classes)
pred_label = pred.max(2)[1]
pred_label = pred_label.view(-1).cpu() # (n,)
assert pred_label.shape == gt_label.shape
return (pred_label, gt_label)
# Iterate over all the samples
for sample in test_dataset:
print('Eval test sample ..')
pred_label, gt_label = eval_sample(net, sample)
print('Eval done ..')
pred_labels = pred_label.numpy()
print(pred_labels)