Added normals to prediction DataObject

This commit is contained in:
Si11ium 2019-08-09 13:06:29 +02:00
parent 39e5d72226
commit a501dcd6b0
2 changed files with 8 additions and 8 deletions

View File

@ -144,7 +144,7 @@ class CustomShapeNet(InMemoryDataset):
y = torch.as_tensor(y_all, dtype=torch.int)
# This is where you define the keys
attr_dict = dict(y=y, pos=points[:, :3]) # , normals=points[:, 3:6])
attr_dict = dict(y=y, pos=points[:, :3], normals=points[:, 3:6])
if self.collate_per_element:
data = Data(**attr_dict)
else:
@ -193,20 +193,20 @@ class ShapeNetPartSegDataset(Dataset):
# Resample to fixed number of points
try:
npoints = self.npoints if self.mode != 'predict' else data.pos.shape[0]
choice = np.random.choice(data.pos.shape[0], npoints, replace=False)
choice = np.random.choice(data.pos.shape[0], npoints, replace=False if self.mode == 'predict' else True)
except ValueError:
choice = []
# pos, normals, labels = data.pos[choice, :], data.normals[choice, :], data.y[choice]
pos, labels = data.pos[choice, :], data.y[choice]
pos, normals, labels = data.pos[choice, :], data.normals[choice, :], data.y[choice]
# pos, labels = data.pos[choice, :], data.y[choice]
labels -= 1 if self.num_classes() in labels else 0 # Map label from [1, C] to [0, C-1]
sample = {
'points': torch.cat([pos], dim=1), # torch.Tensor (n, 6)
'points': torch.cat([pos, normals], dim=1), # torch.Tensor (n, 6)
'labels': labels, # torch.Tensor (n,)
# 'pos': pos, # torch.Tensor (n, 3)
# 'normals': normals # torch.Tensor (n, 3)
'pos': pos, # torch.Tensor (n, 3)
'normals': normals # torch.Tensor (n, 3)
}
return sample

View File

@ -8,7 +8,7 @@ from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_geometric.data.data import Data
from torch_scatter import scatter_add, scatter_max
GLOBAL_POINT_FEATURES = 3
GLOBAL_POINT_FEATURES = 6
class PointNet2SAModule(torch.nn.Module):
def __init__(self, sample_radio, radius, max_num_neighbors, mlp):