Added normals to prediction DataObject
This commit is contained in:
@ -141,13 +141,16 @@ class CustomShapeNet(InMemoryDataset):
|
||||
y_all = [-1] * points.shape[0]
|
||||
|
||||
y = torch.as_tensor(y_all, dtype=torch.int)
|
||||
attr_dict = dict(y=y, pos=points[:, :3])
|
||||
if self.mode == 'predict':
|
||||
attr_dict.update(normals=points[:, 3:6])
|
||||
if self.collate_per_element:
|
||||
data = Data(y=y, pos=points[:, :3]) # , points=points, norm=points[:, 3:])
|
||||
data = Data(**attr_dict)
|
||||
else:
|
||||
if not data:
|
||||
data = defaultdict(list)
|
||||
# points=points, norm=points[:, 3:]
|
||||
for key, val in dict(y=y, pos=points[:, :3]).items():
|
||||
for key, val in attr_dict.items():
|
||||
data[key].append(val)
|
||||
|
||||
data = self._transform_and_filter(data)
|
||||
@ -175,9 +178,10 @@ class ShapeNetPartSegDataset(Dataset):
|
||||
Resample raw point cloud to fixed number of points.
|
||||
Map raw label from range [1, N] to [0, N-1].
|
||||
"""
|
||||
def __init__(self, root_dir, npoints=1024, **kwargs):
|
||||
def __init__(self, root_dir, npoints=1024, mode='train', **kwargs):
|
||||
super(ShapeNetPartSegDataset, self).__init__()
|
||||
kwargs.update(dict(root_dir=root_dir))
|
||||
self.mode = mode
|
||||
kwargs.update(dict(root_dir=root_dir, mode=self.mode))
|
||||
self.npoints = npoints
|
||||
self.dataset = CustomShapeNet(**kwargs)
|
||||
|
||||
@ -199,6 +203,8 @@ class ShapeNetPartSegDataset(Dataset):
|
||||
'points': points, # torch.Tensor (n, 3)
|
||||
'labels': labels # torch.Tensor (n,)
|
||||
}
|
||||
if self.mode == 'predict':
|
||||
sample.update(normals=data.normals)
|
||||
|
||||
return sample
|
||||
|
||||
|
Reference in New Issue
Block a user