Added normals to prediction DataObject

This commit is contained in:
Si11ium
2019-08-06 16:18:54 +02:00
parent c3948397dc
commit efba70f19a

@ -141,13 +141,16 @@ class CustomShapeNet(InMemoryDataset):
y_all = [-1] * points.shape[0]
y = torch.as_tensor(y_all, dtype=torch.int)
attr_dict = dict(y=y, pos=points[:, :3])
if self.mode == 'predict':
attr_dict.update(normals=points[:, 3:6])
if self.collate_per_element:
data = Data(y=y, pos=points[:, :3]) # , points=points, norm=points[:, 3:])
data = Data(**attr_dict)
else:
if not data:
data = defaultdict(list)
# points=points, norm=points[:, 3:]
for key, val in dict(y=y, pos=points[:, :3]).items():
for key, val in attr_dict.items():
data[key].append(val)
data = self._transform_and_filter(data)
@ -175,9 +178,10 @@ class ShapeNetPartSegDataset(Dataset):
Resample raw point cloud to fixed number of points.
Map raw label from range [1, N] to [0, N-1].
"""
def __init__(self, root_dir, npoints=1024, **kwargs):
def __init__(self, root_dir, npoints=1024, mode='train', **kwargs):
super(ShapeNetPartSegDataset, self).__init__()
kwargs.update(dict(root_dir=root_dir))
self.mode = mode
kwargs.update(dict(root_dir=root_dir, mode=self.mode))
self.npoints = npoints
self.dataset = CustomShapeNet(**kwargs)
@ -199,6 +203,8 @@ class ShapeNetPartSegDataset(Dataset):
'points': points, # torch.Tensor (n, 3)
'labels': labels # torch.Tensor (n,)
}
if self.mode == 'predict':
sample.update(normals=data.normals)
return sample