Added normals to prediction DataObject

This commit is contained in:
Si11ium
2019-08-07 08:54:07 +02:00
parent efba70f19a
commit 74de208831
4 changed files with 20 additions and 20 deletions

View File

@ -135,6 +135,7 @@ class CustomShapeNet(InMemoryDataset):
else:
# Get the y - Label
if self.mode != 'predict':
# TODO: This is shady function, elaborate on it
y_raw = next(i for i, v in enumerate(self.categories.keys()) if v.lower() in element.lower())
y_all = [y_raw] * points.shape[0]
else:
@ -187,15 +188,14 @@ class ShapeNetPartSegDataset(Dataset):
def __getitem__(self, index):
data = self.dataset[index]
points, labels = data.pos, data.y # , data.points, data.norm
# Resample to fixed number of points
try:
choice = np.random.choice(points.shape[0], self.npoints, replace=True)
choice = np.random.choice(data.pos.shape[0], self.npoints, replace=True)
except ValueError:
choice = []
points, labels = points[choice, :], labels[choice]
points, labels = data.pos[choice, :], data.y[choice]
labels -= 1 if self.num_classes() in labels else 0 # Map label from [1, C] to [0, C-1]
@ -204,7 +204,8 @@ class ShapeNetPartSegDataset(Dataset):
'labels': labels # torch.Tensor (n,)
}
if self.mode == 'predict':
sample.update(normals=data.normals)
normals = data.normals[choice]
sample.update(normals=normals)
return sample