Added normals to prediction DataObject

This commit is contained in:
Si11ium 2019-08-09 13:06:29 +02:00
parent 39e5d72226
commit a501dcd6b0
2 changed files with 8 additions and 8 deletions

View File

@ -144,7 +144,7 @@ class CustomShapeNet(InMemoryDataset):
y = torch.as_tensor(y_all, dtype=torch.int) y = torch.as_tensor(y_all, dtype=torch.int)
# This is where you define the keys # This is where you define the keys
attr_dict = dict(y=y, pos=points[:, :3]) # , normals=points[:, 3:6]) attr_dict = dict(y=y, pos=points[:, :3], normals=points[:, 3:6])
if self.collate_per_element: if self.collate_per_element:
data = Data(**attr_dict) data = Data(**attr_dict)
else: else:
@ -193,20 +193,20 @@ class ShapeNetPartSegDataset(Dataset):
# Resample to fixed number of points # Resample to fixed number of points
try: try:
npoints = self.npoints if self.mode != 'predict' else data.pos.shape[0] npoints = self.npoints if self.mode != 'predict' else data.pos.shape[0]
choice = np.random.choice(data.pos.shape[0], npoints, replace=False) choice = np.random.choice(data.pos.shape[0], npoints, replace=False if self.mode == 'predict' else True)
except ValueError: except ValueError:
choice = [] choice = []
# pos, normals, labels = data.pos[choice, :], data.normals[choice, :], data.y[choice] pos, normals, labels = data.pos[choice, :], data.normals[choice, :], data.y[choice]
pos, labels = data.pos[choice, :], data.y[choice] # pos, labels = data.pos[choice, :], data.y[choice]
labels -= 1 if self.num_classes() in labels else 0 # Map label from [1, C] to [0, C-1] labels -= 1 if self.num_classes() in labels else 0 # Map label from [1, C] to [0, C-1]
sample = { sample = {
'points': torch.cat([pos], dim=1), # torch.Tensor (n, 6) 'points': torch.cat([pos, normals], dim=1), # torch.Tensor (n, 6)
'labels': labels, # torch.Tensor (n,) 'labels': labels, # torch.Tensor (n,)
# 'pos': pos, # torch.Tensor (n, 3) 'pos': pos, # torch.Tensor (n, 3)
# 'normals': normals # torch.Tensor (n, 3) 'normals': normals # torch.Tensor (n, 3)
} }
return sample return sample

View File

@ -8,7 +8,7 @@ from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_geometric.data.data import Data from torch_geometric.data.data import Data
from torch_scatter import scatter_add, scatter_max from torch_scatter import scatter_add, scatter_max
GLOBAL_POINT_FEATURES = 3 GLOBAL_POINT_FEATURES = 6
class PointNet2SAModule(torch.nn.Module): class PointNet2SAModule(torch.nn.Module):
def __init__(self, sample_radio, radius, max_num_neighbors, mlp): def __init__(self, sample_radio, radius, max_num_neighbors, mlp):