Added normals to prediction DataObject
This commit is contained in:
@@ -141,13 +141,16 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
y_all = [-1] * points.shape[0]
|
y_all = [-1] * points.shape[0]
|
||||||
|
|
||||||
y = torch.as_tensor(y_all, dtype=torch.int)
|
y = torch.as_tensor(y_all, dtype=torch.int)
|
||||||
|
attr_dict = dict(y=y, pos=points[:, :3])
|
||||||
|
if self.mode == 'predict':
|
||||||
|
attr_dict.update(normals=points[:, 3:6])
|
||||||
if self.collate_per_element:
|
if self.collate_per_element:
|
||||||
data = Data(y=y, pos=points[:, :3]) # , points=points, norm=points[:, 3:])
|
data = Data(**attr_dict)
|
||||||
else:
|
else:
|
||||||
if not data:
|
if not data:
|
||||||
data = defaultdict(list)
|
data = defaultdict(list)
|
||||||
# points=points, norm=points[:, 3:]
|
# points=points, norm=points[:, 3:]
|
||||||
for key, val in dict(y=y, pos=points[:, :3]).items():
|
for key, val in attr_dict.items():
|
||||||
data[key].append(val)
|
data[key].append(val)
|
||||||
|
|
||||||
data = self._transform_and_filter(data)
|
data = self._transform_and_filter(data)
|
||||||
@@ -175,9 +178,10 @@ class ShapeNetPartSegDataset(Dataset):
|
|||||||
Resample raw point cloud to fixed number of points.
|
Resample raw point cloud to fixed number of points.
|
||||||
Map raw label from range [1, N] to [0, N-1].
|
Map raw label from range [1, N] to [0, N-1].
|
||||||
"""
|
"""
|
||||||
def __init__(self, root_dir, npoints=1024, **kwargs):
|
def __init__(self, root_dir, npoints=1024, mode='train', **kwargs):
|
||||||
super(ShapeNetPartSegDataset, self).__init__()
|
super(ShapeNetPartSegDataset, self).__init__()
|
||||||
kwargs.update(dict(root_dir=root_dir))
|
self.mode = mode
|
||||||
|
kwargs.update(dict(root_dir=root_dir, mode=self.mode))
|
||||||
self.npoints = npoints
|
self.npoints = npoints
|
||||||
self.dataset = CustomShapeNet(**kwargs)
|
self.dataset = CustomShapeNet(**kwargs)
|
||||||
|
|
||||||
@@ -199,6 +203,8 @@ class ShapeNetPartSegDataset(Dataset):
|
|||||||
'points': points, # torch.Tensor (n, 3)
|
'points': points, # torch.Tensor (n, 3)
|
||||||
'labels': labels # torch.Tensor (n,)
|
'labels': labels # torch.Tensor (n,)
|
||||||
}
|
}
|
||||||
|
if self.mode == 'predict':
|
||||||
|
sample.update(normals=data.normals)
|
||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user